Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -1,6 +1,29 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v10.0.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
|
||||||
|
- Breaking: renamed list_get to seq_get.
|
||||||
|
- Breaking: activated mypy type checking for SQLGlot.
|
||||||
|
- New: Azure Databricks support.
|
||||||
|
- New: placeholders can now be replaced in an expression.
|
||||||
|
- New: null safe equal operator (<=>).
|
||||||
|
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
|
||||||
|
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
|
||||||
|
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
|
||||||
|
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
|
||||||
|
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
|
||||||
|
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
|
||||||
|
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
|
||||||
|
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
|
||||||
|
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
|
||||||
|
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
|
||||||
|
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
|
||||||
|
|
||||||
v9.0.0
|
v9.0.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
49
README.md
49
README.md
|
@ -14,7 +14,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
|
||||||
|
|
||||||
* [Install](#install)
|
* [Install](#install)
|
||||||
* [Documentation](#documentation)
|
* [Documentation](#documentation)
|
||||||
* [Run Tests & Lint](#run-tests-and-lint)
|
* [Run Tests and Lint](#run-tests-and-lint)
|
||||||
* [Examples](#examples)
|
* [Examples](#examples)
|
||||||
* [Formatting and Transpiling](#formatting-and-transpiling)
|
* [Formatting and Transpiling](#formatting-and-transpiling)
|
||||||
* [Metadata](#metadata)
|
* [Metadata](#metadata)
|
||||||
|
@ -22,7 +22,6 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
|
||||||
* [Unsupported Errors](#unsupported-errors)
|
* [Unsupported Errors](#unsupported-errors)
|
||||||
* [Build and Modify SQL](#build-and-modify-sql)
|
* [Build and Modify SQL](#build-and-modify-sql)
|
||||||
* [SQL Optimizer](#sql-optimizer)
|
* [SQL Optimizer](#sql-optimizer)
|
||||||
* [SQL Annotations](#sql-annotations)
|
|
||||||
* [AST Introspection](#ast-introspection)
|
* [AST Introspection](#ast-introspection)
|
||||||
* [AST Diff](#ast-diff)
|
* [AST Diff](#ast-diff)
|
||||||
* [Custom Dialects](#custom-dialects)
|
* [Custom Dialects](#custom-dialects)
|
||||||
|
@ -51,7 +50,7 @@ pip3 install -r dev-requirements.txt
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
SQLGlot's uses [pdocs](https://pdoc.dev/) to serve its API documentation:
|
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
|
||||||
|
|
||||||
```
|
```
|
||||||
pdoc sqlglot --docformat google
|
pdoc sqlglot --docformat google
|
||||||
|
@ -121,6 +120,39 @@ LEFT JOIN `baz`
|
||||||
ON `f`.`a` = `baz`.`a`
|
ON `f`.`a` = `baz`.`a`
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Comments are also preserved in a best-effort basis when transpiling SQL code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
sql = """
|
||||||
|
/* multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT
|
||||||
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
|
CAST(x AS INT), # comment 3
|
||||||
|
y -- comment 4
|
||||||
|
FROM
|
||||||
|
bar /* comment 5 */,
|
||||||
|
tbl # comment 6
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
```sql
|
||||||
|
/* multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT
|
||||||
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
|
CAST(x AS INT), -- comment 3
|
||||||
|
y -- comment 4
|
||||||
|
FROM bar /* comment 5 */, tbl /* comment 6*/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Metadata
|
### Metadata
|
||||||
|
|
||||||
You can explore SQL with expression helpers to do things like find columns and tables:
|
You can explore SQL with expression helpers to do things like find columns and tables:
|
||||||
|
@ -249,17 +281,6 @@ WHERE
|
||||||
"x"."Z" = CAST('2021-02-01' AS DATE)
|
"x"."Z" = CAST('2021-02-01' AS DATE)
|
||||||
```
|
```
|
||||||
|
|
||||||
### SQL Annotations
|
|
||||||
|
|
||||||
SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example:
|
|
||||||
|
|
||||||
```sql
|
|
||||||
SELECT
|
|
||||||
user # primary_key,
|
|
||||||
country
|
|
||||||
FROM users
|
|
||||||
```
|
|
||||||
|
|
||||||
### AST Introspection
|
### AST Introspection
|
||||||
|
|
||||||
You can see the AST version of the sql by calling `repr`:
|
You can see the AST version of the sql by calling `repr`:
|
||||||
|
|
|
@ -1,15 +1,8 @@
|
||||||
#!/bin/bash -e
|
#!/bin/bash -e
|
||||||
|
|
||||||
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
|
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
|
||||||
|
TARGETS="sqlglot/ tests/"
|
||||||
python -m autoflake -i -r ${RETURN_ERROR_CODE} \
|
python -m mypy $TARGETS
|
||||||
--expand-star-imports \
|
python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS
|
||||||
--remove-all-unused-imports \
|
python -m isort $TARGETS
|
||||||
--ignore-init-module-imports \
|
python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS
|
||||||
--remove-duplicate-keys \
|
|
||||||
--remove-unused-variables \
|
|
||||||
sqlglot/ tests/
|
|
||||||
python -m isort --profile black sqlglot/ tests/
|
|
||||||
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
|
|
||||||
python -m mypy sqlglot tests
|
|
||||||
python -m unittest
|
python -m unittest
|
||||||
|
|
15
setup.cfg
15
setup.cfg
|
@ -3,7 +3,7 @@ disallow_untyped_calls = False
|
||||||
no_implicit_optional = True
|
no_implicit_optional = True
|
||||||
|
|
||||||
[mypy-sqlglot.*]
|
[mypy-sqlglot.*]
|
||||||
ignore_errors = True
|
ignore_errors = False
|
||||||
|
|
||||||
[mypy-sqlglot.dataframe.*]
|
[mypy-sqlglot.dataframe.*]
|
||||||
ignore_errors = False
|
ignore_errors = False
|
||||||
|
@ -13,3 +13,16 @@ ignore_errors = True
|
||||||
|
|
||||||
[mypy-tests.dataframe.*]
|
[mypy-tests.dataframe.*]
|
||||||
ignore_errors = False
|
ignore_errors = False
|
||||||
|
|
||||||
|
[autoflake]
|
||||||
|
in-place = True
|
||||||
|
expand-star-imports = True
|
||||||
|
remove-all-unused-imports = True
|
||||||
|
ignore-init-module-imports = True
|
||||||
|
remove-duplicate-keys = True
|
||||||
|
remove-unused-variables = True
|
||||||
|
quiet = True
|
||||||
|
|
||||||
|
[isort]
|
||||||
|
profile=black
|
||||||
|
known_first_party=sqlglot
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -21,6 +21,7 @@ setup(
|
||||||
author_email="toby.mao@gmail.com",
|
author_email="toby.mao@gmail.com",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
packages=find_packages(include=["sqlglot", "sqlglot.*"]),
|
packages=find_packages(include=["sqlglot", "sqlglot.*"]),
|
||||||
|
package_data={"sqlglot": ["py.typed"]},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
"""## Python SQL parser, transpiler and optimizer."""
|
"""## Python SQL parser, transpiler and optimizer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.dialects import Dialect, Dialects
|
from sqlglot.dialects import Dialect, Dialects
|
||||||
from sqlglot.diff import diff
|
from sqlglot.diff import diff
|
||||||
|
@ -20,51 +24,54 @@ from sqlglot.expressions import (
|
||||||
subquery,
|
subquery,
|
||||||
)
|
)
|
||||||
from sqlglot.expressions import table_ as table
|
from sqlglot.expressions import table_ as table
|
||||||
from sqlglot.expressions import union
|
from sqlglot.expressions import to_column, to_table, union
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "9.0.6"
|
__version__ = "10.0.1"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
schema = MappingSchema()
|
schema = MappingSchema()
|
||||||
|
|
||||||
|
|
||||||
def parse(sql, read=None, **opts):
|
def parse(
|
||||||
|
sql: str, read: t.Optional[str | Dialect] = None, **opts
|
||||||
|
) -> t.List[t.Optional[Expression]]:
|
||||||
"""
|
"""
|
||||||
Parses the given SQL string into a collection of syntax trees, one per
|
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
||||||
parsed SQL statement.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql (str): the SQL code string to parse.
|
sql: the SQL code string to parse.
|
||||||
read (str): the SQL dialect to apply during parsing
|
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
|
||||||
**opts: other options.
|
**opts: other options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
typing.List[Expression]: the list of parsed syntax trees.
|
The resulting syntax tree collection.
|
||||||
"""
|
"""
|
||||||
dialect = Dialect.get_or_raise(read)()
|
dialect = Dialect.get_or_raise(read)()
|
||||||
return dialect.parse(sql, **opts)
|
return dialect.parse(sql, **opts)
|
||||||
|
|
||||||
|
|
||||||
def parse_one(sql, read=None, into=None, **opts):
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
read: t.Optional[str | Dialect] = None,
|
||||||
|
into: t.Optional[Expression | str] = None,
|
||||||
|
**opts,
|
||||||
|
) -> t.Optional[Expression]:
|
||||||
"""
|
"""
|
||||||
Parses the given SQL string and returns a syntax tree for the first
|
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
||||||
parsed SQL statement.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql (str): the SQL code string to parse.
|
sql: the SQL code string to parse.
|
||||||
read (str): the SQL dialect to apply during parsing
|
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
into: the SQLGlot Expression to parse into.
|
||||||
into (Expression): the SQLGlot Expression to parse into
|
|
||||||
**opts: other options.
|
**opts: other options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Expression: the syntax tree for the first parsed statement.
|
The syntax tree for the first parsed statement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dialect = Dialect.get_or_raise(read)()
|
dialect = Dialect.get_or_raise(read)()
|
||||||
|
@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts):
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
|
||||||
def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
|
def transpile(
|
||||||
|
sql: str,
|
||||||
|
read: t.Optional[str | Dialect] = None,
|
||||||
|
write: t.Optional[str | Dialect] = None,
|
||||||
|
identity: bool = True,
|
||||||
|
error_level: t.Optional[ErrorLevel] = None,
|
||||||
|
**opts,
|
||||||
|
) -> t.List[str]:
|
||||||
"""
|
"""
|
||||||
Parses the given SQL string using the source dialect and returns a list of SQL strings
|
Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
|
||||||
transformed to conform to the target dialect. Each string in the returned list represents
|
to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.
|
||||||
a single transformed SQL statement.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql (str): the SQL code string to transpile.
|
sql: the SQL code string to transpile.
|
||||||
read (str): the source dialect used to parse the input string
|
read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
|
||||||
write (str): the target dialect into which the input should be transformed
|
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
the source and the target dialect.
|
||||||
identity (bool): if set to True and if the target dialect is not specified
|
error_level: the desired error level of the parser.
|
||||||
the source dialect will be used as both: the source and the target dialect.
|
|
||||||
error_level (ErrorLevel): the desired error level of the parser.
|
|
||||||
**opts: other options.
|
**opts: other options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
typing.List[str]: the list of transpiled SQL statements / expressions.
|
The list of transpiled SQL statements.
|
||||||
"""
|
"""
|
||||||
write = write or read if identity else write
|
write = write or read if identity else write
|
||||||
return [
|
return [
|
||||||
|
|
|
@ -49,7 +49,10 @@ args = parser.parse_args()
|
||||||
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
||||||
|
|
||||||
if args.parse:
|
if args.parse:
|
||||||
sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
|
sqls = [
|
||||||
|
repr(expression)
|
||||||
|
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
sqls = sqlglot.transpile(
|
sqls = sqlglot.transpile(
|
||||||
args.sql,
|
args.sql,
|
||||||
|
|
|
@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
|
||||||
from sqlglot.dataframe.sql.types import StructType
|
from sqlglot.dataframe.sql.types import StructType
|
||||||
|
|
||||||
ColumnLiterals = t.TypeVar(
|
ColumnLiterals = t.TypeVar(
|
||||||
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
"ColumnLiterals",
|
||||||
|
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||||
)
|
)
|
||||||
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
|
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
|
||||||
ColumnOrLiteral = t.TypeVar(
|
ColumnOrLiteral = t.TypeVar(
|
||||||
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
"ColumnOrLiteral",
|
||||||
|
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||||
|
)
|
||||||
|
SchemaInput = t.TypeVar(
|
||||||
|
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
|
||||||
|
)
|
||||||
|
OutputExpressionContainer = t.TypeVar(
|
||||||
|
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
|
||||||
)
|
)
|
||||||
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
|
|
||||||
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
|
|
||||||
|
|
|
@ -18,7 +18,11 @@ class Column:
|
||||||
expression = expression.expression # type: ignore
|
expression = expression.expression # type: ignore
|
||||||
elif expression is None or not isinstance(expression, (str, exp.Expression)):
|
elif expression is None or not isinstance(expression, (str, exp.Expression)):
|
||||||
expression = self._lit(expression).expression # type: ignore
|
expression = self._lit(expression).expression # type: ignore
|
||||||
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
|
|
||||||
|
expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||||
|
if expression is None:
|
||||||
|
raise ValueError(f"Could not parse {expression}")
|
||||||
|
self.expression: exp.Expression = expression
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return repr(self.expression)
|
return repr(self.expression)
|
||||||
|
@ -135,21 +139,29 @@ class Column:
|
||||||
) -> Column:
|
) -> Column:
|
||||||
ensured_column = None if column is None else cls.ensure_col(column)
|
ensured_column = None if column is None else cls.ensure_col(column)
|
||||||
ensure_expression_values = {
|
ensure_expression_values = {
|
||||||
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
|
k: [Column.ensure_col(x).expression for x in v]
|
||||||
|
if is_iterable(v)
|
||||||
|
else Column.ensure_col(v).expression
|
||||||
for k, v in kwargs.items()
|
for k, v in kwargs.items()
|
||||||
}
|
}
|
||||||
new_expression = (
|
new_expression = (
|
||||||
callable_expression(**ensure_expression_values)
|
callable_expression(**ensure_expression_values)
|
||||||
if ensured_column is None
|
if ensured_column is None
|
||||||
else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
|
else callable_expression(
|
||||||
|
this=ensured_column.column_expression, **ensure_expression_values
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return Column(new_expression)
|
return Column(new_expression)
|
||||||
|
|
||||||
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||||
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
|
return Column(
|
||||||
|
klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||||
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
|
return Column(
|
||||||
|
klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
|
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
|
||||||
return Column(klass(this=self.column_expression, **kwargs))
|
return Column(klass(this=self.column_expression, **kwargs))
|
||||||
|
@ -188,7 +200,7 @@ class Column:
|
||||||
expression.set("table", exp.to_identifier(table_name))
|
expression.set("table", exp.to_identifier(table_name))
|
||||||
return Column(expression)
|
return Column(expression)
|
||||||
|
|
||||||
def sql(self, **kwargs) -> Column:
|
def sql(self, **kwargs) -> str:
|
||||||
return self.expression.sql(**{"dialect": "spark", **kwargs})
|
return self.expression.sql(**{"dialect": "spark", **kwargs})
|
||||||
|
|
||||||
def alias(self, name: str) -> Column:
|
def alias(self, name: str) -> Column:
|
||||||
|
@ -265,10 +277,14 @@ class Column:
|
||||||
)
|
)
|
||||||
|
|
||||||
def like(self, other: str):
|
def like(self, other: str):
|
||||||
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
|
return self.invoke_expression_over_column(
|
||||||
|
self, exp.Like, expression=self._lit(other).expression
|
||||||
|
)
|
||||||
|
|
||||||
def ilike(self, other: str):
|
def ilike(self, other: str):
|
||||||
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
|
return self.invoke_expression_over_column(
|
||||||
|
self, exp.ILike, expression=self._lit(other).expression
|
||||||
|
)
|
||||||
|
|
||||||
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
|
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
|
||||||
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
|
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
|
||||||
|
@ -287,10 +303,18 @@ class Column:
|
||||||
lowerBound: t.Union[ColumnOrLiteral],
|
lowerBound: t.Union[ColumnOrLiteral],
|
||||||
upperBound: t.Union[ColumnOrLiteral],
|
upperBound: t.Union[ColumnOrLiteral],
|
||||||
) -> Column:
|
) -> Column:
|
||||||
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
|
lower_bound_exp = (
|
||||||
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
|
self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
|
||||||
|
)
|
||||||
|
upper_bound_exp = (
|
||||||
|
self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
|
||||||
|
)
|
||||||
return Column(
|
return Column(
|
||||||
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
|
exp.Between(
|
||||||
|
this=self.column_expression,
|
||||||
|
low=lower_bound_exp.expression,
|
||||||
|
high=upper_bound_exp.expression,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def over(self, window: WindowSpec) -> Column:
|
def over(self, window: WindowSpec) -> Column:
|
||||||
|
|
|
@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
|
||||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
|
from sqlglot.dataframe.sql._typing import (
|
||||||
|
ColumnLiterals,
|
||||||
|
ColumnOrLiteral,
|
||||||
|
ColumnOrName,
|
||||||
|
OutputExpressionContainer,
|
||||||
|
)
|
||||||
from sqlglot.dataframe.sql.session import SparkSession
|
from sqlglot.dataframe.sql.session import SparkSession
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +88,9 @@ class DataFrame:
|
||||||
return from_exp.alias_or_name
|
return from_exp.alias_or_name
|
||||||
table_alias = from_exp.find(exp.TableAlias)
|
table_alias = from_exp.find(exp.TableAlias)
|
||||||
if not table_alias:
|
if not table_alias:
|
||||||
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
|
raise RuntimeError(
|
||||||
|
f"Could not find an alias name for this expression: {self.expression}"
|
||||||
|
)
|
||||||
return table_alias.alias_or_name
|
return table_alias.alias_or_name
|
||||||
return self.expression.ctes[-1].alias
|
return self.expression.ctes[-1].alias
|
||||||
|
|
||||||
|
@ -132,12 +139,16 @@ class DataFrame:
|
||||||
cte.set("sequence_id", sequence_id or self.sequence_id)
|
cte.set("sequence_id", sequence_id or self.sequence_id)
|
||||||
return cte, name
|
return cte, name
|
||||||
|
|
||||||
def _ensure_list_of_columns(
|
@t.overload
|
||||||
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
|
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
||||||
) -> t.List[Column]:
|
...
|
||||||
columns = ensure_list(cols)
|
|
||||||
columns = Column.ensure_cols(columns)
|
@t.overload
|
||||||
return columns
|
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def _ensure_list_of_columns(self, cols):
|
||||||
|
return Column.ensure_cols(ensure_list(cols))
|
||||||
|
|
||||||
def _ensure_and_normalize_cols(self, cols):
|
def _ensure_and_normalize_cols(self, cols):
|
||||||
cols = self._ensure_list_of_columns(cols)
|
cols = self._ensure_list_of_columns(cols)
|
||||||
|
@ -153,10 +164,16 @@ class DataFrame:
|
||||||
df = self._resolve_pending_hints()
|
df = self._resolve_pending_hints()
|
||||||
sequence_id = sequence_id or df.sequence_id
|
sequence_id = sequence_id or df.sequence_id
|
||||||
expression = df.expression.copy()
|
expression = df.expression.copy()
|
||||||
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
|
cte_expression, cte_name = df._create_cte_from_expression(
|
||||||
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
|
expression=expression, sequence_id=sequence_id
|
||||||
|
)
|
||||||
|
new_expression = df._add_ctes_to_expression(
|
||||||
|
exp.Select(), expression.ctes + [cte_expression]
|
||||||
|
)
|
||||||
sel_columns = df._get_outer_select_columns(cte_expression)
|
sel_columns = df._get_outer_select_columns(cte_expression)
|
||||||
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
|
new_expression = new_expression.from_(cte_name).select(
|
||||||
|
*[x.alias_or_name for x in sel_columns]
|
||||||
|
)
|
||||||
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
||||||
|
|
||||||
def _resolve_pending_hints(self) -> DataFrame:
|
def _resolve_pending_hints(self) -> DataFrame:
|
||||||
|
@ -169,16 +186,23 @@ class DataFrame:
|
||||||
hint_expression.args.get("expressions").append(hint)
|
hint_expression.args.get("expressions").append(hint)
|
||||||
df.pending_hints.remove(hint)
|
df.pending_hints.remove(hint)
|
||||||
|
|
||||||
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
|
join_aliases = {
|
||||||
|
join_table.alias_or_name
|
||||||
|
for join_table in get_tables_from_expression_with_join(expression)
|
||||||
|
}
|
||||||
if join_aliases:
|
if join_aliases:
|
||||||
for hint in df.pending_join_hints:
|
for hint in df.pending_join_hints:
|
||||||
for sequence_id_expression in hint.expressions:
|
for sequence_id_expression in hint.expressions:
|
||||||
sequence_id_or_name = sequence_id_expression.alias_or_name
|
sequence_id_or_name = sequence_id_expression.alias_or_name
|
||||||
sequence_ids_to_match = [sequence_id_or_name]
|
sequence_ids_to_match = [sequence_id_or_name]
|
||||||
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
||||||
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
|
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
|
||||||
|
sequence_id_or_name
|
||||||
|
]
|
||||||
matching_ctes = [
|
matching_ctes = [
|
||||||
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
|
cte
|
||||||
|
for cte in reversed(expression.ctes)
|
||||||
|
if cte.args["sequence_id"] in sequence_ids_to_match
|
||||||
]
|
]
|
||||||
for matching_cte in matching_ctes:
|
for matching_cte in matching_ctes:
|
||||||
if matching_cte.alias_or_name in join_aliases:
|
if matching_cte.alias_or_name in join_aliases:
|
||||||
|
@ -193,9 +217,14 @@ class DataFrame:
|
||||||
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
||||||
hint_name = hint_name.upper()
|
hint_name = hint_name.upper()
|
||||||
hint_expression = (
|
hint_expression = (
|
||||||
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
|
exp.JoinHint(
|
||||||
|
this=hint_name,
|
||||||
|
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
|
||||||
|
)
|
||||||
if hint_name in JOIN_HINTS
|
if hint_name in JOIN_HINTS
|
||||||
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
|
else exp.Anonymous(
|
||||||
|
this=hint_name, expressions=[parameter.expression for parameter in args]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
new_df = self.copy()
|
new_df = self.copy()
|
||||||
new_df.pending_hints.append(hint_expression)
|
new_df.pending_hints.append(hint_expression)
|
||||||
|
@ -245,7 +274,9 @@ class DataFrame:
|
||||||
def _get_select_expressions(
|
def _get_select_expressions(
|
||||||
self,
|
self,
|
||||||
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
||||||
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
|
select_expressions: t.List[
|
||||||
|
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
|
||||||
|
] = []
|
||||||
main_select_ctes: t.List[exp.CTE] = []
|
main_select_ctes: t.List[exp.CTE] = []
|
||||||
for cte in self.expression.ctes:
|
for cte in self.expression.ctes:
|
||||||
cache_storage_level = cte.args.get("cache_storage_level")
|
cache_storage_level = cte.args.get("cache_storage_level")
|
||||||
|
@ -279,14 +310,19 @@ class DataFrame:
|
||||||
cache_table_name = df._create_hash_from_expression(select_expression)
|
cache_table_name = df._create_hash_from_expression(select_expression)
|
||||||
cache_table = exp.to_table(cache_table_name)
|
cache_table = exp.to_table(cache_table_name)
|
||||||
original_alias_name = select_expression.args["cte_alias_name"]
|
original_alias_name = select_expression.args["cte_alias_name"]
|
||||||
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
|
|
||||||
|
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
||||||
|
cache_table_name
|
||||||
|
)
|
||||||
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
|
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
|
||||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||||
options = [
|
options = [
|
||||||
exp.Literal.string("storageLevel"),
|
exp.Literal.string("storageLevel"),
|
||||||
exp.Literal.string(cache_storage_level),
|
exp.Literal.string(cache_storage_level),
|
||||||
]
|
]
|
||||||
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
|
expression = exp.Cache(
|
||||||
|
this=cache_table, expression=select_expression, lazy=True, options=options
|
||||||
|
)
|
||||||
# We will drop the "view" if it exists before running the cache table
|
# We will drop the "view" if it exists before running the cache table
|
||||||
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
||||||
elif expression_type == exp.Create:
|
elif expression_type == exp.Create:
|
||||||
|
@ -305,7 +341,9 @@ class DataFrame:
|
||||||
raise ValueError(f"Invalid expression type: {expression_type}")
|
raise ValueError(f"Invalid expression type: {expression_type}")
|
||||||
output_expressions.append(expression)
|
output_expressions.append(expression)
|
||||||
|
|
||||||
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
|
return [
|
||||||
|
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
|
||||||
|
]
|
||||||
|
|
||||||
def copy(self, **kwargs) -> DataFrame:
|
def copy(self, **kwargs) -> DataFrame:
|
||||||
return DataFrame(**object_to_dict(self, **kwargs))
|
return DataFrame(**object_to_dict(self, **kwargs))
|
||||||
|
@ -317,7 +355,9 @@ class DataFrame:
|
||||||
if self.expression.args.get("joins"):
|
if self.expression.args.get("joins"):
|
||||||
ambiguous_cols = [col for col in cols if not col.column_expression.table]
|
ambiguous_cols = [col for col in cols if not col.column_expression.table]
|
||||||
if ambiguous_cols:
|
if ambiguous_cols:
|
||||||
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
|
join_table_identifiers = [
|
||||||
|
x.this for x in get_tables_from_expression_with_join(self.expression)
|
||||||
|
]
|
||||||
cte_names_in_join = [x.this for x in join_table_identifiers]
|
cte_names_in_join = [x.this for x in join_table_identifiers]
|
||||||
for ambiguous_col in ambiguous_cols:
|
for ambiguous_col in ambiguous_cols:
|
||||||
ctes_with_column = [
|
ctes_with_column = [
|
||||||
|
@ -367,14 +407,20 @@ class DataFrame:
|
||||||
|
|
||||||
@operation(Operation.FROM)
|
@operation(Operation.FROM)
|
||||||
def join(
|
def join(
|
||||||
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
|
self,
|
||||||
|
other_df: DataFrame,
|
||||||
|
on: t.Union[str, t.List[str], Column, t.List[Column]],
|
||||||
|
how: str = "inner",
|
||||||
|
**kwargs,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
other_df = other_df._convert_leaf_to_cte()
|
other_df = other_df._convert_leaf_to_cte()
|
||||||
pre_join_self_latest_cte_name = self.latest_cte_name
|
pre_join_self_latest_cte_name = self.latest_cte_name
|
||||||
columns = self._ensure_and_normalize_cols(on)
|
columns = self._ensure_and_normalize_cols(on)
|
||||||
join_type = how.replace("_", " ")
|
join_type = how.replace("_", " ")
|
||||||
if isinstance(columns[0].expression, exp.Column):
|
if isinstance(columns[0].expression, exp.Column):
|
||||||
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
|
join_columns = [
|
||||||
|
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
|
||||||
|
]
|
||||||
join_clause = functools.reduce(
|
join_clause = functools.reduce(
|
||||||
lambda x, y: x & y,
|
lambda x, y: x & y,
|
||||||
[
|
[
|
||||||
|
@ -402,7 +448,9 @@ class DataFrame:
|
||||||
for column in self._get_outer_select_columns(other_df)
|
for column in self._get_outer_select_columns(other_df)
|
||||||
]
|
]
|
||||||
column_value_mapping = {
|
column_value_mapping = {
|
||||||
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
|
column.alias_or_name
|
||||||
|
if not isinstance(column.expression.this, exp.Star)
|
||||||
|
else column.sql(): column
|
||||||
for column in other_columns + self_columns + join_columns
|
for column in other_columns + self_columns + join_columns
|
||||||
}
|
}
|
||||||
all_columns = [
|
all_columns = [
|
||||||
|
@ -410,16 +458,22 @@ class DataFrame:
|
||||||
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
|
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
|
||||||
]
|
]
|
||||||
new_df = self.copy(
|
new_df = self.copy(
|
||||||
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
|
expression=self.expression.join(
|
||||||
|
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_df.expression = new_df._add_ctes_to_expression(
|
||||||
|
new_df.expression, other_df.expression.ctes
|
||||||
)
|
)
|
||||||
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
|
|
||||||
new_df.pending_hints.extend(other_df.pending_hints)
|
new_df.pending_hints.extend(other_df.pending_hints)
|
||||||
new_df = new_df.select.__wrapped__(new_df, *all_columns)
|
new_df = new_df.select.__wrapped__(new_df, *all_columns)
|
||||||
return new_df
|
return new_df
|
||||||
|
|
||||||
@operation(Operation.ORDER_BY)
|
@operation(Operation.ORDER_BY)
|
||||||
def orderBy(
|
def orderBy(
|
||||||
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
|
self,
|
||||||
|
*cols: t.Union[str, Column],
|
||||||
|
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
"""
|
"""
|
||||||
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
||||||
|
@ -429,7 +483,10 @@ class DataFrame:
|
||||||
columns = self._ensure_and_normalize_cols(cols)
|
columns = self._ensure_and_normalize_cols(cols)
|
||||||
pre_ordered_col_indexes = [
|
pre_ordered_col_indexes = [
|
||||||
x
|
x
|
||||||
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
|
for x in [
|
||||||
|
i if isinstance(col.expression, exp.Ordered) else None
|
||||||
|
for i, col in enumerate(columns)
|
||||||
|
]
|
||||||
if x is not None
|
if x is not None
|
||||||
]
|
]
|
||||||
if ascending is None:
|
if ascending is None:
|
||||||
|
@ -478,7 +535,9 @@ class DataFrame:
|
||||||
for r_column in r_columns_unused:
|
for r_column in r_columns_unused:
|
||||||
l_expressions.append(exp.alias_(exp.Null(), r_column))
|
l_expressions.append(exp.alias_(exp.Null(), r_column))
|
||||||
r_expressions.append(r_column)
|
r_expressions.append(r_column)
|
||||||
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
r_df = (
|
||||||
|
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
||||||
|
)
|
||||||
l_df = self.copy()
|
l_df = self.copy()
|
||||||
if allowMissingColumns:
|
if allowMissingColumns:
|
||||||
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
||||||
|
@ -536,7 +595,9 @@ class DataFrame:
|
||||||
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
||||||
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
||||||
)
|
)
|
||||||
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
|
if_null_checks = [
|
||||||
|
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
|
||||||
|
]
|
||||||
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
||||||
num_nulls = nulls_added_together.alias("num_nulls")
|
num_nulls = nulls_added_together.alias("num_nulls")
|
||||||
new_df = new_df.select(num_nulls, append=True)
|
new_df = new_df.select(num_nulls, append=True)
|
||||||
|
@ -576,11 +637,15 @@ class DataFrame:
|
||||||
value_columns = [lit(value) for value in values]
|
value_columns = [lit(value) for value in values]
|
||||||
|
|
||||||
null_replacement_mapping = {
|
null_replacement_mapping = {
|
||||||
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
|
column.alias_or_name: (
|
||||||
|
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
|
||||||
|
)
|
||||||
for column, value in zip(columns, value_columns)
|
for column, value in zip(columns, value_columns)
|
||||||
}
|
}
|
||||||
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
||||||
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
|
null_replacement_columns = [
|
||||||
|
null_replacement_mapping[column.alias_or_name] for column in all_columns
|
||||||
|
]
|
||||||
new_df = new_df.select(*null_replacement_columns)
|
new_df = new_df.select(*null_replacement_columns)
|
||||||
return new_df
|
return new_df
|
||||||
|
|
||||||
|
@ -589,12 +654,11 @@ class DataFrame:
|
||||||
self,
|
self,
|
||||||
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
||||||
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
||||||
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
from sqlglot.dataframe.sql.functions import lit
|
from sqlglot.dataframe.sql.functions import lit
|
||||||
|
|
||||||
old_values = None
|
old_values = None
|
||||||
subset = ensure_list(subset)
|
|
||||||
new_df = self.copy()
|
new_df = self.copy()
|
||||||
all_columns = self._get_outer_select_columns(new_df.expression)
|
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||||
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
||||||
|
@ -605,7 +669,9 @@ class DataFrame:
|
||||||
new_values = list(to_replace.values())
|
new_values = list(to_replace.values())
|
||||||
elif not old_values and isinstance(to_replace, list):
|
elif not old_values and isinstance(to_replace, list):
|
||||||
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
||||||
assert len(to_replace) == len(value), "the replacements and values must be the same length"
|
assert len(to_replace) == len(
|
||||||
|
value
|
||||||
|
), "the replacements and values must be the same length"
|
||||||
old_values = to_replace
|
old_values = to_replace
|
||||||
new_values = value
|
new_values = value
|
||||||
else:
|
else:
|
||||||
|
@ -635,7 +701,9 @@ class DataFrame:
|
||||||
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
||||||
col = self._ensure_and_normalize_col(col)
|
col = self._ensure_and_normalize_col(col)
|
||||||
existing_col_names = self.expression.named_selects
|
existing_col_names = self.expression.named_selects
|
||||||
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
|
existing_col_index = (
|
||||||
|
existing_col_names.index(colName) if colName in existing_col_names else None
|
||||||
|
)
|
||||||
if existing_col_index:
|
if existing_col_index:
|
||||||
expression = self.expression.copy()
|
expression = self.expression.copy()
|
||||||
expression.expressions[existing_col_index] = col.expression
|
expression.expressions[existing_col_index] = col.expression
|
||||||
|
@ -645,7 +713,11 @@ class DataFrame:
|
||||||
@operation(Operation.SELECT)
|
@operation(Operation.SELECT)
|
||||||
def withColumnRenamed(self, existing: str, new: str):
|
def withColumnRenamed(self, existing: str, new: str):
|
||||||
expression = self.expression.copy()
|
expression = self.expression.copy()
|
||||||
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
|
existing_columns = [
|
||||||
|
expression
|
||||||
|
for expression in expression.expressions
|
||||||
|
if expression.alias_or_name == existing
|
||||||
|
]
|
||||||
if not existing_columns:
|
if not existing_columns:
|
||||||
raise ValueError("Tried to rename a column that doesn't exist")
|
raise ValueError("Tried to rename a column that doesn't exist")
|
||||||
for existing_column in existing_columns:
|
for existing_column in existing_columns:
|
||||||
|
@ -674,15 +746,19 @@ class DataFrame:
|
||||||
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
||||||
parameter_list = ensure_list(parameters)
|
parameter_list = ensure_list(parameters)
|
||||||
parameter_columns = (
|
parameter_columns = (
|
||||||
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
|
self._ensure_list_of_columns(parameter_list)
|
||||||
|
if parameters
|
||||||
|
else Column.ensure_cols([self.sequence_id])
|
||||||
)
|
)
|
||||||
return self._hint(name, parameter_columns)
|
return self._hint(name, parameter_columns)
|
||||||
|
|
||||||
@operation(Operation.NO_OP)
|
@operation(Operation.NO_OP)
|
||||||
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
|
def repartition(
|
||||||
num_partitions = Column.ensure_cols(ensure_list(numPartitions))
|
self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
|
||||||
|
) -> DataFrame:
|
||||||
|
num_partition_cols = self._ensure_list_of_columns(numPartitions)
|
||||||
columns = self._ensure_and_normalize_cols(cols)
|
columns = self._ensure_and_normalize_cols(cols)
|
||||||
args = num_partitions + columns
|
args = num_partition_cols + columns
|
||||||
return self._hint("repartition", args)
|
return self._hint("repartition", args)
|
||||||
|
|
||||||
@operation(Operation.NO_OP)
|
@operation(Operation.NO_OP)
|
||||||
|
|
|
@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
||||||
|
|
||||||
def when(condition: Column, value: t.Any) -> Column:
|
def when(condition: Column, value: t.Any) -> Column:
|
||||||
true_value = value if isinstance(value, Column) else lit(value)
|
true_value = value if isinstance(value, Column) else lit(value)
|
||||||
return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
|
return Column(
|
||||||
|
glotexp.Case(
|
||||||
|
ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def asc(col: ColumnOrName) -> Column:
|
def asc(col: ColumnOrName) -> Column:
|
||||||
|
@ -407,7 +411,9 @@ def percentile_approx(
|
||||||
return Column.invoke_expression_over_column(
|
return Column.invoke_expression_over_column(
|
||||||
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
||||||
)
|
)
|
||||||
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
|
return Column.invoke_expression_over_column(
|
||||||
|
col, glotexp.ApproxQuantile, quantile=lit(percentage)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
||||||
|
@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "FACTORIAL")
|
return Column.invoke_anonymous_function(col, "FACTORIAL")
|
||||||
|
|
||||||
|
|
||||||
def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
|
def lag(
|
||||||
|
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
|
||||||
|
) -> Column:
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return Column.invoke_anonymous_function(col, "LAG", offset, default)
|
return Column.invoke_anonymous_function(col, "LAG", offset, default)
|
||||||
if offset != 1:
|
if offset != 1:
|
||||||
|
@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
|
||||||
return Column.invoke_anonymous_function(col, "LAG")
|
return Column.invoke_anonymous_function(col, "LAG")
|
||||||
|
|
||||||
|
|
||||||
def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
|
def lead(
|
||||||
|
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
|
||||||
|
) -> Column:
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
|
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
|
||||||
if offset != 1:
|
if offset != 1:
|
||||||
|
@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
|
||||||
return Column.invoke_anonymous_function(col, "LEAD")
|
return Column.invoke_anonymous_function(col, "LEAD")
|
||||||
|
|
||||||
|
|
||||||
def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
|
def nth_value(
|
||||||
|
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
|
||||||
|
) -> Column:
|
||||||
if ignoreNulls is not None:
|
if ignoreNulls is not None:
|
||||||
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
|
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
|
||||||
if offset != 1:
|
if offset != 1:
|
||||||
|
@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
|
||||||
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
|
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
|
||||||
|
|
||||||
|
|
||||||
def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
|
def months_between(
|
||||||
|
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
|
||||||
|
) -> Column:
|
||||||
if roundOff is None:
|
if roundOff is None:
|
||||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
|
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
|
||||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
|
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
|
||||||
|
@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
||||||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
|
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
|
||||||
|
|
||||||
|
|
||||||
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
|
def unix_timestamp(
|
||||||
|
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
|
||||||
|
) -> Column:
|
||||||
if format is not None:
|
if format is not None:
|
||||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
|
return Column.invoke_expression_over_column(
|
||||||
|
timestamp, glotexp.StrToUnix, format=lit(format)
|
||||||
|
)
|
||||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
|
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
|
||||||
|
|
||||||
|
|
||||||
|
@ -642,7 +660,9 @@ def window(
|
||||||
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
|
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
|
||||||
)
|
)
|
||||||
if slideDuration is not None:
|
if slideDuration is not None:
|
||||||
return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
|
return Column.invoke_anonymous_function(
|
||||||
|
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
|
||||||
|
)
|
||||||
if startTime is not None:
|
if startTime is not None:
|
||||||
return Column.invoke_anonymous_function(
|
return Column.invoke_anonymous_function(
|
||||||
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
|
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
|
||||||
|
@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
|
||||||
|
|
||||||
|
|
||||||
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
||||||
return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
|
return Column.invoke_expression_over_column(
|
||||||
|
None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode(col: ColumnOrName, charset: str) -> Column:
|
def decode(col: ColumnOrName, charset: str) -> Column:
|
||||||
|
@ -768,7 +790,9 @@ def overlay(
|
||||||
|
|
||||||
|
|
||||||
def sentences(
|
def sentences(
|
||||||
string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
|
string: ColumnOrName,
|
||||||
|
language: t.Optional[ColumnOrName] = None,
|
||||||
|
country: t.Optional[ColumnOrName] = None,
|
||||||
) -> Column:
|
) -> Column:
|
||||||
if language is not None and country is not None:
|
if language is not None and country is not None:
|
||||||
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
|
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
|
||||||
|
@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
|
||||||
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
|
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
|
||||||
substr_col = lit(substr)
|
substr_col = lit(substr)
|
||||||
if pos is not None:
|
if pos is not None:
|
||||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
|
return Column.invoke_expression_over_column(
|
||||||
|
str, glotexp.StrPosition, substr=substr_col, position=pos
|
||||||
|
)
|
||||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
|
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
|
||||||
|
|
||||||
|
|
||||||
|
@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||||
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||||
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
|
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
|
||||||
return Column.invoke_expression_over_column(
|
return Column.invoke_expression_over_column(
|
||||||
None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
|
None,
|
||||||
|
glotexp.VarMap,
|
||||||
|
keys=array(*cols[::2]).expression,
|
||||||
|
values=array(*cols[1::2]).expression,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||||
|
|
||||||
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||||
value_col = value if isinstance(value, Column) else lit(value)
|
value_col = value if isinstance(value, Column) else lit(value)
|
||||||
return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
|
return Column.invoke_expression_over_column(
|
||||||
|
col, glotexp.ArrayContains, expression=value_col.expression
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
|
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
|
||||||
|
|
||||||
|
|
||||||
def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
|
def slice(
|
||||||
|
x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
|
||||||
|
) -> Column:
|
||||||
start_col = start if isinstance(start, Column) else lit(start)
|
start_col = start if isinstance(start, Column) else lit(start)
|
||||||
length_col = length if isinstance(length, Column) else lit(length)
|
length_col = length if isinstance(length, Column) else lit(length)
|
||||||
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
|
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
|
||||||
|
|
||||||
|
|
||||||
def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
|
def array_join(
|
||||||
|
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
|
||||||
|
) -> Column:
|
||||||
if null_replacement is not None:
|
if null_replacement is not None:
|
||||||
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
|
return Column.invoke_anonymous_function(
|
||||||
|
col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
|
||||||
|
)
|
||||||
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
|
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
|
||||||
|
|
||||||
|
|
||||||
def concat(*cols: ColumnOrName) -> Column:
|
def concat(*cols: ColumnOrName) -> Column:
|
||||||
if len(cols) == 1:
|
if len(cols) == 1:
|
||||||
return Column.invoke_anonymous_function(cols[0], "CONCAT")
|
return Column.invoke_anonymous_function(cols[0], "CONCAT")
|
||||||
return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
|
return Column.invoke_anonymous_function(
|
||||||
|
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||||
|
@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
|
||||||
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
|
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
|
||||||
|
|
||||||
|
|
||||||
def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
|
def sequence(
|
||||||
|
start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
|
||||||
|
) -> Column:
|
||||||
if step is not None:
|
if step is not None:
|
||||||
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
|
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
|
||||||
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
|
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
|
||||||
|
@ -1103,12 +1144,15 @@ def aggregate(
|
||||||
merge_exp = _get_lambda_from_func(merge)
|
merge_exp = _get_lambda_from_func(merge)
|
||||||
if finish is not None:
|
if finish is not None:
|
||||||
finish_exp = _get_lambda_from_func(finish)
|
finish_exp = _get_lambda_from_func(finish)
|
||||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
|
return Column.invoke_anonymous_function(
|
||||||
|
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
|
||||||
|
)
|
||||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
||||||
|
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
|
col: ColumnOrName,
|
||||||
|
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||||
) -> Column:
|
) -> Column:
|
||||||
f_expression = _get_lambda_from_func(f)
|
f_expression = _get_lambda_from_func(f)
|
||||||
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
|
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
|
||||||
|
@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
|
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
|
||||||
|
|
||||||
|
|
||||||
def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
|
def filter(
|
||||||
|
col: ColumnOrName,
|
||||||
|
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||||
|
) -> Column:
|
||||||
f_expression = _get_lambda_from_func(f)
|
f_expression = _get_lambda_from_func(f)
|
||||||
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
|
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
|
||||||
|
|
||||||
|
|
||||||
def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
|
def zip_with(
|
||||||
|
left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
|
||||||
|
) -> Column:
|
||||||
f_expression = _get_lambda_from_func(f)
|
f_expression = _get_lambda_from_func(f)
|
||||||
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
|
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
|
||||||
|
|
||||||
|
@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
|
||||||
|
|
||||||
|
|
||||||
def _get_lambda_from_func(lambda_expression: t.Callable):
|
def _get_lambda_from_func(lambda_expression: t.Callable):
|
||||||
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
|
variables = [
|
||||||
|
glotexp.to_identifier(x, quoted=_lambda_quoted(x))
|
||||||
|
for x in lambda_expression.__code__.co_varnames
|
||||||
|
]
|
||||||
return glotexp.Lambda(
|
return glotexp.Lambda(
|
||||||
this=lambda_expression(*[Column(x) for x in variables]).expression,
|
this=lambda_expression(*[Column(x) for x in variables]).expression,
|
||||||
expressions=variables,
|
expressions=variables,
|
||||||
|
|
|
@ -17,7 +17,9 @@ class GroupedData:
|
||||||
self.last_op = last_op
|
self.last_op = last_op
|
||||||
self.group_by_cols = group_by_cols
|
self.group_by_cols = group_by_cols
|
||||||
|
|
||||||
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
|
def _get_function_applied_columns(
|
||||||
|
self, func_name: str, cols: t.Tuple[str, ...]
|
||||||
|
) -> t.List[Column]:
|
||||||
func_name = func_name.lower()
|
func_name = func_name.lower()
|
||||||
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
|
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
|
||||||
|
|
||||||
|
@ -30,9 +32,9 @@ class GroupedData:
|
||||||
)
|
)
|
||||||
cols = self._df._ensure_and_normalize_cols(columns)
|
cols = self._df._ensure_and_normalize_cols(columns)
|
||||||
|
|
||||||
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
|
expression = self._df.expression.group_by(
|
||||||
*[x.expression for x in self.group_by_cols + cols], append=False
|
*[x.expression for x in self.group_by_cols]
|
||||||
)
|
).select(*[x.expression for x in self.group_by_cols + cols], append=False)
|
||||||
return self._df.copy(expression=expression)
|
return self._df.copy(expression=expression)
|
||||||
|
|
||||||
def count(self) -> DataFrame:
|
def count(self) -> DataFrame:
|
||||||
|
|
|
@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
|
||||||
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
||||||
|
|
||||||
|
|
||||||
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
|
def replace_alias_name_with_cte_name(
|
||||||
|
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
|
||||||
|
):
|
||||||
if id.alias_or_name in spark.name_to_sequence_id_mapping:
|
if id.alias_or_name in spark.name_to_sequence_id_mapping:
|
||||||
for cte in reversed(expression_context.ctes):
|
for cte in reversed(expression_context.ctes):
|
||||||
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
|
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
|
||||||
|
@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
|
||||||
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
||||||
# be common in practice
|
# be common in practice
|
||||||
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
|
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
|
||||||
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
|
join_table_aliases = [
|
||||||
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
|
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
|
||||||
|
]
|
||||||
|
ctes_in_join = [
|
||||||
|
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
|
||||||
|
]
|
||||||
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
|
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
|
||||||
assert len(ctes_in_join) == 2
|
assert len(ctes_in_join) == 2
|
||||||
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
||||||
|
@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
|
||||||
|
|
||||||
|
|
||||||
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
|
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
|
||||||
values = ensure_list(values)
|
|
||||||
results = []
|
results = []
|
||||||
for value in values:
|
for value in values:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
|
|
@ -19,12 +19,19 @@ class DataFrameReader:
|
||||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||||
|
|
||||||
sqlglot.schema.add_table(tableName)
|
sqlglot.schema.add_table(tableName)
|
||||||
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
|
return DataFrame(
|
||||||
|
self.spark,
|
||||||
|
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataFrameWriter:
|
class DataFrameWriter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
|
self,
|
||||||
|
df: DataFrame,
|
||||||
|
spark: t.Optional[SparkSession] = None,
|
||||||
|
mode: t.Optional[str] = None,
|
||||||
|
by_name: bool = False,
|
||||||
):
|
):
|
||||||
self._df = df
|
self._df = df
|
||||||
self._spark = spark or df.spark
|
self._spark = spark or df.spark
|
||||||
|
@ -33,7 +40,10 @@ class DataFrameWriter:
|
||||||
|
|
||||||
def copy(self, **kwargs) -> DataFrameWriter:
|
def copy(self, **kwargs) -> DataFrameWriter:
|
||||||
return DataFrameWriter(
|
return DataFrameWriter(
|
||||||
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
|
**{
|
||||||
|
k[1:] if k.startswith("_") else k: v
|
||||||
|
for k, v in object_to_dict(self, **kwargs).items()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def sql(self, **kwargs) -> t.List[str]:
|
def sql(self, **kwargs) -> t.List[str]:
|
||||||
|
|
|
@ -67,13 +67,20 @@ class SparkSession:
|
||||||
|
|
||||||
data_expressions = [
|
data_expressions = [
|
||||||
exp.Tuple(
|
exp.Tuple(
|
||||||
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
|
expressions=list(
|
||||||
|
map(
|
||||||
|
lambda x: F.lit(x).expression,
|
||||||
|
row if not isinstance(row, dict) else row.values(),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for row in data
|
for row in data
|
||||||
]
|
]
|
||||||
|
|
||||||
sel_columns = [
|
sel_columns = [
|
||||||
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
|
F.col(name).cast(data_type).alias(name).expression
|
||||||
|
if data_type is not None
|
||||||
|
else F.col(name).expression
|
||||||
for name, data_type in column_mapping.items()
|
for name, data_type in column_mapping.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -106,10 +113,12 @@ class SparkSession:
|
||||||
select_expression.set("with", expression.args.get("with"))
|
select_expression.set("with", expression.args.get("with"))
|
||||||
expression.set("with", None)
|
expression.set("with", None)
|
||||||
del expression.args["expression"]
|
del expression.args["expression"]
|
||||||
df = DataFrame(self, select_expression, output_expression_container=expression)
|
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
|
||||||
df = df._convert_leaf_to_cte()
|
df = df._convert_leaf_to_cte()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
|
raise ValueError(
|
||||||
|
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
|
||||||
|
)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -158,7 +158,11 @@ class MapType(DataType):
|
||||||
|
|
||||||
class StructField(DataType):
|
class StructField(DataType):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
|
self,
|
||||||
|
name: str,
|
||||||
|
dataType: DataType,
|
||||||
|
nullable: bool = True,
|
||||||
|
metadata: t.Optional[t.Dict[str, t.Any]] = None,
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.dataType = dataType
|
self.dataType = dataType
|
||||||
|
|
|
@ -74,8 +74,13 @@ class WindowSpec:
|
||||||
window_spec.expression.args["order"].set("expressions", order_by)
|
window_spec.expression.args["order"].set("expressions", order_by)
|
||||||
return window_spec
|
return window_spec
|
||||||
|
|
||||||
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
|
def _calc_start_end(
|
||||||
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
|
self, start: int, end: int
|
||||||
|
) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
|
||||||
|
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
|
||||||
|
"start_side": None,
|
||||||
|
"end_side": None,
|
||||||
|
}
|
||||||
if start == Window.currentRow:
|
if start == Window.currentRow:
|
||||||
kwargs["start"] = "CURRENT ROW"
|
kwargs["start"] = "CURRENT ROW"
|
||||||
else:
|
else:
|
||||||
|
@ -83,7 +88,9 @@ class WindowSpec:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
"start_side": "PRECEDING",
|
"start_side": "PRECEDING",
|
||||||
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
|
"start": "UNBOUNDED"
|
||||||
|
if start <= Window.unboundedPreceding
|
||||||
|
else F.lit(start).expression,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if end == Window.currentRow:
|
if end == Window.currentRow:
|
||||||
|
@ -93,7 +100,9 @@ class WindowSpec:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
"end_side": "FOLLOWING",
|
"end_side": "FOLLOWING",
|
||||||
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
|
"end": "UNBOUNDED"
|
||||||
|
if end >= Window.unboundedFollowing
|
||||||
|
else F.lit(end).expression,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return kwargs
|
return kwargs
|
||||||
|
@ -103,7 +112,10 @@ class WindowSpec:
|
||||||
spec = self._calc_start_end(start, end)
|
spec = self._calc_start_end(start, end)
|
||||||
spec["kind"] = "ROWS"
|
spec["kind"] = "ROWS"
|
||||||
window_spec.expression.set(
|
window_spec.expression.set(
|
||||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
"spec",
|
||||||
|
exp.WindowSpec(
|
||||||
|
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return window_spec
|
return window_spec
|
||||||
|
|
||||||
|
@ -112,6 +124,9 @@ class WindowSpec:
|
||||||
spec = self._calc_start_end(start, end)
|
spec = self._calc_start_end(start, end)
|
||||||
spec["kind"] = "RANGE"
|
spec["kind"] = "RANGE"
|
||||||
window_spec.expression.set(
|
window_spec.expression.set(
|
||||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
"spec",
|
||||||
|
exp.WindowSpec(
|
||||||
|
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return window_spec
|
return window_spec
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
inline_array_sql,
|
inline_array_sql,
|
||||||
no_ilike_sql,
|
no_ilike_sql,
|
||||||
rename_func,
|
rename_func,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def _date_add(expression_class):
|
def _date_add(expression_class):
|
||||||
def func(args):
|
def func(args):
|
||||||
interval = list_get(args, 1)
|
interval = seq_get(args, 1)
|
||||||
return expression_class(
|
return expression_class(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=interval.this,
|
expression=interval.this,
|
||||||
unit=interval.args.get("unit"),
|
unit=interval.args.get("unit"),
|
||||||
)
|
)
|
||||||
|
@ -23,6 +23,13 @@ def _date_add(expression_class):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _date_trunc(args):
|
||||||
|
unit = seq_get(args, 1)
|
||||||
|
if isinstance(unit, exp.Column):
|
||||||
|
unit = exp.Var(this=unit.name)
|
||||||
|
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||||
|
|
||||||
|
|
||||||
def _date_add_sql(data_type, kind):
|
def _date_add_sql(data_type, kind):
|
||||||
def func(self, expression):
|
def func(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
|
||||||
structs = []
|
structs = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
aliases = [
|
aliases = [
|
||||||
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
exp.alias_(value, column_name)
|
||||||
|
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||||
]
|
]
|
||||||
structs.append(exp.Struct(expressions=aliases))
|
structs.append(exp.Struct(expressions=aliases))
|
||||||
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
|
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
|
||||||
|
@ -89,18 +97,19 @@ class BigQuery(Dialect):
|
||||||
"%j": "%-j",
|
"%j": "%-j",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = [
|
QUOTES = [
|
||||||
(prefix + quote, quote) if prefix else quote
|
(prefix + quote, quote) if prefix else quote
|
||||||
for quote in ["'", '"', '"""', "'''"]
|
for quote in ["'", '"', '"""', "'''"]
|
||||||
for prefix in ["", "r", "R"]
|
for prefix in ["", "r", "R"]
|
||||||
]
|
]
|
||||||
|
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPE = "\\"
|
ESCAPES = ["\\"]
|
||||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
|
@ -111,35 +120,40 @@ class BigQuery(Dialect):
|
||||||
"WINDOW": TokenType.WINDOW,
|
"WINDOW": TokenType.WINDOW,
|
||||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||||
}
|
}
|
||||||
|
KEYWORDS.pop("DIV")
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
|
"DATE_TRUNC": _date_trunc,
|
||||||
"DATE_ADD": _date_add(exp.DateAdd),
|
"DATE_ADD": _date_add(exp.DateAdd),
|
||||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||||
|
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||||
"DATE_SUB": _date_add(exp.DateSub),
|
"DATE_SUB": _date_add(exp.DateSub),
|
||||||
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
||||||
"TIME_SUB": _date_add(exp.TimeSub),
|
"TIME_SUB": _date_add(exp.TimeSub),
|
||||||
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
||||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
|
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
|
||||||
|
this=seq_get(args, 1), format=seq_get(args, 0)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_PAREN_FUNCTIONS = {
|
NO_PAREN_FUNCTIONS = {
|
||||||
**Parser.NO_PAREN_FUNCTIONS,
|
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||||
TokenType.CURRENT_TIME: exp.CurrentTime,
|
TokenType.CURRENT_TIME: exp.CurrentTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
NESTED_TYPE_TOKENS = {
|
NESTED_TYPE_TOKENS = {
|
||||||
*Parser.NESTED_TYPE_TOKENS,
|
*parser.Parser.NESTED_TYPE_TOKENS,
|
||||||
TokenType.TABLE,
|
TokenType.TABLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.Array: inline_array_sql,
|
exp.Array: inline_array_sql,
|
||||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||||
|
@ -148,6 +162,7 @@ class BigQuery(Dialect):
|
||||||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
|
exp.IntDiv: rename_func("DIV"),
|
||||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||||
|
@ -157,11 +172,13 @@ class BigQuery(Dialect):
|
||||||
exp.Values: _derived_table_values_to_unnest,
|
exp.Values: _derived_table_values_to_unnest,
|
||||||
exp.ReturnsProperty: _returnsproperty_sql,
|
exp.ReturnsProperty: _returnsproperty_sql,
|
||||||
exp.Create: _create_sql,
|
exp.Create: _create_sql,
|
||||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
|
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||||
|
if e.name == "IMMUTABLE"
|
||||||
|
else "NOT DETERMINISTIC",
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TINYINT: "INT64",
|
exp.DataType.Type.TINYINT: "INT64",
|
||||||
exp.DataType.Type.SMALLINT: "INT64",
|
exp.DataType.Type.SMALLINT: "INT64",
|
||||||
exp.DataType.Type.INT: "INT64",
|
exp.DataType.Type.INT: "INT64",
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.parser import parse_var_map
|
||||||
from sqlglot.parser import Parser, parse_var_map
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def _lower_func(sql):
|
def _lower_func(sql):
|
||||||
|
@ -14,11 +15,12 @@ class ClickHouse(Dialect):
|
||||||
normalize_functions = None
|
normalize_functions = None
|
||||||
null_ordering = "nulls_are_last"
|
null_ordering = "nulls_are_last"
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
|
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||||
IDENTIFIERS = ['"', "`"]
|
IDENTIFIERS = ['"', "`"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"FINAL": TokenType.FINAL,
|
"FINAL": TokenType.FINAL,
|
||||||
"DATETIME64": TokenType.DATETIME,
|
"DATETIME64": TokenType.DATETIME,
|
||||||
"INT8": TokenType.TINYINT,
|
"INT8": TokenType.TINYINT,
|
||||||
|
@ -30,9 +32,9 @@ class ClickHouse(Dialect):
|
||||||
"TUPLE": TokenType.STRUCT,
|
"TUPLE": TokenType.STRUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"MAP": parse_var_map,
|
"MAP": parse_var_map,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,11 +46,11 @@ class ClickHouse(Dialect):
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
STRUCT_DELIMITER = ("(", ")")
|
STRUCT_DELIMITER = ("(", ")")
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.NULLABLE: "Nullable",
|
exp.DataType.Type.NULLABLE: "Nullable",
|
||||||
exp.DataType.Type.DATETIME: "DateTime64",
|
exp.DataType.Type.DATETIME: "DateTime64",
|
||||||
exp.DataType.Type.MAP: "Map",
|
exp.DataType.Type.MAP: "Map",
|
||||||
|
@ -63,7 +65,7 @@ class ClickHouse(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.Array: inline_array_sql,
|
exp.Array: inline_array_sql,
|
||||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.dialect import parse_date_delta
|
from sqlglot.dialects.dialect import parse_date_delta
|
||||||
from sqlglot.dialects.spark import Spark
|
from sqlglot.dialects.spark import Spark
|
||||||
|
@ -15,7 +17,7 @@ class Databricks(Spark):
|
||||||
|
|
||||||
class Generator(Spark.Generator):
|
class Generator(Spark.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Spark.Generator.TRANSFORMS,
|
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
from sqlglot.helper import flatten, list_get
|
from sqlglot.helper import flatten, seq_get
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.time import format_time
|
from sqlglot.time import format_time
|
||||||
from sqlglot.tokens import Tokenizer
|
from sqlglot.tokens import Tokenizer
|
||||||
|
@ -32,7 +35,7 @@ class Dialects(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class _Dialect(type):
|
class _Dialect(type):
|
||||||
classes = {}
|
classes: t.Dict[str, Dialect] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __getitem__(cls, key):
|
def __getitem__(cls, key):
|
||||||
|
@ -56,19 +59,30 @@ class _Dialect(type):
|
||||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||||
|
|
||||||
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||||
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
|
klass.identifier_start, klass.identifier_end = list(
|
||||||
|
klass.tokenizer_class._IDENTIFIERS.items()
|
||||||
|
)[0]
|
||||||
|
|
||||||
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
|
if (
|
||||||
|
klass.tokenizer_class._BIT_STRINGS
|
||||||
|
and exp.BitString not in klass.generator_class.TRANSFORMS
|
||||||
|
):
|
||||||
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
||||||
klass.generator_class.TRANSFORMS[
|
klass.generator_class.TRANSFORMS[
|
||||||
exp.BitString
|
exp.BitString
|
||||||
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
||||||
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
|
if (
|
||||||
|
klass.tokenizer_class._HEX_STRINGS
|
||||||
|
and exp.HexString not in klass.generator_class.TRANSFORMS
|
||||||
|
):
|
||||||
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
||||||
klass.generator_class.TRANSFORMS[
|
klass.generator_class.TRANSFORMS[
|
||||||
exp.HexString
|
exp.HexString
|
||||||
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||||
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
|
if (
|
||||||
|
klass.tokenizer_class._BYTE_STRINGS
|
||||||
|
and exp.ByteString not in klass.generator_class.TRANSFORMS
|
||||||
|
):
|
||||||
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
||||||
klass.generator_class.TRANSFORMS[
|
klass.generator_class.TRANSFORMS[
|
||||||
exp.ByteString
|
exp.ByteString
|
||||||
|
@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
|
||||||
index_offset = 0
|
index_offset = 0
|
||||||
unnest_column_only = False
|
unnest_column_only = False
|
||||||
alias_post_tablesample = False
|
alias_post_tablesample = False
|
||||||
normalize_functions = "upper"
|
normalize_functions: t.Optional[str] = "upper"
|
||||||
null_ordering = "nulls_are_small"
|
null_ordering = "nulls_are_small"
|
||||||
|
|
||||||
date_format = "'%Y-%m-%d'"
|
date_format = "'%Y-%m-%d'"
|
||||||
dateint_format = "'%Y%m%d'"
|
dateint_format = "'%Y%m%d'"
|
||||||
time_format = "'%Y-%m-%d %H:%M:%S'"
|
time_format = "'%Y-%m-%d %H:%M:%S'"
|
||||||
time_mapping = {}
|
time_mapping: t.Dict[str, str] = {}
|
||||||
|
|
||||||
# autofilled
|
# autofilled
|
||||||
quote_start = None
|
quote_start = None
|
||||||
|
@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
|
||||||
"quote_end": self.quote_end,
|
"quote_end": self.quote_end,
|
||||||
"identifier_start": self.identifier_start,
|
"identifier_start": self.identifier_start,
|
||||||
"identifier_end": self.identifier_end,
|
"identifier_end": self.identifier_end,
|
||||||
"escape": self.tokenizer_class.ESCAPE,
|
"escape": self.tokenizer_class.ESCAPES[0],
|
||||||
"index_offset": self.index_offset,
|
"index_offset": self.index_offset,
|
||||||
"time_mapping": self.inverse_time_mapping,
|
"time_mapping": self.inverse_time_mapping,
|
||||||
"time_trie": self.inverse_time_trie,
|
"time_trie": self.inverse_time_trie,
|
||||||
|
@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self, expression):
|
||||||
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
|
expressions = self.format_args(
|
||||||
|
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||||
|
)
|
||||||
return f"IF({expressions})"
|
return f"IF({expressions})"
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
|
||||||
|
|
||||||
def _format_time(args):
|
def _format_time(args):
|
||||||
return exp_class(
|
return exp_class(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
format=Dialect[dialect].format_time(
|
format=Dialect[dialect].format_time(
|
||||||
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
|
||||||
"expressions",
|
"expressions",
|
||||||
[e for e in schema.expressions if e not in partitions],
|
[e for e in schema.expressions if e not in partitions],
|
||||||
)
|
)
|
||||||
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
|
prop.replace(
|
||||||
|
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
|
||||||
|
)
|
||||||
expression.set("this", schema)
|
expression.set("this", schema)
|
||||||
|
|
||||||
return self.create_sql(expression)
|
return self.create_sql(expression)
|
||||||
|
@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
|
||||||
def parse_date_delta(exp_class, unit_mapping=None):
|
def parse_date_delta(exp_class, unit_mapping=None):
|
||||||
def inner_func(args):
|
def inner_func(args):
|
||||||
unit_based = len(args) == 3
|
unit_based = len(args) == 3
|
||||||
this = list_get(args, 2) if unit_based else list_get(args, 0)
|
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||||
expression = list_get(args, 1) if unit_based else list_get(args, 1)
|
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||||
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||||
return exp_class(this=this, expression=expression, unit=unit)
|
return exp_class(this=this, expression=expression, unit=unit)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
approx_count_distinct_sql,
|
approx_count_distinct_sql,
|
||||||
|
@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
|
||||||
rename_func,
|
rename_func,
|
||||||
str_position_sql,
|
str_position_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def _unix_to_time(self, expression):
|
def _unix_to_time(self, expression):
|
||||||
|
@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def _sort_array_reverse(args):
|
def _sort_array_reverse(args):
|
||||||
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
|
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
|
||||||
|
|
||||||
|
|
||||||
def _struct_pack_sql(self, expression):
|
def _struct_pack_sql(self, expression):
|
||||||
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
|
args = [
|
||||||
|
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
|
||||||
|
for e in expression.expressions
|
||||||
|
]
|
||||||
return f"STRUCT_PACK({', '.join(args)})"
|
return f"STRUCT_PACK({', '.join(args)})"
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
|
||||||
|
|
||||||
|
|
||||||
class DuckDB(Dialect):
|
class DuckDB(Dialect):
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
":=": TokenType.EQ,
|
":=": TokenType.EQ,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||||
|
@ -92,7 +95,7 @@ class DuckDB(Dialect):
|
||||||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||||
this=exp.Div(
|
this=exp.Div(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=exp.Literal.number(1000),
|
expression=exp.Literal.number(1000),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
@ -112,11 +115,11 @@ class DuckDB(Dialect):
|
||||||
"UNNEST": exp.Explode.from_arg_list,
|
"UNNEST": exp.Explode.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
STRUCT_DELIMITER = ("(", ")")
|
STRUCT_DELIMITER = ("(", ")")
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||||
exp.Array: rename_func("LIST_VALUE"),
|
exp.Array: rename_func("LIST_VALUE"),
|
||||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||||
|
@ -160,7 +163,7 @@ class DuckDB(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.VARCHAR: "TEXT",
|
exp.DataType.Type.VARCHAR: "TEXT",
|
||||||
exp.DataType.Type.NVARCHAR: "TEXT",
|
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp, transforms
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
approx_count_distinct_sql,
|
approx_count_distinct_sql,
|
||||||
|
@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
|
||||||
struct_extract_sql,
|
struct_extract_sql,
|
||||||
var_map_sql,
|
var_map_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.parser import parse_var_map
|
||||||
from sqlglot.parser import Parser, parse_var_map
|
|
||||||
from sqlglot.tokens import Tokenizer
|
|
||||||
|
|
||||||
# (FuncType, Multiplier)
|
# (FuncType, Multiplier)
|
||||||
DATE_DELTA_INTERVAL = {
|
DATE_DELTA_INTERVAL = {
|
||||||
|
@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
|
||||||
unit = expression.text("unit").upper()
|
unit = expression.text("unit").upper()
|
||||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||||
modified_increment = (
|
modified_increment = (
|
||||||
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
|
int(expression.text("expression")) * multiplier
|
||||||
|
if expression.expression.is_number
|
||||||
|
else expression.expression
|
||||||
)
|
)
|
||||||
modified_increment = exp.Literal.number(modified_increment)
|
modified_increment = exp.Literal.number(modified_increment)
|
||||||
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
||||||
|
@ -165,10 +167,10 @@ class Hive(Dialect):
|
||||||
dateint_format = "'yyyyMMdd'"
|
dateint_format = "'yyyyMMdd'"
|
||||||
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'", '"']
|
QUOTES = ["'", '"']
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPE = "\\"
|
ESCAPES = ["\\"]
|
||||||
ENCODE = "utf-8"
|
ENCODE = "utf-8"
|
||||||
|
|
||||||
NUMERIC_LITERALS = {
|
NUMERIC_LITERALS = {
|
||||||
|
@ -180,40 +182,44 @@ class Hive(Dialect):
|
||||||
"BD": "DECIMAL",
|
"BD": "DECIMAL",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=list_get(args, 1),
|
expression=seq_get(args, 1),
|
||||||
unit=exp.Literal.string("DAY"),
|
unit=exp.Literal.string("DAY"),
|
||||||
),
|
),
|
||||||
"DATEDIFF": lambda args: exp.DateDiff(
|
"DATEDIFF": lambda args: exp.DateDiff(
|
||||||
this=exp.TsOrDsToDate(this=list_get(args, 0)),
|
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||||
expression=exp.TsOrDsToDate(this=list_get(args, 1)),
|
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||||
),
|
),
|
||||||
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=exp.Mul(
|
expression=exp.Mul(
|
||||||
this=list_get(args, 1),
|
this=seq_get(args, 1),
|
||||||
expression=exp.Literal.number(-1),
|
expression=exp.Literal.number(-1),
|
||||||
),
|
),
|
||||||
unit=exp.Literal.string("DAY"),
|
unit=exp.Literal.string("DAY"),
|
||||||
),
|
),
|
||||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
|
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
|
||||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
|
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||||
"LOCATE": lambda args: exp.StrPosition(
|
"LOCATE": lambda args: exp.StrPosition(
|
||||||
this=list_get(args, 1),
|
this=seq_get(args, 1),
|
||||||
substr=list_get(args, 0),
|
substr=seq_get(args, 0),
|
||||||
position=list_get(args, 2),
|
position=seq_get(args, 2),
|
||||||
|
),
|
||||||
|
"LOG": (
|
||||||
|
lambda args: exp.Log.from_arg_list(args)
|
||||||
|
if len(args) > 1
|
||||||
|
else exp.Ln.from_arg_list(args)
|
||||||
),
|
),
|
||||||
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
|
|
||||||
"MAP": parse_var_map,
|
"MAP": parse_var_map,
|
||||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||||
|
@ -226,15 +232,16 @@ class Hive(Dialect):
|
||||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TEXT: "STRING",
|
exp.DataType.Type.TEXT: "STRING",
|
||||||
|
exp.DataType.Type.VARBINARY: "BINARY",
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
**transforms.UNALIAS_GROUP,
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.AnonymousProperty: _property_sql,
|
exp.AnonymousProperty: _property_sql,
|
||||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
no_ilike_sql,
|
no_ilike_sql,
|
||||||
|
@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
|
||||||
no_tablesample_sql,
|
no_tablesample_sql,
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
def _show_parser(*args, **kwargs):
|
||||||
|
def _parse(self):
|
||||||
|
return self._parse_show_mysql(*args, **kwargs)
|
||||||
|
|
||||||
|
return _parse
|
||||||
|
|
||||||
|
|
||||||
def _date_trunc_sql(self, expression):
|
def _date_trunc_sql(self, expression):
|
||||||
unit = expression.text("unit").lower()
|
unit = expression.name.lower()
|
||||||
|
|
||||||
this = self.sql(expression.this)
|
expr = self.sql(expression.expression)
|
||||||
|
|
||||||
if unit == "day":
|
if unit == "day":
|
||||||
return f"DATE({this})"
|
return f"DATE({expr})"
|
||||||
|
|
||||||
if unit == "week":
|
if unit == "week":
|
||||||
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
|
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
|
||||||
date_format = "%Y %u %w"
|
date_format = "%Y %u %w"
|
||||||
elif unit == "month":
|
elif unit == "month":
|
||||||
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
|
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
|
||||||
date_format = "%Y %c %e"
|
date_format = "%Y %c %e"
|
||||||
elif unit == "quarter":
|
elif unit == "quarter":
|
||||||
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
|
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
|
||||||
date_format = "%Y %c %e"
|
date_format = "%Y %c %e"
|
||||||
elif unit == "year":
|
elif unit == "year":
|
||||||
concat = f"CONCAT(YEAR({this}), ' 1 1')"
|
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
|
||||||
date_format = "%Y %c %e"
|
date_format = "%Y %c %e"
|
||||||
else:
|
else:
|
||||||
self.unsupported("Unexpected interval unit: {unit}")
|
self.unsupported("Unexpected interval unit: {unit}")
|
||||||
return f"DATE({this})"
|
return f"DATE({expr})"
|
||||||
|
|
||||||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||||
|
|
||||||
|
|
||||||
def _str_to_date(args):
|
def _str_to_date(args):
|
||||||
date_format = MySQL.format_time(list_get(args, 1))
|
date_format = MySQL.format_time(seq_get(args, 1))
|
||||||
return exp.StrToDate(this=list_get(args, 0), format=date_format)
|
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||||
|
|
||||||
|
|
||||||
def _str_to_date_sql(self, expression):
|
def _str_to_date_sql(self, expression):
|
||||||
|
@ -66,9 +75,9 @@ def _trim_sql(self, expression):
|
||||||
|
|
||||||
def _date_add(expression_class):
|
def _date_add(expression_class):
|
||||||
def func(args):
|
def func(args):
|
||||||
interval = list_get(args, 1)
|
interval = seq_get(args, 1)
|
||||||
return expression_class(
|
return expression_class(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=interval.this,
|
expression=interval.this,
|
||||||
unit=exp.Literal.string(interval.text("unit").lower()),
|
unit=exp.Literal.string(interval.text("unit").lower()),
|
||||||
)
|
)
|
||||||
|
@ -101,15 +110,16 @@ class MySQL(Dialect):
|
||||||
"%l": "%-I",
|
"%l": "%-I",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'", '"']
|
QUOTES = ["'", '"']
|
||||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
|
ESCAPES = ["'", "\\"]
|
||||||
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"SEPARATOR": TokenType.SEPARATOR,
|
"SEPARATOR": TokenType.SEPARATOR,
|
||||||
"_ARMSCII8": TokenType.INTRODUCER,
|
"_ARMSCII8": TokenType.INTRODUCER,
|
||||||
"_ASCII": TokenType.INTRODUCER,
|
"_ASCII": TokenType.INTRODUCER,
|
||||||
|
@ -156,20 +166,23 @@ class MySQL(Dialect):
|
||||||
"_UTF32": TokenType.INTRODUCER,
|
"_UTF32": TokenType.INTRODUCER,
|
||||||
"_UTF8MB3": TokenType.INTRODUCER,
|
"_UTF8MB3": TokenType.INTRODUCER,
|
||||||
"_UTF8MB4": TokenType.INTRODUCER,
|
"_UTF8MB4": TokenType.INTRODUCER,
|
||||||
|
"@@": TokenType.SESSION_PARAMETER,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||||
|
|
||||||
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"DATE_ADD": _date_add(exp.DateAdd),
|
"DATE_ADD": _date_add(exp.DateAdd),
|
||||||
"DATE_SUB": _date_add(exp.DateSub),
|
"DATE_SUB": _date_add(exp.DateSub),
|
||||||
"STR_TO_DATE": _str_to_date,
|
"STR_TO_DATE": _str_to_date,
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
**Parser.FUNCTION_PARSERS,
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
"GROUP_CONCAT": lambda self: self.expression(
|
"GROUP_CONCAT": lambda self: self.expression(
|
||||||
exp.GroupConcat,
|
exp.GroupConcat,
|
||||||
this=self._parse_lambda(),
|
this=self._parse_lambda(),
|
||||||
|
@ -178,15 +191,212 @@ class MySQL(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
PROPERTY_PARSERS = {
|
PROPERTY_PARSERS = {
|
||||||
**Parser.PROPERTY_PARSERS,
|
**parser.Parser.PROPERTY_PARSERS,
|
||||||
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
|
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
STATEMENT_PARSERS = {
|
||||||
|
**parser.Parser.STATEMENT_PARSERS,
|
||||||
|
TokenType.SHOW: lambda self: self._parse_show(),
|
||||||
|
TokenType.SET: lambda self: self._parse_set(),
|
||||||
|
}
|
||||||
|
|
||||||
|
SHOW_PARSERS = {
|
||||||
|
"BINARY LOGS": _show_parser("BINARY LOGS"),
|
||||||
|
"MASTER LOGS": _show_parser("BINARY LOGS"),
|
||||||
|
"BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
|
||||||
|
"CHARACTER SET": _show_parser("CHARACTER SET"),
|
||||||
|
"CHARSET": _show_parser("CHARACTER SET"),
|
||||||
|
"COLLATION": _show_parser("COLLATION"),
|
||||||
|
"FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
|
||||||
|
"COLUMNS": _show_parser("COLUMNS", target="FROM"),
|
||||||
|
"CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
|
||||||
|
"CREATE EVENT": _show_parser("CREATE EVENT", target=True),
|
||||||
|
"CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
|
||||||
|
"CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
|
||||||
|
"CREATE TABLE": _show_parser("CREATE TABLE", target=True),
|
||||||
|
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
|
||||||
|
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
|
||||||
|
"DATABASES": _show_parser("DATABASES"),
|
||||||
|
"ENGINE": _show_parser("ENGINE", target=True),
|
||||||
|
"STORAGE ENGINES": _show_parser("ENGINES"),
|
||||||
|
"ENGINES": _show_parser("ENGINES"),
|
||||||
|
"ERRORS": _show_parser("ERRORS"),
|
||||||
|
"EVENTS": _show_parser("EVENTS"),
|
||||||
|
"FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
|
||||||
|
"FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
|
||||||
|
"GRANTS": _show_parser("GRANTS", target="FOR"),
|
||||||
|
"INDEX": _show_parser("INDEX", target="FROM"),
|
||||||
|
"MASTER STATUS": _show_parser("MASTER STATUS"),
|
||||||
|
"OPEN TABLES": _show_parser("OPEN TABLES"),
|
||||||
|
"PLUGINS": _show_parser("PLUGINS"),
|
||||||
|
"PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
|
||||||
|
"PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
|
||||||
|
"PRIVILEGES": _show_parser("PRIVILEGES"),
|
||||||
|
"FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
|
||||||
|
"PROCESSLIST": _show_parser("PROCESSLIST"),
|
||||||
|
"PROFILE": _show_parser("PROFILE"),
|
||||||
|
"PROFILES": _show_parser("PROFILES"),
|
||||||
|
"RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
|
||||||
|
"REPLICAS": _show_parser("REPLICAS"),
|
||||||
|
"SLAVE HOSTS": _show_parser("REPLICAS"),
|
||||||
|
"REPLICA STATUS": _show_parser("REPLICA STATUS"),
|
||||||
|
"SLAVE STATUS": _show_parser("REPLICA STATUS"),
|
||||||
|
"GLOBAL STATUS": _show_parser("STATUS", global_=True),
|
||||||
|
"SESSION STATUS": _show_parser("STATUS"),
|
||||||
|
"STATUS": _show_parser("STATUS"),
|
||||||
|
"TABLE STATUS": _show_parser("TABLE STATUS"),
|
||||||
|
"FULL TABLES": _show_parser("TABLES", full=True),
|
||||||
|
"TABLES": _show_parser("TABLES"),
|
||||||
|
"TRIGGERS": _show_parser("TRIGGERS"),
|
||||||
|
"GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
|
||||||
|
"SESSION VARIABLES": _show_parser("VARIABLES"),
|
||||||
|
"VARIABLES": _show_parser("VARIABLES"),
|
||||||
|
"WARNINGS": _show_parser("WARNINGS"),
|
||||||
|
}
|
||||||
|
|
||||||
|
SET_PARSERS = {
|
||||||
|
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
|
||||||
|
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
|
||||||
|
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
|
||||||
|
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
|
||||||
|
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
|
||||||
|
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||||
|
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||||
|
"NAMES": lambda self: self._parse_set_item_names(),
|
||||||
|
}
|
||||||
|
|
||||||
|
PROFILE_TYPES = {
|
||||||
|
"ALL",
|
||||||
|
"BLOCK IO",
|
||||||
|
"CONTEXT SWITCHES",
|
||||||
|
"CPU",
|
||||||
|
"IPC",
|
||||||
|
"MEMORY",
|
||||||
|
"PAGE FAULTS",
|
||||||
|
"SOURCE",
|
||||||
|
"SWAPS",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||||
|
if target:
|
||||||
|
if isinstance(target, str):
|
||||||
|
self._match_text(target)
|
||||||
|
target_id = self._parse_id_var()
|
||||||
|
else:
|
||||||
|
target_id = None
|
||||||
|
|
||||||
|
log = self._parse_string() if self._match_text("IN") else None
|
||||||
|
|
||||||
|
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
||||||
|
position = self._parse_number() if self._match_text("FROM") else None
|
||||||
|
db = None
|
||||||
|
else:
|
||||||
|
position = None
|
||||||
|
db = self._parse_id_var() if self._match_text("FROM") else None
|
||||||
|
|
||||||
|
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
|
||||||
|
|
||||||
|
like = self._parse_string() if self._match_text("LIKE") else None
|
||||||
|
where = self._parse_where()
|
||||||
|
|
||||||
|
if this == "PROFILE":
|
||||||
|
types = self._parse_csv(self._parse_show_profile_type)
|
||||||
|
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
|
||||||
|
offset = self._parse_number() if self._match_text("OFFSET") else None
|
||||||
|
limit = self._parse_number() if self._match_text("LIMIT") else None
|
||||||
|
else:
|
||||||
|
types, query = None, None
|
||||||
|
offset, limit = self._parse_oldstyle_limit()
|
||||||
|
|
||||||
|
mutex = True if self._match_text("MUTEX") else None
|
||||||
|
mutex = False if self._match_text("STATUS") else mutex
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.Show,
|
||||||
|
this=this,
|
||||||
|
target=target_id,
|
||||||
|
full=full,
|
||||||
|
log=log,
|
||||||
|
position=position,
|
||||||
|
db=db,
|
||||||
|
channel=channel,
|
||||||
|
like=like,
|
||||||
|
where=where,
|
||||||
|
types=types,
|
||||||
|
query=query,
|
||||||
|
offset=offset,
|
||||||
|
limit=limit,
|
||||||
|
mutex=mutex,
|
||||||
|
**{"global": global_},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_show_profile_type(self):
|
||||||
|
for type_ in self.PROFILE_TYPES:
|
||||||
|
if self._match_text(*type_.split(" ")):
|
||||||
|
return exp.Var(this=type_)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_oldstyle_limit(self):
|
||||||
|
limit = None
|
||||||
|
offset = None
|
||||||
|
if self._match_text("LIMIT"):
|
||||||
|
parts = self._parse_csv(self._parse_number)
|
||||||
|
if len(parts) == 1:
|
||||||
|
limit = parts[0]
|
||||||
|
elif len(parts) == 2:
|
||||||
|
limit = parts[1]
|
||||||
|
offset = parts[0]
|
||||||
|
return offset, limit
|
||||||
|
|
||||||
|
def _default_parse_set_item(self):
|
||||||
|
return self._parse_set_item_assignment(kind=None)
|
||||||
|
|
||||||
|
def _parse_set_item_assignment(self, kind):
|
||||||
|
left = self._parse_primary() or self._parse_id_var()
|
||||||
|
if not self._match(TokenType.EQ):
|
||||||
|
self.raise_error("Expected =")
|
||||||
|
right = self._parse_statement() or self._parse_id_var()
|
||||||
|
|
||||||
|
this = self.expression(
|
||||||
|
exp.EQ,
|
||||||
|
this=left,
|
||||||
|
expression=right,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.SetItem,
|
||||||
|
this=this,
|
||||||
|
kind=kind,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_set_item_charset(self, kind):
|
||||||
|
this = self._parse_string() or self._parse_id_var()
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.SetItem,
|
||||||
|
this=this,
|
||||||
|
kind=kind,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_set_item_names(self):
|
||||||
|
charset = self._parse_string() or self._parse_id_var()
|
||||||
|
if self._match_text("COLLATE"):
|
||||||
|
collate = self._parse_string() or self._parse_id_var()
|
||||||
|
else:
|
||||||
|
collate = None
|
||||||
|
return self.expression(
|
||||||
|
exp.SetItem,
|
||||||
|
this=charset,
|
||||||
|
collate=collate,
|
||||||
|
kind="NAMES",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Generator(generator.Generator):
|
||||||
NULL_ORDERING_SUPPORTED = False
|
NULL_ORDERING_SUPPORTED = False
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.CurrentDate: no_paren_current_date_sql,
|
exp.CurrentDate: no_paren_current_date_sql,
|
||||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
|
@ -199,6 +409,8 @@ class MySQL(Dialect):
|
||||||
exp.StrToDate: _str_to_date_sql,
|
exp.StrToDate: _str_to_date_sql,
|
||||||
exp.StrToTime: _str_to_date_sql,
|
exp.StrToTime: _str_to_date_sql,
|
||||||
exp.Trim: _trim_sql,
|
exp.Trim: _trim_sql,
|
||||||
|
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||||
|
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {
|
||||||
|
@ -209,4 +421,69 @@ class MySQL(Dialect):
|
||||||
exp.SchemaCommentProperty,
|
exp.SchemaCommentProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {}
|
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||||
|
|
||||||
|
def show_sql(self, expression):
|
||||||
|
this = f" {expression.name}"
|
||||||
|
full = " FULL" if expression.args.get("full") else ""
|
||||||
|
global_ = " GLOBAL" if expression.args.get("global") else ""
|
||||||
|
|
||||||
|
target = self.sql(expression, "target")
|
||||||
|
target = f" {target}" if target else ""
|
||||||
|
if expression.name in {"COLUMNS", "INDEX"}:
|
||||||
|
target = f" FROM{target}"
|
||||||
|
elif expression.name == "GRANTS":
|
||||||
|
target = f" FOR{target}"
|
||||||
|
|
||||||
|
db = self._prefixed_sql("FROM", expression, "db")
|
||||||
|
|
||||||
|
like = self._prefixed_sql("LIKE", expression, "like")
|
||||||
|
where = self.sql(expression, "where")
|
||||||
|
|
||||||
|
types = self.expressions(expression, key="types")
|
||||||
|
types = f" {types}" if types else types
|
||||||
|
query = self._prefixed_sql("FOR QUERY", expression, "query")
|
||||||
|
|
||||||
|
if expression.name == "PROFILE":
|
||||||
|
offset = self._prefixed_sql("OFFSET", expression, "offset")
|
||||||
|
limit = self._prefixed_sql("LIMIT", expression, "limit")
|
||||||
|
else:
|
||||||
|
offset = ""
|
||||||
|
limit = self._oldstyle_limit_sql(expression)
|
||||||
|
|
||||||
|
log = self._prefixed_sql("IN", expression, "log")
|
||||||
|
position = self._prefixed_sql("FROM", expression, "position")
|
||||||
|
|
||||||
|
channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
|
||||||
|
|
||||||
|
if expression.name == "ENGINE":
|
||||||
|
mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
|
||||||
|
else:
|
||||||
|
mutex_or_status = ""
|
||||||
|
|
||||||
|
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
|
||||||
|
|
||||||
|
def _prefixed_sql(self, prefix, expression, arg):
|
||||||
|
sql = self.sql(expression, arg)
|
||||||
|
if not sql:
|
||||||
|
return ""
|
||||||
|
return f" {prefix} {sql}"
|
||||||
|
|
||||||
|
def _oldstyle_limit_sql(self, expression):
|
||||||
|
limit = self.sql(expression, "limit")
|
||||||
|
offset = self.sql(expression, "offset")
|
||||||
|
if limit:
|
||||||
|
limit_offset = f"{offset}, {limit}" if offset else limit
|
||||||
|
return f" LIMIT {limit_offset}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def setitem_sql(self, expression):
|
||||||
|
kind = self.sql(expression, "kind")
|
||||||
|
kind = f"{kind} " if kind else ""
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
collate = self.sql(expression, "collate")
|
||||||
|
collate = f" COLLATE {collate}" if collate else ""
|
||||||
|
return f"{kind}{this}{collate}"
|
||||||
|
|
||||||
|
def set_sql(self, expression):
|
||||||
|
return f"SET {self.expressions(expression)}"
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from sqlglot import exp, transforms
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
|
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
|
||||||
from sqlglot.generator import Generator
|
|
||||||
from sqlglot.helper import csv
|
from sqlglot.helper import csv
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
|
||||||
def _limit_sql(self, expression):
|
def _limit_sql(self, expression):
|
||||||
|
@ -36,9 +37,9 @@ class Oracle(Dialect):
|
||||||
"YYYY": "%Y", # 2015
|
"YYYY": "%Y", # 2015
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TINYINT: "NUMBER",
|
exp.DataType.Type.TINYINT: "NUMBER",
|
||||||
exp.DataType.Type.SMALLINT: "NUMBER",
|
exp.DataType.Type.SMALLINT: "NUMBER",
|
||||||
exp.DataType.Type.INT: "NUMBER",
|
exp.DataType.Type.INT: "NUMBER",
|
||||||
|
@ -49,11 +50,12 @@ class Oracle(Dialect):
|
||||||
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
|
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
|
||||||
exp.DataType.Type.TEXT: "CLOB",
|
exp.DataType.Type.TEXT: "CLOB",
|
||||||
exp.DataType.Type.BINARY: "BLOB",
|
exp.DataType.Type.BINARY: "BLOB",
|
||||||
|
exp.DataType.Type.VARBINARY: "BLOB",
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
**transforms.UNALIAS_GROUP,
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.Limit: _limit_sql,
|
exp.Limit: _limit_sql,
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
|
@ -86,9 +88,9 @@ class Oracle(Dialect):
|
||||||
def table_sql(self, expression):
|
def table_sql(self, expression):
|
||||||
return super().table_sql(expression, sep=" ")
|
return super().table_sql(expression, sep=" ")
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"TOP": TokenType.TOP,
|
"TOP": TokenType.TOP,
|
||||||
"VARCHAR2": TokenType.VARCHAR,
|
"VARCHAR2": TokenType.VARCHAR,
|
||||||
"NVARCHAR2": TokenType.NVARCHAR,
|
"NVARCHAR2": TokenType.NVARCHAR,
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
arrow_json_extract_scalar_sql,
|
arrow_json_extract_scalar_sql,
|
||||||
|
@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
str_position_sql,
|
str_position_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
from sqlglot.transforms import delegate, preprocess
|
from sqlglot.transforms import delegate, preprocess
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,12 +160,12 @@ class Postgres(Dialect):
|
||||||
"YYYY": "%Y", # 2015
|
"YYYY": "%Y", # 2015
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"ALWAYS": TokenType.ALWAYS,
|
"ALWAYS": TokenType.ALWAYS,
|
||||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||||
"COMMENT ON": TokenType.COMMENT_ON,
|
"COMMENT ON": TokenType.COMMENT_ON,
|
||||||
|
@ -179,31 +179,32 @@ class Postgres(Dialect):
|
||||||
}
|
}
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
**Tokenizer.SINGLE_TOKENS,
|
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||||
"$": TokenType.PARAMETER,
|
"$": TokenType.PARAMETER,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
STRICT_CAST = False
|
STRICT_CAST = False
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"TO_TIMESTAMP": _to_timestamp,
|
"TO_TIMESTAMP": _to_timestamp,
|
||||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||||
exp.DataType.Type.FLOAT: "REAL",
|
exp.DataType.Type.FLOAT: "REAL",
|
||||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||||
exp.DataType.Type.BINARY: "BYTEA",
|
exp.DataType.Type.BINARY: "BYTEA",
|
||||||
|
exp.DataType.Type.VARBINARY: "BYTEA",
|
||||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.ColumnDef: preprocess(
|
exp.ColumnDef: preprocess(
|
||||||
[
|
[
|
||||||
_auto_increment_to_serial,
|
_auto_increment_to_serial,
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp, transforms
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
format_time_lambda,
|
format_time_lambda,
|
||||||
|
@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
|
||||||
struct_extract_sql,
|
struct_extract_sql,
|
||||||
)
|
)
|
||||||
from sqlglot.dialects.mysql import MySQL
|
from sqlglot.dialects.mysql import MySQL
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def _approx_distinct_sql(self, expression):
|
def _approx_distinct_sql(self, expression):
|
||||||
|
@ -110,30 +110,29 @@ class Presto(Dialect):
|
||||||
index_offset = 1
|
index_offset = 1
|
||||||
null_ordering = "nulls_are_last"
|
null_ordering = "nulls_are_last"
|
||||||
time_format = "'%Y-%m-%d %H:%i:%S'"
|
time_format = "'%Y-%m-%d %H:%i:%S'"
|
||||||
time_mapping = MySQL.time_mapping
|
time_mapping = MySQL.time_mapping # type: ignore
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"VARBINARY": TokenType.BINARY,
|
|
||||||
"ROW": TokenType.STRUCT,
|
"ROW": TokenType.STRUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||||
"DATE_ADD": lambda args: exp.DateAdd(
|
"DATE_ADD": lambda args: exp.DateAdd(
|
||||||
this=list_get(args, 2),
|
this=seq_get(args, 2),
|
||||||
expression=list_get(args, 1),
|
expression=seq_get(args, 1),
|
||||||
unit=list_get(args, 0),
|
unit=seq_get(args, 0),
|
||||||
),
|
),
|
||||||
"DATE_DIFF": lambda args: exp.DateDiff(
|
"DATE_DIFF": lambda args: exp.DateDiff(
|
||||||
this=list_get(args, 2),
|
this=seq_get(args, 2),
|
||||||
expression=list_get(args, 1),
|
expression=seq_get(args, 1),
|
||||||
unit=list_get(args, 0),
|
unit=seq_get(args, 0),
|
||||||
),
|
),
|
||||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||||
|
@ -143,7 +142,7 @@ class Presto(Dialect):
|
||||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
|
|
||||||
STRUCT_DELIMITER = ("(", ")")
|
STRUCT_DELIMITER = ("(", ")")
|
||||||
|
|
||||||
|
@ -159,7 +158,7 @@ class Presto(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.INT: "INTEGER",
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
exp.DataType.Type.FLOAT: "REAL",
|
exp.DataType.Type.FLOAT: "REAL",
|
||||||
exp.DataType.Type.BINARY: "VARBINARY",
|
exp.DataType.Type.BINARY: "VARBINARY",
|
||||||
|
@ -169,8 +168,8 @@ class Presto(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
**transforms.UNALIAS_GROUP,
|
**transforms.UNALIAS_GROUP, # type: ignore
|
||||||
exp.ApproxDistinct: _approx_distinct_sql,
|
exp.ApproxDistinct: _approx_distinct_sql,
|
||||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||||
exp.ArrayConcat: rename_func("CONCAT"),
|
exp.ArrayConcat: rename_func("CONCAT"),
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.postgres import Postgres
|
from sqlglot.dialects.postgres import Postgres
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
|
||||||
class Redshift(Postgres):
|
class Redshift(Postgres):
|
||||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||||
time_mapping = {
|
time_mapping = {
|
||||||
**Postgres.time_mapping,
|
**Postgres.time_mapping, # type: ignore
|
||||||
"MON": "%b",
|
"MON": "%b",
|
||||||
"HH": "%H",
|
"HH": "%H",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Postgres.Tokenizer):
|
class Tokenizer(Postgres.Tokenizer):
|
||||||
ESCAPE = "\\"
|
ESCAPES = ["\\"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Postgres.Tokenizer.KEYWORDS,
|
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||||
"GEOMETRY": TokenType.GEOMETRY,
|
"GEOMETRY": TokenType.GEOMETRY,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||||
"SUPER": TokenType.SUPER,
|
"SUPER": TokenType.SUPER,
|
||||||
"TIME": TokenType.TIMESTAMP,
|
"TIME": TokenType.TIMESTAMP,
|
||||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||||
"VARBYTE": TokenType.BINARY,
|
"VARBYTE": TokenType.VARBINARY,
|
||||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Postgres.Generator):
|
class Generator(Postgres.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Postgres.Generator.TYPE_MAPPING,
|
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
||||||
exp.DataType.Type.BINARY: "VARBYTE",
|
exp.DataType.Type.BINARY: "VARBYTE",
|
||||||
|
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||||
exp.DataType.Type.INT: "INTEGER",
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
format_time_lambda,
|
format_time_lambda,
|
||||||
|
@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
|
||||||
rename_func,
|
rename_func,
|
||||||
)
|
)
|
||||||
from sqlglot.expressions import Literal
|
from sqlglot.expressions import Literal
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
def _check_int(s):
|
def _check_int(s):
|
||||||
|
@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
|
||||||
|
|
||||||
# case: <numeric_expr> [ , <scale> ]
|
# case: <numeric_expr> [ , <scale> ]
|
||||||
if second_arg.name not in ["0", "3", "9"]:
|
if second_arg.name not in ["0", "3", "9"]:
|
||||||
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
|
raise ValueError(
|
||||||
|
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
|
||||||
|
)
|
||||||
|
|
||||||
if second_arg.name == "0":
|
if second_arg.name == "0":
|
||||||
timescale = exp.UnixToTime.SECONDS
|
timescale = exp.UnixToTime.SECONDS
|
||||||
|
@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
|
||||||
|
|
||||||
return exp.UnixToTime(this=first_arg, scale=timescale)
|
return exp.UnixToTime(this=first_arg, scale=timescale)
|
||||||
|
|
||||||
first_arg = list_get(args, 0)
|
first_arg = seq_get(args, 0)
|
||||||
if not isinstance(first_arg, Literal):
|
if not isinstance(first_arg, Literal):
|
||||||
# case: <variant_expr>
|
# case: <variant_expr>
|
||||||
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
||||||
|
@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
|
||||||
return exp.UnixToTime.from_arg_list(args)
|
return exp.UnixToTime.from_arg_list(args)
|
||||||
|
|
||||||
|
|
||||||
def _unix_to_time(self, expression):
|
def _unix_to_time_sql(self, expression):
|
||||||
scale = expression.args.get("scale")
|
scale = expression.args.get("scale")
|
||||||
timestamp = self.sql(expression, "this")
|
timestamp = self.sql(expression, "this")
|
||||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||||
|
@ -132,9 +134,9 @@ class Snowflake(Dialect):
|
||||||
"ff6": "%f",
|
"ff6": "%f",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||||
"IFF": exp.If.from_arg_list,
|
"IFF": exp.If.from_arg_list,
|
||||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||||
|
@ -143,18 +145,18 @@ class Snowflake(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
**Parser.FUNCTION_PARSERS,
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
"DATE_PART": _parse_date_part,
|
"DATE_PART": _parse_date_part,
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNC_TOKENS = {
|
FUNC_TOKENS = {
|
||||||
*Parser.FUNC_TOKENS,
|
*parser.Parser.FUNC_TOKENS,
|
||||||
TokenType.RLIKE,
|
TokenType.RLIKE,
|
||||||
TokenType.TABLE,
|
TokenType.TABLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
COLUMN_OPERATORS = {
|
COLUMN_OPERATORS = {
|
||||||
**Parser.COLUMN_OPERATORS,
|
**parser.Parser.COLUMN_OPERATORS, # type: ignore
|
||||||
TokenType.COLON: lambda self, this, path: self.expression(
|
TokenType.COLON: lambda self, this, path: self.expression(
|
||||||
exp.Bracket,
|
exp.Bracket,
|
||||||
this=this,
|
this=this,
|
||||||
|
@ -163,21 +165,21 @@ class Snowflake(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
PROPERTY_PARSERS = {
|
PROPERTY_PARSERS = {
|
||||||
**Parser.PROPERTY_PARSERS,
|
**parser.Parser.PROPERTY_PARSERS,
|
||||||
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
|
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
ESCAPE = "\\"
|
ESCAPES = ["\\"]
|
||||||
|
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
**Tokenizer.SINGLE_TOKENS,
|
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||||
"$": TokenType.PARAMETER,
|
"$": TokenType.PARAMETER,
|
||||||
}
|
}
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
"QUALIFY": TokenType.QUALIFY,
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||||
|
@ -187,15 +189,15 @@ class Snowflake(Dialect):
|
||||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
CREATE_TRANSIENT = True
|
CREATE_TRANSIENT = True
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||||
exp.If: rename_func("IFF"),
|
exp.If: rename_func("IFF"),
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.UnixToTime: _unix_to_time,
|
exp.UnixToTime: _unix_to_time_sql,
|
||||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||||
exp.Array: inline_array_sql,
|
exp.Array: inline_array_sql,
|
||||||
exp.StrPosition: rename_func("POSITION"),
|
exp.StrPosition: rename_func("POSITION"),
|
||||||
|
@ -204,7 +206,7 @@ class Snowflake(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, parser
|
||||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
|
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
|
||||||
from sqlglot.dialects.hive import Hive
|
from sqlglot.dialects.hive import Hive
|
||||||
from sqlglot.helper import list_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.parser import Parser
|
|
||||||
|
|
||||||
|
|
||||||
def _create_sql(self, e):
|
def _create_sql(self, e):
|
||||||
|
@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
|
||||||
class Spark(Hive):
|
class Spark(Hive):
|
||||||
class Parser(Hive.Parser):
|
class Parser(Hive.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Hive.Parser.FUNCTIONS,
|
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||||
"LEFT": lambda args: exp.Substring(
|
"LEFT": lambda args: exp.Substring(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
start=exp.Literal.number(1),
|
start=exp.Literal.number(1),
|
||||||
length=list_get(args, 1),
|
length=seq_get(args, 1),
|
||||||
),
|
),
|
||||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=list_get(args, 1),
|
expression=seq_get(args, 1),
|
||||||
),
|
),
|
||||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
expression=list_get(args, 1),
|
expression=seq_get(args, 1),
|
||||||
),
|
),
|
||||||
"RIGHT": lambda args: exp.Substring(
|
"RIGHT": lambda args: exp.Substring(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
start=exp.Sub(
|
start=exp.Sub(
|
||||||
this=exp.Length(this=list_get(args, 0)),
|
this=exp.Length(this=seq_get(args, 0)),
|
||||||
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
|
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||||
),
|
),
|
||||||
length=list_get(args, 1),
|
length=seq_get(args, 1),
|
||||||
),
|
),
|
||||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
"IIF": exp.If.from_arg_list,
|
"IIF": exp.If.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
**Parser.FUNCTION_PARSERS,
|
**parser.Parser.FUNCTION_PARSERS,
|
||||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||||
|
@ -88,14 +89,14 @@ class Spark(Hive):
|
||||||
|
|
||||||
class Generator(Hive.Generator):
|
class Generator(Hive.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Hive.Generator.TYPE_MAPPING,
|
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||||
exp.DataType.Type.TINYINT: "BYTE",
|
exp.DataType.Type.TINYINT: "BYTE",
|
||||||
exp.DataType.Type.SMALLINT: "SHORT",
|
exp.DataType.Type.SMALLINT: "SHORT",
|
||||||
exp.DataType.Type.BIGINT: "LONG",
|
exp.DataType.Type.BIGINT: "LONG",
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
|
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||||
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
||||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
|
@ -114,6 +115,8 @@ class Spark(Hive):
|
||||||
exp.VariancePop: rename_func("VAR_POP"),
|
exp.VariancePop: rename_func("VAR_POP"),
|
||||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||||
}
|
}
|
||||||
|
TRANSFORMS.pop(exp.ArraySort)
|
||||||
|
TRANSFORMS.pop(exp.ILike)
|
||||||
|
|
||||||
WRAP_DERIVED_VALUES = False
|
WRAP_DERIVED_VALUES = False
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
arrow_json_extract_scalar_sql,
|
arrow_json_extract_scalar_sql,
|
||||||
|
@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
|
||||||
no_trycast_sql,
|
no_trycast_sql,
|
||||||
rename_func,
|
rename_func,
|
||||||
)
|
)
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.tokens import TokenType
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
|
||||||
|
|
||||||
|
|
||||||
class SQLite(Dialect):
|
class SQLite(Dialect):
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"VARBINARY": TokenType.BINARY,
|
|
||||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||||
exp.DataType.Type.TINYINT: "INTEGER",
|
exp.DataType.Type.TINYINT: "INTEGER",
|
||||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||||
|
@ -46,6 +45,7 @@ class SQLite(Dialect):
|
||||||
exp.DataType.Type.VARCHAR: "TEXT",
|
exp.DataType.Type.VARCHAR: "TEXT",
|
||||||
exp.DataType.Type.NVARCHAR: "TEXT",
|
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||||
exp.DataType.Type.BINARY: "BLOB",
|
exp.DataType.Type.BINARY: "BLOB",
|
||||||
|
exp.DataType.Type.VARBINARY: "BLOB",
|
||||||
}
|
}
|
||||||
|
|
||||||
TOKEN_MAPPING = {
|
TOKEN_MAPPING = {
|
||||||
|
@ -53,7 +53,7 @@ class SQLite(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS,
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.JSONExtract: arrow_json_extract_sql,
|
exp.JSONExtract: arrow_json_extract_sql,
|
||||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||||
from sqlglot.dialects.mysql import MySQL
|
from sqlglot.dialects.mysql import MySQL
|
||||||
|
|
||||||
|
|
||||||
class StarRocks(MySQL):
|
class StarRocks(MySQL):
|
||||||
class Generator(MySQL.Generator):
|
class Generator(MySQL.Generator): # type: ignore
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**MySQL.Generator.TYPE_MAPPING,
|
**MySQL.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.TEXT: "STRING",
|
exp.DataType.Type.TEXT: "STRING",
|
||||||
|
@ -13,7 +15,7 @@ class StarRocks(MySQL):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**MySQL.Generator.TRANSFORMS,
|
**MySQL.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||||
exp.JSONExtract: arrow_json_extract_sql,
|
exp.JSONExtract: arrow_json_extract_sql,
|
||||||
exp.DateDiff: rename_func("DATEDIFF"),
|
exp.DateDiff: rename_func("DATEDIFF"),
|
||||||
|
@ -22,3 +24,4 @@ class StarRocks(MySQL):
|
||||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||||
}
|
}
|
||||||
|
TRANSFORMS.pop(exp.DateTrunc)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from sqlglot import exp
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlglot import exp, generator, parser
|
||||||
from sqlglot.dialects.dialect import Dialect
|
from sqlglot.dialects.dialect import Dialect
|
||||||
from sqlglot.generator import Generator
|
|
||||||
from sqlglot.parser import Parser
|
|
||||||
|
|
||||||
|
|
||||||
def _if_sql(self, expression):
|
def _if_sql(self, expression):
|
||||||
|
@ -20,17 +20,17 @@ def _count_sql(self, expression):
|
||||||
|
|
||||||
|
|
||||||
class Tableau(Dialect):
|
class Tableau(Dialect):
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.If: _if_sql,
|
exp.If: _if_sql,
|
||||||
exp.Coalesce: _coalesce_sql,
|
exp.Coalesce: _coalesce_sql,
|
||||||
exp.Count: _count_sql,
|
exp.Count: _count_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"IFNULL": exp.Coalesce.from_arg_list,
|
"IFNULL": exp.Coalesce.from_arg_list,
|
||||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.presto import Presto
|
from sqlglot.dialects.presto import Presto
|
||||||
|
|
||||||
|
@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
|
||||||
class Trino(Presto):
|
class Trino(Presto):
|
||||||
class Generator(Presto.Generator):
|
class Generator(Presto.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Presto.Generator.TRANSFORMS,
|
**Presto.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,22 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
||||||
from sqlglot.expressions import DataType
|
from sqlglot.expressions import DataType
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.helper import list_get
|
|
||||||
from sqlglot.parser import Parser
|
|
||||||
from sqlglot.time import format_time
|
from sqlglot.time import format_time
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
|
FULL_FORMAT_TIME_MAPPING = {
|
||||||
|
"weekday": "%A",
|
||||||
|
"dw": "%A",
|
||||||
|
"w": "%A",
|
||||||
|
"month": "%B",
|
||||||
|
"mm": "%B",
|
||||||
|
"m": "%B",
|
||||||
|
}
|
||||||
DATE_DELTA_INTERVAL = {
|
DATE_DELTA_INTERVAL = {
|
||||||
"year": "year",
|
"year": "year",
|
||||||
"yyyy": "year",
|
"yyyy": "year",
|
||||||
|
@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
||||||
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||||
def _format_time(args):
|
def _format_time(args):
|
||||||
return exp_class(
|
return exp_class(
|
||||||
this=list_get(args, 1),
|
this=seq_get(args, 1),
|
||||||
format=exp.Literal.string(
|
format=exp.Literal.string(
|
||||||
format_time(
|
format_time(
|
||||||
list_get(args, 0).name or (TSQL.time_format if default is True else default),
|
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
|
||||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
|
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
|
||||||
|
if full_format_mapping
|
||||||
|
else TSQL.time_mapping,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||||
|
|
||||||
|
|
||||||
def parse_format(args):
|
def parse_format(args):
|
||||||
fmt = list_get(args, 1)
|
fmt = seq_get(args, 1)
|
||||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
||||||
if number_fmt:
|
if number_fmt:
|
||||||
return exp.NumberToStr(this=list_get(args, 0), format=fmt)
|
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
|
||||||
return exp.TimeToStr(
|
return exp.TimeToStr(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
format=exp.Literal.string(
|
format=exp.Literal.string(
|
||||||
format_time(fmt.name, TSQL.format_time_mapping)
|
format_time(fmt.name, TSQL.format_time_mapping)
|
||||||
if len(fmt.name) == 1
|
if len(fmt.name) == 1
|
||||||
|
@ -188,11 +197,11 @@ class TSQL(Dialect):
|
||||||
"Y": "%a %Y",
|
"Y": "%a %Y",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
IDENTIFIERS = ['"', ("[", "]")]
|
IDENTIFIERS = ['"', ("[", "]")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"BIT": TokenType.BOOLEAN,
|
"BIT": TokenType.BOOLEAN,
|
||||||
"REAL": TokenType.FLOAT,
|
"REAL": TokenType.FLOAT,
|
||||||
"NTEXT": TokenType.TEXT,
|
"NTEXT": TokenType.TEXT,
|
||||||
|
@ -200,7 +209,6 @@ class TSQL(Dialect):
|
||||||
"DATETIME2": TokenType.DATETIME,
|
"DATETIME2": TokenType.DATETIME,
|
||||||
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
||||||
"TIME": TokenType.TIMESTAMP,
|
"TIME": TokenType.TIMESTAMP,
|
||||||
"VARBINARY": TokenType.BINARY,
|
|
||||||
"IMAGE": TokenType.IMAGE,
|
"IMAGE": TokenType.IMAGE,
|
||||||
"MONEY": TokenType.MONEY,
|
"MONEY": TokenType.MONEY,
|
||||||
"SMALLMONEY": TokenType.SMALLMONEY,
|
"SMALLMONEY": TokenType.SMALLMONEY,
|
||||||
|
@ -213,9 +221,9 @@ class TSQL(Dialect):
|
||||||
"TOP": TokenType.TOP,
|
"TOP": TokenType.TOP,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**Parser.FUNCTIONS,
|
**parser.Parser.FUNCTIONS,
|
||||||
"CHARINDEX": exp.StrPosition.from_arg_list,
|
"CHARINDEX": exp.StrPosition.from_arg_list,
|
||||||
"ISNULL": exp.Coalesce.from_arg_list,
|
"ISNULL": exp.Coalesce.from_arg_list,
|
||||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||||
|
@ -243,14 +251,16 @@ class TSQL(Dialect):
|
||||||
this = self._parse_column()
|
this = self._parse_column()
|
||||||
|
|
||||||
# Retrieve length of datatype and override to default if not specified
|
# Retrieve length of datatype and override to default if not specified
|
||||||
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||||
|
|
||||||
# Check whether a conversion with format is applicable
|
# Check whether a conversion with format is applicable
|
||||||
if self._match(TokenType.COMMA):
|
if self._match(TokenType.COMMA):
|
||||||
format_val = self._parse_number().name
|
format_val = self._parse_number().name
|
||||||
if format_val not in TSQL.convert_format_mapping:
|
if format_val not in TSQL.convert_format_mapping:
|
||||||
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
|
raise ValueError(
|
||||||
|
f"CONVERT function at T-SQL does not support format style {format_val}"
|
||||||
|
)
|
||||||
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
|
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
|
||||||
|
|
||||||
# Check whether the convert entails a string to date format
|
# Check whether the convert entails a string to date format
|
||||||
|
@ -272,9 +282,9 @@ class TSQL(Dialect):
|
||||||
# Entails a simple cast without any format requirement
|
# Entails a simple cast without any format requirement
|
||||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**Generator.TYPE_MAPPING,
|
**generator.Generator.TYPE_MAPPING,
|
||||||
exp.DataType.Type.BOOLEAN: "BIT",
|
exp.DataType.Type.BOOLEAN: "BIT",
|
||||||
exp.DataType.Type.INT: "INTEGER",
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||||
|
@ -283,7 +293,7 @@ class TSQL(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**generator.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||||
exp.CurrentDate: rename_func("GETDATE"),
|
exp.CurrentDate: rename_func("GETDATE"),
|
||||||
|
|
|
@ -4,7 +4,7 @@ from heapq import heappop, heappush
|
||||||
|
|
||||||
from sqlglot import Dialect
|
from sqlglot import Dialect
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.helper import ensure_list
|
from sqlglot.helper import ensure_collection
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
@ -116,7 +116,9 @@ class ChangeDistiller:
|
||||||
source_node = self._source_index[kept_source_node_id]
|
source_node = self._source_index[kept_source_node_id]
|
||||||
target_node = self._target_index[kept_target_node_id]
|
target_node = self._target_index[kept_target_node_id]
|
||||||
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
|
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
|
||||||
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
|
edit_script.extend(
|
||||||
|
self._generate_move_edits(source_node, target_node, matching_set)
|
||||||
|
)
|
||||||
edit_script.append(Keep(source_node, target_node))
|
edit_script.append(Keep(source_node, target_node))
|
||||||
else:
|
else:
|
||||||
edit_script.append(Update(source_node, target_node))
|
edit_script.append(Update(source_node, target_node))
|
||||||
|
@ -158,13 +160,16 @@ class ChangeDistiller:
|
||||||
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
||||||
if max_leaves_num:
|
if max_leaves_num:
|
||||||
common_leaves_num = sum(
|
common_leaves_num = sum(
|
||||||
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
|
1 if s in source_leaf_ids and t in target_leaf_ids else 0
|
||||||
|
for s, t in leaves_matching_set
|
||||||
)
|
)
|
||||||
leaf_similarity_score = common_leaves_num / max_leaves_num
|
leaf_similarity_score = common_leaves_num / max_leaves_num
|
||||||
else:
|
else:
|
||||||
leaf_similarity_score = 0.0
|
leaf_similarity_score = 0.0
|
||||||
|
|
||||||
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
|
adjusted_t = (
|
||||||
|
self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
|
||||||
|
)
|
||||||
|
|
||||||
if leaf_similarity_score >= 0.8 or (
|
if leaf_similarity_score >= 0.8 or (
|
||||||
leaf_similarity_score >= adjusted_t
|
leaf_similarity_score >= adjusted_t
|
||||||
|
@ -201,7 +206,10 @@ class ChangeDistiller:
|
||||||
matching_set = set()
|
matching_set = set()
|
||||||
while candidate_matchings:
|
while candidate_matchings:
|
||||||
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||||
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
|
if (
|
||||||
|
id(source_leaf) in self._unmatched_source_nodes
|
||||||
|
and id(target_leaf) in self._unmatched_target_nodes
|
||||||
|
):
|
||||||
matching_set.add((id(source_leaf), id(target_leaf)))
|
matching_set.add((id(source_leaf), id(target_leaf)))
|
||||||
self._unmatched_source_nodes.remove(id(source_leaf))
|
self._unmatched_source_nodes.remove(id(source_leaf))
|
||||||
self._unmatched_target_nodes.remove(id(target_leaf))
|
self._unmatched_target_nodes.remove(id(target_leaf))
|
||||||
|
@ -241,8 +249,7 @@ def _get_leaves(expression):
|
||||||
has_child_exprs = False
|
has_child_exprs = False
|
||||||
|
|
||||||
for a in expression.args.values():
|
for a in expression.args.values():
|
||||||
nodes = ensure_list(a)
|
for node in ensure_collection(a):
|
||||||
for node in nodes:
|
|
||||||
if isinstance(node, exp.Expression):
|
if isinstance(node, exp.Expression):
|
||||||
has_child_exprs = True
|
has_child_exprs = True
|
||||||
yield from _get_leaves(node)
|
yield from _get_leaves(node)
|
||||||
|
@ -268,7 +275,7 @@ def _expression_only_args(expression):
|
||||||
args = []
|
args = []
|
||||||
if expression:
|
if expression:
|
||||||
for a in expression.args.values():
|
for a in expression.args.values():
|
||||||
args.extend(ensure_list(a))
|
args.extend(ensure_collection(a))
|
||||||
return [a for a in args if isinstance(a, exp.Expression)]
|
return [a for a in args if isinstance(a, exp.Expression)]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
|
||||||
from sqlglot.helper import AutoName
|
from sqlglot.helper import AutoName
|
||||||
|
@ -30,7 +33,11 @@ class OptimizeError(SqlglotError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def concat_errors(errors, maximum):
|
class SchemaError(SqlglotError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||||
msg = [str(e) for e in errors[:maximum]]
|
msg = [str(e) for e in errors[:maximum]]
|
||||||
remaining = len(errors) - maximum
|
remaining = len(errors) - maximum
|
||||||
if remaining > 0:
|
if remaining > 0:
|
||||||
|
|
|
@ -19,6 +19,7 @@ class Context:
|
||||||
env (Optional[dict]): dictionary of functions within the execution context
|
env (Optional[dict]): dictionary of functions within the execution context
|
||||||
"""
|
"""
|
||||||
self.tables = tables
|
self.tables = tables
|
||||||
|
self._table = None
|
||||||
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
||||||
self.row_readers = {name: table.reader for name, table in tables.items()}
|
self.row_readers = {name: table.reader for name, table in tables.items()}
|
||||||
self.env = {**(env or {}), "scope": self.row_readers}
|
self.env = {**(env or {}), "scope": self.row_readers}
|
||||||
|
@ -29,8 +30,27 @@ class Context:
|
||||||
def eval_tuple(self, codes):
|
def eval_tuple(self, codes):
|
||||||
return tuple(self.eval(code) for code in codes)
|
return tuple(self.eval(code) for code in codes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table(self):
|
||||||
|
if self._table is None:
|
||||||
|
self._table = list(self.tables.values())[0]
|
||||||
|
for other in self.tables.values():
|
||||||
|
if self._table.columns != other.columns:
|
||||||
|
raise Exception(f"Columns are different.")
|
||||||
|
if len(self._table.rows) != len(other.rows):
|
||||||
|
raise Exception(f"Rows are different.")
|
||||||
|
return self._table
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self):
|
||||||
|
return self.table.columns
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self.table_iter(list(self.tables)[0])
|
self.env["scope"] = self.row_readers
|
||||||
|
for i in range(len(self.table.rows)):
|
||||||
|
for table in self.tables.values():
|
||||||
|
reader = table[i]
|
||||||
|
yield reader, self
|
||||||
|
|
||||||
def table_iter(self, table):
|
def table_iter(self, table):
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
@ -38,8 +58,8 @@ class Context:
|
||||||
for reader in self.tables[table]:
|
for reader in self.tables[table]:
|
||||||
yield reader, self
|
yield reader, self
|
||||||
|
|
||||||
def sort(self, table, key):
|
def sort(self, key):
|
||||||
table = self.tables[table]
|
table = self.table
|
||||||
|
|
||||||
def sort_key(row):
|
def sort_key(row):
|
||||||
table.reader.row = row
|
table.reader.row = row
|
||||||
|
@ -47,20 +67,20 @@ class Context:
|
||||||
|
|
||||||
table.rows.sort(key=sort_key)
|
table.rows.sort(key=sort_key)
|
||||||
|
|
||||||
def set_row(self, table, row):
|
def set_row(self, row):
|
||||||
self.row_readers[table].row = row
|
for table in self.tables.values():
|
||||||
|
table.reader.row = row
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
def set_index(self, table, index):
|
def set_index(self, index):
|
||||||
self.row_readers[table].row = self.tables[table].rows[index]
|
for table in self.tables.values():
|
||||||
|
table[index]
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
def set_range(self, table, start, end):
|
def set_range(self, start, end):
|
||||||
self.range_readers[table].range = range(start, end)
|
for name in self.tables:
|
||||||
|
self.range_readers[name].range = range(start, end)
|
||||||
self.env["scope"] = self.range_readers
|
self.env["scope"] = self.range_readers
|
||||||
|
|
||||||
def __getitem__(self, table):
|
|
||||||
return self.env["scope"][table]
|
|
||||||
|
|
||||||
def __contains__(self, table):
|
def __contains__(self, table):
|
||||||
return table in self.tables
|
return table in self.tables
|
||||||
|
|
|
@ -2,6 +2,8 @@ import datetime
|
||||||
import re
|
import re
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
|
from sqlglot.helper import PYTHON_VERSION
|
||||||
|
|
||||||
|
|
||||||
class reverse_key:
|
class reverse_key:
|
||||||
def __init__(self, obj):
|
def __init__(self, obj):
|
||||||
|
@ -25,7 +27,7 @@ ENV = {
|
||||||
"str": str,
|
"str": str,
|
||||||
"desc": reverse_key,
|
"desc": reverse_key,
|
||||||
"SUM": sum,
|
"SUM": sum,
|
||||||
"AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
|
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
|
||||||
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
|
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
|
||||||
"MAX": max,
|
"MAX": max,
|
||||||
"MIN": min,
|
"MIN": min,
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
import ast
|
import ast
|
||||||
import collections
|
import collections
|
||||||
import itertools
|
import itertools
|
||||||
|
import math
|
||||||
|
|
||||||
from sqlglot import exp, planner
|
from sqlglot import exp, generator, planner, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
||||||
from sqlglot.executor.context import Context
|
from sqlglot.executor.context import Context
|
||||||
from sqlglot.executor.env import ENV
|
from sqlglot.executor.env import ENV
|
||||||
from sqlglot.executor.table import Table
|
from sqlglot.executor.table import Table
|
||||||
from sqlglot.generator import Generator
|
|
||||||
from sqlglot.helper import csv_reader
|
from sqlglot.helper import csv_reader
|
||||||
from sqlglot.tokens import Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class PythonExecutor:
|
class PythonExecutor:
|
||||||
|
@ -26,7 +25,11 @@ class PythonExecutor:
|
||||||
while queue:
|
while queue:
|
||||||
node = queue.pop()
|
node = queue.pop()
|
||||||
context = self.context(
|
context = self.context(
|
||||||
{name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
|
{
|
||||||
|
name: table
|
||||||
|
for dep in node.dependencies
|
||||||
|
for name, table in contexts[dep].tables.items()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
running.add(node)
|
running.add(node)
|
||||||
|
|
||||||
|
@ -76,13 +79,10 @@ class PythonExecutor:
|
||||||
return Table(expression.alias_or_name for expression in expressions)
|
return Table(expression.alias_or_name for expression in expressions)
|
||||||
|
|
||||||
def scan(self, step, context):
|
def scan(self, step, context):
|
||||||
if hasattr(step, "source"):
|
|
||||||
source = step.source
|
source = step.source
|
||||||
|
|
||||||
if isinstance(source, exp.Expression):
|
if isinstance(source, exp.Expression):
|
||||||
source = source.name or source.alias
|
source = source.name or source.alias
|
||||||
else:
|
|
||||||
source = step.name
|
|
||||||
|
|
||||||
condition = self.generate(step.condition)
|
condition = self.generate(step.condition)
|
||||||
projections = self.generate_tuple(step.projections)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
@ -96,14 +96,12 @@ class PythonExecutor:
|
||||||
|
|
||||||
if projections:
|
if projections:
|
||||||
sink = self.table(step.projections)
|
sink = self.table(step.projections)
|
||||||
elif source in context:
|
|
||||||
sink = Table(context[source].columns)
|
|
||||||
else:
|
else:
|
||||||
sink = None
|
sink = None
|
||||||
|
|
||||||
for reader, ctx in table_iter:
|
for reader, ctx in table_iter:
|
||||||
if sink is None:
|
if sink is None:
|
||||||
sink = Table(ctx[source].columns)
|
sink = Table(reader.columns)
|
||||||
|
|
||||||
if condition and not ctx.eval(condition):
|
if condition and not ctx.eval(condition):
|
||||||
continue
|
continue
|
||||||
|
@ -135,119 +133,102 @@ class PythonExecutor:
|
||||||
types.append(type(ast.literal_eval(v)))
|
types.append(type(ast.literal_eval(v)))
|
||||||
except (ValueError, SyntaxError):
|
except (ValueError, SyntaxError):
|
||||||
types.append(str)
|
types.append(str)
|
||||||
context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
|
context.set_row(tuple(t(v) for t, v in zip(types, row)))
|
||||||
yield context[alias], context
|
yield context.table.reader, context
|
||||||
|
|
||||||
def join(self, step, context):
|
def join(self, step, context):
|
||||||
source = step.name
|
source = step.name
|
||||||
|
|
||||||
join_context = self.context({source: context.tables[source]})
|
source_table = context.tables[source]
|
||||||
|
source_context = self.context({source: source_table})
|
||||||
def merge_context(ctx, table):
|
column_ranges = {source: range(0, len(source_table.columns))}
|
||||||
# create a new context where all existing tables are mapped to a new one
|
|
||||||
return self.context({name: table for name in ctx.tables})
|
|
||||||
|
|
||||||
for name, join in step.joins.items():
|
for name, join in step.joins.items():
|
||||||
join_context = self.context({**join_context.tables, name: context.tables[name]})
|
table = context.tables[name]
|
||||||
|
start = max(r.stop for r in column_ranges.values())
|
||||||
|
column_ranges[name] = range(start, len(table.columns) + start)
|
||||||
|
join_context = self.context({name: table})
|
||||||
|
|
||||||
if join.get("source_key"):
|
if join.get("source_key"):
|
||||||
table = self.hash_join(join, source, name, join_context)
|
table = self.hash_join(join, source_context, join_context)
|
||||||
else:
|
else:
|
||||||
table = self.nested_loop_join(join, source, name, join_context)
|
table = self.nested_loop_join(join, source_context, join_context)
|
||||||
|
|
||||||
join_context = merge_context(join_context, table)
|
source_context = self.context(
|
||||||
|
{
|
||||||
|
name: Table(table.columns, table.rows, column_range)
|
||||||
|
for name, column_range in column_ranges.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# apply projections or conditions
|
condition = self.generate(step.condition)
|
||||||
context = self.scan(step, join_context)
|
projections = self.generate_tuple(step.projections)
|
||||||
|
|
||||||
# use the scan context since it returns a single table
|
if not condition or not projections:
|
||||||
# otherwise there are no projections so all other tables are still in scope
|
return source_context
|
||||||
if step.projections:
|
|
||||||
return context
|
|
||||||
|
|
||||||
return merge_context(join_context, context.tables[source])
|
sink = self.table(step.projections if projections else source_context.columns)
|
||||||
|
|
||||||
def nested_loop_join(self, _join, a, b, context):
|
for reader, ctx in join_context:
|
||||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
if condition and not ctx.eval(condition):
|
||||||
|
continue
|
||||||
|
|
||||||
for reader_a, _ in context.table_iter(a):
|
if projections:
|
||||||
for reader_b, _ in context.table_iter(b):
|
sink.append(ctx.eval_tuple(projections))
|
||||||
|
else:
|
||||||
|
sink.append(reader.row)
|
||||||
|
|
||||||
|
if len(sink) >= step.limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return self.context({step.name: sink})
|
||||||
|
|
||||||
|
def nested_loop_join(self, _join, source_context, join_context):
|
||||||
|
table = Table(source_context.columns + join_context.columns)
|
||||||
|
|
||||||
|
for reader_a, _ in source_context:
|
||||||
|
for reader_b, _ in join_context:
|
||||||
table.append(reader_a.row + reader_b.row)
|
table.append(reader_a.row + reader_b.row)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def hash_join(self, join, a, b, context):
|
def hash_join(self, join, source_context, join_context):
|
||||||
a_key = self.generate_tuple(join["source_key"])
|
source_key = self.generate_tuple(join["source_key"])
|
||||||
b_key = self.generate_tuple(join["join_key"])
|
join_key = self.generate_tuple(join["join_key"])
|
||||||
|
|
||||||
results = collections.defaultdict(lambda: ([], []))
|
results = collections.defaultdict(lambda: ([], []))
|
||||||
|
|
||||||
for reader, ctx in context.table_iter(a):
|
for reader, ctx in source_context:
|
||||||
results[ctx.eval_tuple(a_key)][0].append(reader.row)
|
results[ctx.eval_tuple(source_key)][0].append(reader.row)
|
||||||
for reader, ctx in context.table_iter(b):
|
for reader, ctx in join_context:
|
||||||
results[ctx.eval_tuple(b_key)][1].append(reader.row)
|
results[ctx.eval_tuple(join_key)][1].append(reader.row)
|
||||||
|
|
||||||
|
table = Table(source_context.columns + join_context.columns)
|
||||||
|
|
||||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
|
||||||
for a_group, b_group in results.values():
|
for a_group, b_group in results.values():
|
||||||
for a_row, b_row in itertools.product(a_group, b_group):
|
for a_row, b_row in itertools.product(a_group, b_group):
|
||||||
table.append(a_row + b_row)
|
table.append(a_row + b_row)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def sort_merge_join(self, join, a, b, context):
|
|
||||||
a_key = self.generate_tuple(join["source_key"])
|
|
||||||
b_key = self.generate_tuple(join["join_key"])
|
|
||||||
|
|
||||||
context.sort(a, a_key)
|
|
||||||
context.sort(b, b_key)
|
|
||||||
|
|
||||||
a_i = 0
|
|
||||||
b_i = 0
|
|
||||||
a_n = len(context.tables[a])
|
|
||||||
b_n = len(context.tables[b])
|
|
||||||
|
|
||||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
|
||||||
|
|
||||||
def get_key(source, key, i):
|
|
||||||
context.set_index(source, i)
|
|
||||||
return context.eval_tuple(key)
|
|
||||||
|
|
||||||
while a_i < a_n and b_i < b_n:
|
|
||||||
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
|
|
||||||
|
|
||||||
a_group = []
|
|
||||||
|
|
||||||
while a_i < a_n and key == get_key(a, a_key, a_i):
|
|
||||||
a_group.append(context[a].row)
|
|
||||||
a_i += 1
|
|
||||||
|
|
||||||
b_group = []
|
|
||||||
|
|
||||||
while b_i < b_n and key == get_key(b, b_key, b_i):
|
|
||||||
b_group.append(context[b].row)
|
|
||||||
b_i += 1
|
|
||||||
|
|
||||||
for a_row, b_row in itertools.product(a_group, b_group):
|
|
||||||
table.append(a_row + b_row)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
def aggregate(self, step, context):
|
def aggregate(self, step, context):
|
||||||
source = step.source
|
source = step.source
|
||||||
group_by = self.generate_tuple(step.group)
|
group_by = self.generate_tuple(step.group)
|
||||||
aggregations = self.generate_tuple(step.aggregations)
|
aggregations = self.generate_tuple(step.aggregations)
|
||||||
operands = self.generate_tuple(step.operands)
|
operands = self.generate_tuple(step.operands)
|
||||||
|
|
||||||
context.sort(source, group_by)
|
if operands:
|
||||||
|
|
||||||
if step.operands:
|
|
||||||
source_table = context.tables[source]
|
source_table = context.tables[source]
|
||||||
operand_table = Table(source_table.columns + self.table(step.operands).columns)
|
operand_table = Table(source_table.columns + self.table(step.operands).columns)
|
||||||
|
|
||||||
for reader, ctx in context:
|
for reader, ctx in context:
|
||||||
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
||||||
|
|
||||||
context = self.context({source: operand_table})
|
context = self.context(
|
||||||
|
{None: operand_table, **{table: operand_table for table in context.tables}}
|
||||||
|
)
|
||||||
|
|
||||||
|
context.sort(group_by)
|
||||||
|
|
||||||
group = None
|
group = None
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -256,15 +237,15 @@ class PythonExecutor:
|
||||||
table = self.table(step.group + step.aggregations)
|
table = self.table(step.group + step.aggregations)
|
||||||
|
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
context.set_index(source, i)
|
context.set_index(i)
|
||||||
key = context.eval_tuple(group_by)
|
key = context.eval_tuple(group_by)
|
||||||
group = key if group is None else group
|
group = key if group is None else group
|
||||||
end += 1
|
end += 1
|
||||||
|
|
||||||
if i == length - 1:
|
if i == length - 1:
|
||||||
context.set_range(source, start, end - 1)
|
context.set_range(start, end - 1)
|
||||||
elif key != group:
|
elif key != group:
|
||||||
context.set_range(source, start, end - 2)
|
context.set_range(start, end - 2)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -272,13 +253,32 @@ class PythonExecutor:
|
||||||
group = key
|
group = key
|
||||||
start = end - 2
|
start = end - 2
|
||||||
|
|
||||||
return self.scan(step, self.context({source: table}))
|
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||||
|
|
||||||
|
if step.projections:
|
||||||
|
return self.scan(step, context)
|
||||||
|
return context
|
||||||
|
|
||||||
def sort(self, step, context):
|
def sort(self, step, context):
|
||||||
table = list(context.tables)[0]
|
projections = self.generate_tuple(step.projections)
|
||||||
key = self.generate_tuple(step.key)
|
|
||||||
context.sort(table, key)
|
sink = self.table(step.projections)
|
||||||
return self.scan(step, context)
|
|
||||||
|
for reader, ctx in context:
|
||||||
|
sink.append(ctx.eval_tuple(projections))
|
||||||
|
|
||||||
|
context = self.context(
|
||||||
|
{
|
||||||
|
None: sink,
|
||||||
|
**{table: sink for table in context.tables},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
context.sort(self.generate_tuple(step.key))
|
||||||
|
|
||||||
|
if not math.isinf(step.limit):
|
||||||
|
context.table.rows = context.table.rows[0 : step.limit]
|
||||||
|
|
||||||
|
return self.context({step.name: context.table})
|
||||||
|
|
||||||
|
|
||||||
def _cast_py(self, expression):
|
def _cast_py(self, expression):
|
||||||
|
@ -293,7 +293,7 @@ def _cast_py(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def _column_py(self, expression):
|
def _column_py(self, expression):
|
||||||
table = self.sql(expression, "table")
|
table = self.sql(expression, "table") or None
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
return f"scope[{table}][{this}]"
|
return f"scope[{table}][{this}]"
|
||||||
|
|
||||||
|
@ -319,10 +319,10 @@ def _ordered_py(self, expression):
|
||||||
|
|
||||||
|
|
||||||
class Python(Dialect):
|
class Python(Dialect):
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
ESCAPE = "\\"
|
ESCAPES = ["\\"]
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
exp.Alias: lambda self, e: self.sql(e.this),
|
exp.Alias: lambda self, e: self.sql(e.this),
|
||||||
exp.Array: inline_array_sql,
|
exp.Array: inline_array_sql,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
class Table:
|
class Table:
|
||||||
def __init__(self, *columns, rows=None):
|
def __init__(self, columns, rows=None, column_range=None):
|
||||||
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
|
self.columns = tuple(columns)
|
||||||
|
self.column_range = column_range
|
||||||
|
self.reader = RowReader(self.columns, self.column_range)
|
||||||
|
|
||||||
self.rows = rows or []
|
self.rows = rows or []
|
||||||
if rows:
|
if rows:
|
||||||
assert len(rows[0]) == len(self.columns)
|
assert len(rows[0]) == len(self.columns)
|
||||||
self.reader = RowReader(self.columns)
|
|
||||||
self.range_reader = RangeReader(self)
|
self.range_reader = RangeReader(self)
|
||||||
|
|
||||||
def append(self, row):
|
def append(self, row):
|
||||||
|
@ -29,15 +31,22 @@ class Table:
|
||||||
return self.reader
|
return self.reader
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
widths = {column: len(column) for column in self.columns}
|
columns = tuple(
|
||||||
lines = [" ".join(column for column in self.columns)]
|
column
|
||||||
|
for i, column in enumerate(self.columns)
|
||||||
|
if not self.column_range or i in self.column_range
|
||||||
|
)
|
||||||
|
widths = {column: len(column) for column in columns}
|
||||||
|
lines = [" ".join(column for column in columns)]
|
||||||
|
|
||||||
for i, row in enumerate(self):
|
for i, row in enumerate(self):
|
||||||
if i > 10:
|
if i > 10:
|
||||||
break
|
break
|
||||||
|
|
||||||
lines.append(
|
lines.append(
|
||||||
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
|
" ".join(
|
||||||
|
str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@ -70,8 +79,10 @@ class RangeReader:
|
||||||
|
|
||||||
|
|
||||||
class RowReader:
|
class RowReader:
|
||||||
def __init__(self, columns):
|
def __init__(self, columns, column_range=None):
|
||||||
self.columns = {column: i for i, column in enumerate(columns)}
|
self.columns = {
|
||||||
|
column: i for i, column in enumerate(columns) if not column_range or i in column_range
|
||||||
|
}
|
||||||
self.row = None
|
self.row = None
|
||||||
|
|
||||||
def __getitem__(self, column):
|
def __getitem__(self, column):
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import numbers
|
import numbers
|
||||||
import re
|
import re
|
||||||
|
import typing as t
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
|
||||||
from sqlglot.helper import (
|
from sqlglot.helper import (
|
||||||
AutoName,
|
AutoName,
|
||||||
camel_to_snake_case,
|
camel_to_snake_case,
|
||||||
ensure_list,
|
ensure_collection,
|
||||||
list_get,
|
seq_get,
|
||||||
split_num_words,
|
split_num_words,
|
||||||
subclasses,
|
subclasses,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
|
||||||
|
|
||||||
class _Expression(type):
|
class _Expression(type):
|
||||||
def __new__(cls, clsname, bases, attrs):
|
def __new__(cls, clsname, bases, attrs):
|
||||||
|
@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
|
||||||
or optional (False).
|
or optional (False).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = None
|
key = "Expression"
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
__slots__ = ("args", "parent", "arg_key", "type")
|
__slots__ = ("args", "parent", "arg_key", "type", "comment")
|
||||||
|
|
||||||
def __init__(self, **args):
|
def __init__(self, **args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.parent = None
|
self.parent = None
|
||||||
self.arg_key = None
|
self.arg_key = None
|
||||||
self.type = None
|
self.type = None
|
||||||
|
self.comment = None
|
||||||
|
|
||||||
for arg_key, value in self.args.items():
|
for arg_key, value in self.args.items():
|
||||||
self._set_parent(arg_key, value)
|
self._set_parent(arg_key, value)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other) -> bool:
|
||||||
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
|
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash(
|
return hash(
|
||||||
(
|
(
|
||||||
self.key,
|
self.key,
|
||||||
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
|
tuple(
|
||||||
|
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
|
||||||
return field.this
|
return field.this
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def find_comment(self, key: str) -> str:
|
||||||
|
"""
|
||||||
|
Finds the comment that is attached to a specified child node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: the key of the target child node (e.g. "this", "expression", etc).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The comment attached to the child node, or the empty string, if it doesn't exist.
|
||||||
|
"""
|
||||||
|
field = self.args.get(key)
|
||||||
|
return field.comment if isinstance(field, Expression) else ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_string(self):
|
def is_string(self):
|
||||||
return isinstance(self, Literal) and self.args["is_string"]
|
return isinstance(self, Literal) and self.args["is_string"]
|
||||||
|
@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
|
||||||
return self.alias or self.name
|
return self.alias or self.name
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
return self.__class__(**deepcopy(self.args))
|
copy = self.__class__(**deepcopy(self.args))
|
||||||
|
copy.comment = self.comment
|
||||||
|
copy.type = self.type
|
||||||
|
return copy
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
new = deepcopy(self)
|
new = deepcopy(self)
|
||||||
|
@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
|
||||||
return
|
return
|
||||||
|
|
||||||
for k, v in self.args.items():
|
for k, v in self.args.items():
|
||||||
nodes = ensure_list(v)
|
for node in ensure_collection(v):
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if isinstance(node, Expression):
|
if isinstance(node, Expression):
|
||||||
yield from node.dfs(self, k, prune)
|
yield from node.dfs(self, k, prune)
|
||||||
|
|
||||||
|
@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
if isinstance(item, Expression):
|
if isinstance(item, Expression):
|
||||||
for k, v in item.args.items():
|
for k, v in item.args.items():
|
||||||
nodes = ensure_list(v)
|
for node in ensure_collection(v):
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if isinstance(node, Expression):
|
if isinstance(node, Expression):
|
||||||
queue.append((node, item, k))
|
queue.append((node, item, k))
|
||||||
|
|
||||||
|
@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.to_s()
|
return self.to_s()
|
||||||
|
|
||||||
def sql(self, dialect=None, **opts):
|
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
||||||
"""
|
"""
|
||||||
Returns SQL string representation of this tree.
|
Returns SQL string representation of this tree.
|
||||||
|
|
||||||
|
@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
||||||
|
|
||||||
def to_s(self, hide_missing=True, level=0):
|
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
||||||
indent = "" if not level else "\n"
|
indent = "" if not level else "\n"
|
||||||
indent += "".join([" "] * level)
|
indent += "".join([" "] * level)
|
||||||
left = f"({self.key.upper()} "
|
left = f"({self.key.upper()} "
|
||||||
|
@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
|
||||||
args = {
|
args = {
|
||||||
k: ", ".join(
|
k: ", ".join(
|
||||||
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
||||||
for v in ensure_list(vs)
|
for v in ensure_collection(vs)
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
for k, vs in self.args.items()
|
for k, vs in self.args.items()
|
||||||
}
|
}
|
||||||
|
args["comment"] = self.comment
|
||||||
|
args["type"] = self.type
|
||||||
args = {k: v for k, v in args.items() if v or not hide_missing}
|
args = {k: v for k, v in args.items() if v or not hide_missing}
|
||||||
|
|
||||||
right = ", ".join(f"{k}: {v}" for k, v in args.items())
|
right = ", ".join(f"{k}: {v}" for k, v in args.items())
|
||||||
|
@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Annotation(Expression):
|
|
||||||
arg_types = {
|
|
||||||
"this": True,
|
|
||||||
"expression": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def alias(self):
|
|
||||||
return self.expression.alias_or_name
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(Expression):
|
class Cache(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"with": False,
|
"with": False,
|
||||||
|
@ -623,6 +635,38 @@ class Describe(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Set(Expression):
|
||||||
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SetItem(Expression):
|
||||||
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"kind": False,
|
||||||
|
"collate": False, # MySQL SET NAMES statement
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Show(Expression):
|
||||||
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"target": False,
|
||||||
|
"offset": False,
|
||||||
|
"limit": False,
|
||||||
|
"like": False,
|
||||||
|
"where": False,
|
||||||
|
"db": False,
|
||||||
|
"full": False,
|
||||||
|
"mutex": False,
|
||||||
|
"query": False,
|
||||||
|
"channel": False,
|
||||||
|
"global": False,
|
||||||
|
"log": False,
|
||||||
|
"position": False,
|
||||||
|
"types": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedFunction(Expression):
|
class UserDefinedFunction(Expression):
|
||||||
arg_types = {"this": True, "expressions": False}
|
arg_types = {"this": True, "expressions": False}
|
||||||
|
|
||||||
|
@ -864,18 +908,20 @@ class Literal(Condition):
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
return (
|
||||||
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
|
isinstance(other, Literal)
|
||||||
|
and self.this == other.this
|
||||||
|
and self.args["is_string"] == other.args["is_string"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.key, self.this, self.args["is_string"]))
|
return hash((self.key, self.this, self.args["is_string"]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def number(cls, number):
|
def number(cls, number) -> Literal:
|
||||||
return cls(this=str(number), is_string=False)
|
return cls(this=str(number), is_string=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def string(cls, string):
|
def string(cls, string) -> Literal:
|
||||||
return cls(this=str(string), is_string=True)
|
return cls(this=str(string), is_string=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1087,7 +1133,7 @@ class Properties(Expression):
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, properties_dict):
|
def from_dict(cls, properties_dict) -> Properties:
|
||||||
expressions = []
|
expressions = []
|
||||||
for key, value in properties_dict.items():
|
for key, value in properties_dict.items():
|
||||||
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
||||||
|
@ -1323,7 +1369,7 @@ class Select(Subqueryable):
|
||||||
**QUERY_MODIFIERS,
|
**QUERY_MODIFIERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the FROM expression.
|
Set the FROM expression.
|
||||||
|
|
||||||
|
@ -1356,7 +1402,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the GROUP BY expression.
|
Set the GROUP BY expression.
|
||||||
|
|
||||||
|
@ -1392,7 +1438,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the ORDER BY expression.
|
Set the ORDER BY expression.
|
||||||
|
|
||||||
|
@ -1425,7 +1471,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the SORT BY expression.
|
Set the SORT BY expression.
|
||||||
|
|
||||||
|
@ -1458,7 +1504,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the CLUSTER BY expression.
|
Set the CLUSTER BY expression.
|
||||||
|
|
||||||
|
@ -1491,7 +1537,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def limit(self, expression, dialect=None, copy=True, **opts):
|
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the LIMIT expression.
|
Set the LIMIT expression.
|
||||||
|
|
||||||
|
@ -1522,7 +1568,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def offset(self, expression, dialect=None, copy=True, **opts):
|
def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the OFFSET expression.
|
Set the OFFSET expression.
|
||||||
|
|
||||||
|
@ -1553,7 +1599,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Append to or set the SELECT expressions.
|
Append to or set the SELECT expressions.
|
||||||
|
|
||||||
|
@ -1583,7 +1629,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Append to or set the LATERAL expressions.
|
Append to or set the LATERAL expressions.
|
||||||
|
|
||||||
|
@ -1626,7 +1672,7 @@ class Select(Subqueryable):
|
||||||
dialect=None,
|
dialect=None,
|
||||||
copy=True,
|
copy=True,
|
||||||
**opts,
|
**opts,
|
||||||
):
|
) -> Select:
|
||||||
"""
|
"""
|
||||||
Append to or set the JOIN expressions.
|
Append to or set the JOIN expressions.
|
||||||
|
|
||||||
|
@ -1672,7 +1718,7 @@ class Select(Subqueryable):
|
||||||
join.this.replace(join.this.subquery())
|
join.this.replace(join.this.subquery())
|
||||||
|
|
||||||
if join_type:
|
if join_type:
|
||||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
|
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||||
if natural:
|
if natural:
|
||||||
join.set("natural", True)
|
join.set("natural", True)
|
||||||
if side:
|
if side:
|
||||||
|
@ -1681,12 +1727,12 @@ class Select(Subqueryable):
|
||||||
join.set("kind", kind.text)
|
join.set("kind", kind.text)
|
||||||
|
|
||||||
if on:
|
if on:
|
||||||
on = and_(*ensure_list(on), dialect=dialect, **opts)
|
on = and_(*ensure_collection(on), dialect=dialect, **opts)
|
||||||
join.set("on", on)
|
join.set("on", on)
|
||||||
|
|
||||||
if using:
|
if using:
|
||||||
join = _apply_list_builder(
|
join = _apply_list_builder(
|
||||||
*ensure_list(using),
|
*ensure_collection(using),
|
||||||
instance=join,
|
instance=join,
|
||||||
arg="using",
|
arg="using",
|
||||||
append=append,
|
append=append,
|
||||||
|
@ -1705,7 +1751,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Append to or set the WHERE expressions.
|
Append to or set the WHERE expressions.
|
||||||
|
|
||||||
|
@ -1737,7 +1783,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Append to or set the HAVING expressions.
|
Append to or set the HAVING expressions.
|
||||||
|
|
||||||
|
@ -1769,7 +1815,7 @@ class Select(Subqueryable):
|
||||||
**opts,
|
**opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def distinct(self, distinct=True, copy=True):
|
def distinct(self, distinct=True, copy=True) -> Select:
|
||||||
"""
|
"""
|
||||||
Set the OFFSET expression.
|
Set the OFFSET expression.
|
||||||
|
|
||||||
|
@ -1788,7 +1834,7 @@ class Select(Subqueryable):
|
||||||
instance.set("distinct", Distinct() if distinct else None)
|
instance.set("distinct", Distinct() if distinct else None)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
|
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
|
||||||
"""
|
"""
|
||||||
Convert this expression to a CREATE TABLE AS statement.
|
Convert this expression to a CREATE TABLE AS statement.
|
||||||
|
|
||||||
|
@ -1826,11 +1872,11 @@ class Select(Subqueryable):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_selects(self):
|
def named_selects(self) -> t.List[str]:
|
||||||
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
|
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selects(self):
|
def selects(self) -> t.List[Expression]:
|
||||||
return self.expressions
|
return self.expressions
|
||||||
|
|
||||||
|
|
||||||
|
@ -1910,12 +1956,16 @@ class Parameter(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SessionParameter(Expression):
|
||||||
|
arg_types = {"this": True, "kind": False}
|
||||||
|
|
||||||
|
|
||||||
class Placeholder(Expression):
|
class Placeholder(Expression):
|
||||||
arg_types = {"this": False}
|
arg_types = {"this": False}
|
||||||
|
|
||||||
|
|
||||||
class Null(Condition):
|
class Null(Condition):
|
||||||
arg_types = {}
|
arg_types: t.Dict[str, t.Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class Boolean(Condition):
|
class Boolean(Condition):
|
||||||
|
@ -1936,6 +1986,7 @@ class DataType(Expression):
|
||||||
NVARCHAR = auto()
|
NVARCHAR = auto()
|
||||||
TEXT = auto()
|
TEXT = auto()
|
||||||
BINARY = auto()
|
BINARY = auto()
|
||||||
|
VARBINARY = auto()
|
||||||
INT = auto()
|
INT = auto()
|
||||||
TINYINT = auto()
|
TINYINT = auto()
|
||||||
SMALLINT = auto()
|
SMALLINT = auto()
|
||||||
|
@ -1975,7 +2026,7 @@ class DataType(Expression):
|
||||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, dtype, **kwargs):
|
def build(cls, dtype, **kwargs) -> DataType:
|
||||||
return DataType(
|
return DataType(
|
||||||
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NullSafeEQ(Binary, Predicate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NullSafeNEQ(Binary, Predicate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Distance(Binary):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Escape(Binary):
|
class Escape(Binary):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2101,18 +2164,14 @@ class Is(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Kwarg(Binary):
|
||||||
|
"""Kwarg in special functions like func(kwarg => y)."""
|
||||||
|
|
||||||
|
|
||||||
class Like(Binary, Predicate):
|
class Like(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SimilarTo(Binary, Predicate):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Distance(Binary):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LT(Binary, Predicate):
|
class LT(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarTo(Binary, Predicate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Sub(Binary):
|
class Sub(Binary):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2189,7 +2252,13 @@ class Distinct(Expression):
|
||||||
|
|
||||||
|
|
||||||
class In(Predicate):
|
class In(Predicate):
|
||||||
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
|
arg_types = {
|
||||||
|
"this": True,
|
||||||
|
"expressions": False,
|
||||||
|
"query": False,
|
||||||
|
"unnest": False,
|
||||||
|
"field": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TimeUnit(Expression):
|
class TimeUnit(Expression):
|
||||||
|
@ -2255,7 +2324,9 @@ class Func(Condition):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sql_names(cls):
|
def sql_names(cls):
|
||||||
if cls is Func:
|
if cls is Func:
|
||||||
raise NotImplementedError("SQL name is only supported by concrete function implementations")
|
raise NotImplementedError(
|
||||||
|
"SQL name is only supported by concrete function implementations"
|
||||||
|
)
|
||||||
if not hasattr(cls, "_sql_names"):
|
if not hasattr(cls, "_sql_names"):
|
||||||
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
||||||
return cls._sql_names
|
return cls._sql_names
|
||||||
|
@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
|
||||||
arg_types = {"this": True, "expression": True, "unit": False}
|
arg_types = {"this": True, "expression": True, "unit": False}
|
||||||
|
|
||||||
|
|
||||||
class DateTrunc(Func, TimeUnit):
|
class DateTrunc(Func):
|
||||||
arg_types = {"this": True, "unit": True, "zone": False}
|
arg_types = {"this": True, "expression": True, "zone": False}
|
||||||
|
|
||||||
|
|
||||||
class DatetimeAdd(Func, TimeUnit):
|
class DatetimeAdd(Func, TimeUnit):
|
||||||
|
@ -2791,6 +2862,10 @@ class Year(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Use(Expression):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _norm_args(expression):
|
def _norm_args(expression):
|
||||||
args = {}
|
args = {}
|
||||||
|
|
||||||
|
@ -2822,7 +2897,7 @@ def maybe_parse(
|
||||||
dialect=None,
|
dialect=None,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
**opts,
|
**opts,
|
||||||
):
|
) -> t.Optional[Expression]:
|
||||||
"""Gracefully handle a possible string or expression.
|
"""Gracefully handle a possible string or expression.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
|
||||||
return Except(this=left, expression=right, distinct=distinct)
|
return Except(this=left, expression=right, distinct=distinct)
|
||||||
|
|
||||||
|
|
||||||
def select(*expressions, dialect=None, **opts):
|
def select(*expressions, dialect=None, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Initializes a syntax tree from one or multiple SELECT expressions.
|
Initializes a syntax tree from one or multiple SELECT expressions.
|
||||||
|
|
||||||
|
@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
|
||||||
return Select().select(*expressions, dialect=dialect, **opts)
|
return Select().select(*expressions, dialect=dialect, **opts)
|
||||||
|
|
||||||
|
|
||||||
def from_(*expressions, dialect=None, **opts):
|
def from_(*expressions, dialect=None, **opts) -> Select:
|
||||||
"""
|
"""
|
||||||
Initializes a syntax tree from a FROM expression.
|
Initializes a syntax tree from a FROM expression.
|
||||||
|
|
||||||
|
@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
|
||||||
return Select().from_(*expressions, dialect=dialect, **opts)
|
return Select().from_(*expressions, dialect=dialect, **opts)
|
||||||
|
|
||||||
|
|
||||||
def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
|
||||||
"""
|
"""
|
||||||
Creates an update statement.
|
Creates an update statement.
|
||||||
|
|
||||||
|
@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
||||||
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
|
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
|
||||||
update.set(
|
update.set(
|
||||||
"expressions",
|
"expressions",
|
||||||
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
|
[
|
||||||
|
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
|
||||||
|
for k, v in properties.items()
|
||||||
|
],
|
||||||
)
|
)
|
||||||
if from_:
|
if from_:
|
||||||
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
|
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
|
||||||
|
@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
||||||
return update
|
return update
|
||||||
|
|
||||||
|
|
||||||
def delete(table, where=None, dialect=None, **opts):
|
def delete(table, where=None, dialect=None, **opts) -> Delete:
|
||||||
"""
|
"""
|
||||||
Builds a delete statement.
|
Builds a delete statement.
|
||||||
|
|
||||||
|
@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def condition(expression, dialect=None, **opts):
|
def condition(expression, dialect=None, **opts) -> Condition:
|
||||||
"""
|
"""
|
||||||
Initialize a logical condition expression.
|
Initialize a logical condition expression.
|
||||||
|
|
||||||
|
@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
|
||||||
Returns:
|
Returns:
|
||||||
Condition: the expression
|
Condition: the expression
|
||||||
"""
|
"""
|
||||||
return maybe_parse(
|
return maybe_parse( # type: ignore
|
||||||
expression,
|
expression,
|
||||||
into=Condition,
|
into=Condition,
|
||||||
dialect=dialect,
|
dialect=dialect,
|
||||||
|
@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def and_(*expressions, dialect=None, **opts):
|
def and_(*expressions, dialect=None, **opts) -> And:
|
||||||
"""
|
"""
|
||||||
Combine multiple conditions with an AND logical operator.
|
Combine multiple conditions with an AND logical operator.
|
||||||
|
|
||||||
|
@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
|
||||||
return _combine(expressions, And, dialect, **opts)
|
return _combine(expressions, And, dialect, **opts)
|
||||||
|
|
||||||
|
|
||||||
def or_(*expressions, dialect=None, **opts):
|
def or_(*expressions, dialect=None, **opts) -> Or:
|
||||||
"""
|
"""
|
||||||
Combine multiple conditions with an OR logical operator.
|
Combine multiple conditions with an OR logical operator.
|
||||||
|
|
||||||
|
@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
|
||||||
return _combine(expressions, Or, dialect, **opts)
|
return _combine(expressions, Or, dialect, **opts)
|
||||||
|
|
||||||
|
|
||||||
def not_(expression, dialect=None, **opts):
|
def not_(expression, dialect=None, **opts) -> Not:
|
||||||
"""
|
"""
|
||||||
Wrap a condition with a NOT operator.
|
Wrap a condition with a NOT operator.
|
||||||
|
|
||||||
|
@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
|
||||||
return Not(this=_wrap_operator(this))
|
return Not(this=_wrap_operator(this))
|
||||||
|
|
||||||
|
|
||||||
def paren(expression):
|
def paren(expression) -> Paren:
|
||||||
return Paren(this=expression)
|
return Paren(this=expression)
|
||||||
|
|
||||||
|
|
||||||
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
|
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
|
||||||
|
|
||||||
|
|
||||||
def to_identifier(alias, quoted=None):
|
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
||||||
if alias is None:
|
if alias is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(alias, Identifier):
|
if isinstance(alias, Identifier):
|
||||||
|
@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
|
|
||||||
def to_table(sql_path: str, **kwargs) -> Table:
|
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
|
||||||
"""
|
"""
|
||||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||||
|
|
||||||
If a table is passed in then that table is returned.
|
If a table is passed in then that table is returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql_path(str|Table): `[catalog].[schema].[table]` string
|
sql_path: a `[catalog].[schema].[table]` string.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Table: A table expression
|
A table expression.
|
||||||
"""
|
"""
|
||||||
if sql_path is None or isinstance(sql_path, Table):
|
if sql_path is None or isinstance(sql_path, Table):
|
||||||
return sql_path
|
return sql_path
|
||||||
|
@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
||||||
return Select().from_(expression, dialect=dialect, **opts)
|
return Select().from_(expression, dialect=dialect, **opts)
|
||||||
|
|
||||||
|
|
||||||
def column(col, table=None, quoted=None):
|
def column(col, table=None, quoted=None) -> Column:
|
||||||
"""
|
"""
|
||||||
Build a Column.
|
Build a Column.
|
||||||
Args:
|
Args:
|
||||||
|
@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def table_(table, db=None, catalog=None, quoted=None, alias=None):
|
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||||
"""Build a Table.
|
"""Build a Table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def values(values, alias=None):
|
def values(values, alias=None) -> Values:
|
||||||
"""Build VALUES statement.
|
"""Build VALUES statement.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -3449,7 +3527,7 @@ def values(values, alias=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert(value):
|
def convert(value) -> Expression:
|
||||||
"""Convert a python value into an expression object.
|
"""Convert a python value into an expression object.
|
||||||
|
|
||||||
Raises an error if a conversion is not possible.
|
Raises an error if a conversion is not possible.
|
||||||
|
@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
|
||||||
|
|
||||||
for cn in child_nodes:
|
for cn in child_nodes:
|
||||||
if isinstance(cn, Expression):
|
if isinstance(cn, Expression):
|
||||||
cns = ensure_list(fun(cn))
|
for child_node in ensure_collection(fun(cn)):
|
||||||
for child_node in cns:
|
|
||||||
new_child_nodes.append(child_node)
|
new_child_nodes.append(child_node)
|
||||||
child_node.parent = expression
|
child_node.parent = expression
|
||||||
child_node.arg_key = k
|
child_node.arg_key = k
|
||||||
else:
|
else:
|
||||||
new_child_nodes.append(cn)
|
new_child_nodes.append(cn)
|
||||||
|
|
||||||
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
|
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
|
||||||
|
|
||||||
|
|
||||||
def column_table_names(expression):
|
def column_table_names(expression):
|
||||||
|
@ -3529,7 +3606,7 @@ def column_table_names(expression):
|
||||||
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
|
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
|
||||||
|
|
||||||
|
|
||||||
def table_name(table):
|
def table_name(table) -> str:
|
||||||
"""Get the full name of a table as a string.
|
"""Get the full name of a table as a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -3546,6 +3623,9 @@ def table_name(table):
|
||||||
|
|
||||||
table = maybe_parse(table, into=Table)
|
table = maybe_parse(table, into=Table)
|
||||||
|
|
||||||
|
if not table:
|
||||||
|
raise ValueError(f"Cannot parse {table}")
|
||||||
|
|
||||||
return ".".join(
|
return ".".join(
|
||||||
part
|
part
|
||||||
for part in (
|
for part in (
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||||
|
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
NEWLINE_RE = re.compile("\r\n?|\n")
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
"""
|
"""
|
||||||
|
@ -47,8 +53,7 @@ class Generator:
|
||||||
The default is on the smaller end because the length only represents a segment and not the true
|
The default is on the smaller end because the length only represents a segment and not the true
|
||||||
line length.
|
line length.
|
||||||
Default: 80
|
Default: 80
|
||||||
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
|
comments: Whether or not to preserve comments in the ouput SQL code.
|
||||||
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
|
|
||||||
Default: True
|
Default: True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -65,14 +70,16 @@ class Generator:
|
||||||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
||||||
}
|
}
|
||||||
|
|
||||||
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||||
# can override in dialects
|
|
||||||
CREATE_TRANSIENT = False
|
CREATE_TRANSIENT = False
|
||||||
# whether or not null ordering is supported in order by
|
|
||||||
|
# Whether or not null ordering is supported in order by
|
||||||
NULL_ORDERING_SUPPORTED = True
|
NULL_ORDERING_SUPPORTED = True
|
||||||
# always do union distinct or union all
|
|
||||||
|
# Always do union distinct or union all
|
||||||
EXPLICIT_UNION = False
|
EXPLICIT_UNION = False
|
||||||
# wrap derived values in parens, usually standard but spark doesn't support it
|
|
||||||
|
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||||
WRAP_DERIVED_VALUES = True
|
WRAP_DERIVED_VALUES = True
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
@ -80,7 +87,7 @@ class Generator:
|
||||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||||
}
|
}
|
||||||
|
|
||||||
TOKEN_MAPPING = {}
|
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||||
|
|
||||||
STRUCT_DELIMITER = ("<", ">")
|
STRUCT_DELIMITER = ("<", ">")
|
||||||
|
|
||||||
|
@ -96,6 +103,8 @@ class Generator:
|
||||||
exp.TableFormatProperty,
|
exp.TableFormatProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WITH_SEPARATED_COMMENTS = (exp.Select,)
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"time_mapping",
|
"time_mapping",
|
||||||
"time_trie",
|
"time_trie",
|
||||||
|
@ -122,7 +131,7 @@ class Generator:
|
||||||
"_escaped_quote_end",
|
"_escaped_quote_end",
|
||||||
"_leading_comma",
|
"_leading_comma",
|
||||||
"_max_text_width",
|
"_max_text_width",
|
||||||
"_annotations",
|
"_comments",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -148,7 +157,7 @@ class Generator:
|
||||||
max_unsupported=3,
|
max_unsupported=3,
|
||||||
leading_comma=False,
|
leading_comma=False,
|
||||||
max_text_width=80,
|
max_text_width=80,
|
||||||
annotations=True,
|
comments=True,
|
||||||
):
|
):
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
|
||||||
|
@ -177,7 +186,7 @@ class Generator:
|
||||||
self._escaped_quote_end = self.escape + self.quote_end
|
self._escaped_quote_end = self.escape + self.quote_end
|
||||||
self._leading_comma = leading_comma
|
self._leading_comma = leading_comma
|
||||||
self._max_text_width = max_text_width
|
self._max_text_width = max_text_width
|
||||||
self._annotations = annotations
|
self._comments = comments
|
||||||
|
|
||||||
def generate(self, expression):
|
def generate(self, expression):
|
||||||
"""
|
"""
|
||||||
|
@ -204,7 +213,6 @@ class Generator:
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
def unsupported(self, message):
|
def unsupported(self, message):
|
||||||
|
|
||||||
if self.unsupported_level == ErrorLevel.IMMEDIATE:
|
if self.unsupported_level == ErrorLevel.IMMEDIATE:
|
||||||
raise UnsupportedError(message)
|
raise UnsupportedError(message)
|
||||||
self.unsupported_messages.append(message)
|
self.unsupported_messages.append(message)
|
||||||
|
@ -215,9 +223,31 @@ class Generator:
|
||||||
def seg(self, sql, sep=" "):
|
def seg(self, sql, sep=" "):
|
||||||
return f"{self.sep(sep)}{sql}"
|
return f"{self.sep(sep)}{sql}"
|
||||||
|
|
||||||
|
def maybe_comment(self, sql, expression, single_line=False):
|
||||||
|
comment = expression.comment if self._comments else None
|
||||||
|
|
||||||
|
if not comment:
|
||||||
|
return sql
|
||||||
|
|
||||||
|
comment = " " + comment if comment[0].strip() else comment
|
||||||
|
comment = comment + " " if comment[-1].strip() else comment
|
||||||
|
|
||||||
|
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||||
|
return f"/*{comment}*/{self.sep()}{sql}"
|
||||||
|
|
||||||
|
if not self.pretty:
|
||||||
|
return f"{sql} /*{comment}*/"
|
||||||
|
|
||||||
|
if not NEWLINE_RE.search(comment):
|
||||||
|
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
||||||
|
|
||||||
|
return f"/*{comment}*/\n{sql}"
|
||||||
|
|
||||||
def wrap(self, expression):
|
def wrap(self, expression):
|
||||||
this_sql = self.indent(
|
this_sql = self.indent(
|
||||||
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
|
self.sql(expression)
|
||||||
|
if isinstance(expression, (exp.Select, exp.Union))
|
||||||
|
else self.sql(expression, "this"),
|
||||||
level=1,
|
level=1,
|
||||||
pad=0,
|
pad=0,
|
||||||
)
|
)
|
||||||
|
@ -251,7 +281,7 @@ class Generator:
|
||||||
for i, line in enumerate(lines)
|
for i, line in enumerate(lines)
|
||||||
)
|
)
|
||||||
|
|
||||||
def sql(self, expression, key=None):
|
def sql(self, expression, key=None, comment=True):
|
||||||
if not expression:
|
if not expression:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -264,29 +294,24 @@ class Generator:
|
||||||
transform = self.TRANSFORMS.get(expression.__class__)
|
transform = self.TRANSFORMS.get(expression.__class__)
|
||||||
|
|
||||||
if callable(transform):
|
if callable(transform):
|
||||||
return transform(self, expression)
|
sql = transform(self, expression)
|
||||||
if transform:
|
elif transform:
|
||||||
return transform
|
sql = transform
|
||||||
|
elif isinstance(expression, exp.Expression):
|
||||||
|
exp_handler_name = f"{expression.key}_sql"
|
||||||
|
|
||||||
if not isinstance(expression, exp.Expression):
|
if hasattr(self, exp_handler_name):
|
||||||
|
sql = getattr(self, exp_handler_name)(expression)
|
||||||
|
elif isinstance(expression, exp.Func):
|
||||||
|
sql = self.function_fallback_sql(expression)
|
||||||
|
elif isinstance(expression, exp.Property):
|
||||||
|
sql = self.property_sql(expression)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||||
|
else:
|
||||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||||
|
|
||||||
exp_handler_name = f"{expression.key}_sql"
|
return self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||||
if hasattr(self, exp_handler_name):
|
|
||||||
return getattr(self, exp_handler_name)(expression)
|
|
||||||
|
|
||||||
if isinstance(expression, exp.Func):
|
|
||||||
return self.function_fallback_sql(expression)
|
|
||||||
|
|
||||||
if isinstance(expression, exp.Property):
|
|
||||||
return self.property_sql(expression)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
|
||||||
|
|
||||||
def annotation_sql(self, expression):
|
|
||||||
if self._annotations and self.pretty:
|
|
||||||
return f"{self.sql(expression, 'expression')} # {expression.name}"
|
|
||||||
return self.sql(expression, "expression")
|
|
||||||
|
|
||||||
def uncache_sql(self, expression):
|
def uncache_sql(self, expression):
|
||||||
table = self.sql(expression, "this")
|
table = self.sql(expression, "this")
|
||||||
|
@ -371,7 +396,9 @@ class Generator:
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
transient = (
|
||||||
|
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||||
|
)
|
||||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||||
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
|
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
|
||||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||||
|
@ -434,7 +461,9 @@ class Generator:
|
||||||
def delete_sql(self, expression):
|
def delete_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
using_sql = (
|
using_sql = (
|
||||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
|
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
|
||||||
|
if expression.args.get("using")
|
||||||
|
else ""
|
||||||
)
|
)
|
||||||
where_sql = self.sql(expression, "where")
|
where_sql = self.sql(expression, "where")
|
||||||
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
|
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
|
||||||
|
@ -481,15 +510,18 @@ class Generator:
|
||||||
return f"{this} ON {table} {columns}"
|
return f"{this} ON {table} {columns}"
|
||||||
|
|
||||||
def identifier_sql(self, expression):
|
def identifier_sql(self, expression):
|
||||||
value = expression.name
|
text = expression.name
|
||||||
value = value.lower() if self.normalize else value
|
text = text.lower() if self.normalize else text
|
||||||
if expression.args.get("quoted") or self.identify:
|
if expression.args.get("quoted") or self.identify:
|
||||||
return f"{self.identifier_start}{value}{self.identifier_end}"
|
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||||
return value
|
return text
|
||||||
|
|
||||||
def partition_sql(self, expression):
|
def partition_sql(self, expression):
|
||||||
keys = csv(
|
keys = csv(
|
||||||
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
|
*[
|
||||||
|
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
|
||||||
|
for prop in expression.this
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return f"PARTITION({keys})"
|
return f"PARTITION({keys})"
|
||||||
|
|
||||||
|
@ -504,9 +536,9 @@ class Generator:
|
||||||
elif p_class in self.ROOT_PROPERTIES:
|
elif p_class in self.ROOT_PROPERTIES:
|
||||||
root_properties.append(p)
|
root_properties.append(p)
|
||||||
|
|
||||||
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
|
return self.root_properties(
|
||||||
exp.Properties(expressions=with_properties)
|
exp.Properties(expressions=root_properties)
|
||||||
)
|
) + self.with_properties(exp.Properties(expressions=with_properties))
|
||||||
|
|
||||||
def root_properties(self, properties):
|
def root_properties(self, properties):
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
|
@ -551,7 +583,9 @@ class Generator:
|
||||||
|
|
||||||
this = f"{this}{self.sql(expression, 'this')}"
|
this = f"{this}{self.sql(expression, 'this')}"
|
||||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||||
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
partition_sql = (
|
||||||
|
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||||
|
)
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
sep = self.sep() if partition_sql else ""
|
sep = self.sep() if partition_sql else ""
|
||||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||||
|
@ -669,7 +703,9 @@ class Generator:
|
||||||
def group_sql(self, expression):
|
def group_sql(self, expression):
|
||||||
group_by = self.op_expressions("GROUP BY", expression)
|
group_by = self.op_expressions("GROUP BY", expression)
|
||||||
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
||||||
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
grouping_sets = (
|
||||||
|
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||||
|
)
|
||||||
cube = self.expressions(expression, key="cube", indent=False)
|
cube = self.expressions(expression, key="cube", indent=False)
|
||||||
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
||||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||||
|
@ -711,10 +747,10 @@ class Generator:
|
||||||
this_sql = self.sql(expression, "this")
|
this_sql = self.sql(expression, "this")
|
||||||
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
|
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
|
||||||
|
|
||||||
def lambda_sql(self, expression):
|
def lambda_sql(self, expression, arrow_sep="->"):
|
||||||
args = self.expressions(expression, flat=True)
|
args = self.expressions(expression, flat=True)
|
||||||
args = f"({args})" if len(args.split(",")) > 1 else args
|
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||||
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
|
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
|
||||||
|
|
||||||
def lateral_sql(self, expression):
|
def lateral_sql(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
@ -748,7 +784,7 @@ class Generator:
|
||||||
if self._replace_backslash:
|
if self._replace_backslash:
|
||||||
text = text.replace("\\", "\\\\")
|
text = text.replace("\\", "\\\\")
|
||||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||||
return f"{self.quote_start}{text}{self.quote_end}"
|
text = f"{self.quote_start}{text}{self.quote_end}"
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def loaddata_sql(self, expression):
|
def loaddata_sql(self, expression):
|
||||||
|
@ -796,13 +832,21 @@ class Generator:
|
||||||
|
|
||||||
sort_order = " DESC" if desc else ""
|
sort_order = " DESC" if desc else ""
|
||||||
nulls_sort_change = ""
|
nulls_sort_change = ""
|
||||||
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
|
if nulls_first and (
|
||||||
|
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
|
||||||
|
):
|
||||||
nulls_sort_change = " NULLS FIRST"
|
nulls_sort_change = " NULLS FIRST"
|
||||||
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
|
elif (
|
||||||
|
nulls_last
|
||||||
|
and ((asc and nulls_are_small) or (desc and nulls_are_large))
|
||||||
|
and not nulls_are_last
|
||||||
|
):
|
||||||
nulls_sort_change = " NULLS LAST"
|
nulls_sort_change = " NULLS LAST"
|
||||||
|
|
||||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||||
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
|
self.unsupported(
|
||||||
|
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
|
||||||
|
)
|
||||||
nulls_sort_change = ""
|
nulls_sort_change = ""
|
||||||
|
|
||||||
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
||||||
|
@ -835,7 +879,7 @@ class Generator:
|
||||||
sql = self.query_modifiers(
|
sql = self.query_modifiers(
|
||||||
expression,
|
expression,
|
||||||
f"SELECT{hint}{distinct}{expressions}",
|
f"SELECT{hint}{distinct}{expressions}",
|
||||||
self.sql(expression, "from"),
|
self.sql(expression, "from", comment=False),
|
||||||
)
|
)
|
||||||
return self.prepend_ctes(expression, sql)
|
return self.prepend_ctes(expression, sql)
|
||||||
|
|
||||||
|
@ -858,6 +902,13 @@ class Generator:
|
||||||
def parameter_sql(self, expression):
|
def parameter_sql(self, expression):
|
||||||
return f"@{self.sql(expression, 'this')}"
|
return f"@{self.sql(expression, 'this')}"
|
||||||
|
|
||||||
|
def sessionparameter_sql(self, expression):
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
kind = expression.text("kind")
|
||||||
|
if kind:
|
||||||
|
kind = f"{kind}."
|
||||||
|
return f"@@{kind}{this}"
|
||||||
|
|
||||||
def placeholder_sql(self, expression):
|
def placeholder_sql(self, expression):
|
||||||
return f":{expression.name}" if expression.name else "?"
|
return f":{expression.name}" if expression.name else "?"
|
||||||
|
|
||||||
|
@ -931,7 +982,10 @@ class Generator:
|
||||||
def window_spec_sql(self, expression):
|
def window_spec_sql(self, expression):
|
||||||
kind = self.sql(expression, "kind")
|
kind = self.sql(expression, "kind")
|
||||||
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
||||||
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
|
end = (
|
||||||
|
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
|
||||||
|
or "CURRENT ROW"
|
||||||
|
)
|
||||||
return f"{kind} BETWEEN {start} AND {end}"
|
return f"{kind} BETWEEN {start} AND {end}"
|
||||||
|
|
||||||
def withingroup_sql(self, expression):
|
def withingroup_sql(self, expression):
|
||||||
|
@ -1020,7 +1074,9 @@ class Generator:
|
||||||
return f"UNIQUE ({columns})"
|
return f"UNIQUE ({columns})"
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self, expression):
|
||||||
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
|
return self.case_sql(
|
||||||
|
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||||
|
)
|
||||||
|
|
||||||
def in_sql(self, expression):
|
def in_sql(self, expression):
|
||||||
query = expression.args.get("query")
|
query = expression.args.get("query")
|
||||||
|
@ -1196,6 +1252,12 @@ class Generator:
|
||||||
def neq_sql(self, expression):
|
def neq_sql(self, expression):
|
||||||
return self.binary(expression, "<>")
|
return self.binary(expression, "<>")
|
||||||
|
|
||||||
|
def nullsafeeq_sql(self, expression):
|
||||||
|
return self.binary(expression, "IS NOT DISTINCT FROM")
|
||||||
|
|
||||||
|
def nullsafeneq_sql(self, expression):
|
||||||
|
return self.binary(expression, "IS DISTINCT FROM")
|
||||||
|
|
||||||
def or_sql(self, expression):
|
def or_sql(self, expression):
|
||||||
return self.connector_sql(expression, "OR")
|
return self.connector_sql(expression, "OR")
|
||||||
|
|
||||||
|
@ -1205,6 +1267,9 @@ class Generator:
|
||||||
def trycast_sql(self, expression):
|
def trycast_sql(self, expression):
|
||||||
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||||
|
|
||||||
|
def use_sql(self, expression):
|
||||||
|
return f"USE {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def binary(self, expression, op):
|
def binary(self, expression, op):
|
||||||
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||||
|
|
||||||
|
@ -1240,17 +1305,27 @@ class Generator:
|
||||||
if flat:
|
if flat:
|
||||||
return sep.join(self.sql(e) for e in expressions)
|
return sep.join(self.sql(e) for e in expressions)
|
||||||
|
|
||||||
sql = (self.sql(e) for e in expressions)
|
num_sqls = len(expressions)
|
||||||
# the only time leading_comma changes the output is if pretty print is enabled
|
|
||||||
if self._leading_comma and self.pretty:
|
|
||||||
pad = " " * self.pad
|
|
||||||
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
|
|
||||||
else:
|
|
||||||
expressions = self.sep(sep).join(sql)
|
|
||||||
|
|
||||||
if indent:
|
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
|
||||||
return self.indent(expressions, skip_first=False)
|
pad = " " * self.pad
|
||||||
return expressions
|
stripped_sep = sep.strip()
|
||||||
|
|
||||||
|
result_sqls = []
|
||||||
|
for i, e in enumerate(expressions):
|
||||||
|
sql = self.sql(e, comment=False)
|
||||||
|
comment = self.maybe_comment("", e, single_line=True)
|
||||||
|
|
||||||
|
if self.pretty:
|
||||||
|
if self._leading_comma:
|
||||||
|
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
|
||||||
|
else:
|
||||||
|
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
|
||||||
|
else:
|
||||||
|
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
|
||||||
|
|
||||||
|
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||||
|
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
||||||
|
|
||||||
def op_expressions(self, op, expression, flat=False):
|
def op_expressions(self, op, expression, flat=False):
|
||||||
expressions_sql = self.expressions(expression, flat=flat)
|
expressions_sql = self.expressions(expression, flat=flat)
|
||||||
|
@ -1264,7 +1339,9 @@ class Generator:
|
||||||
def set_operation(self, expression, op):
|
def set_operation(self, expression, op):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
op = self.seg(op)
|
op = self.seg(op)
|
||||||
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
|
return self.query_modifiers(
|
||||||
|
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
||||||
|
)
|
||||||
|
|
||||||
def token_sql(self, token_type):
|
def token_sql(self, token_type):
|
||||||
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
||||||
|
@ -1283,3 +1360,6 @@ class Generator:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
expressions = self.expressions(expression, flat=True)
|
expressions = self.expressions(expression, flat=True)
|
||||||
return f"{this}({expressions})"
|
return f"{this}({expressions})"
|
||||||
|
|
||||||
|
def kwarg_sql(self, expression):
|
||||||
|
return self.binary(expression, "=>")
|
||||||
|
|
|
@ -1,48 +1,125 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import typing as t
|
import typing as t
|
||||||
|
from collections.abc import Collection
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.expressions import Expression, Table
|
||||||
|
|
||||||
|
T = t.TypeVar("T")
|
||||||
|
E = t.TypeVar("E", bound=Expression)
|
||||||
|
|
||||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||||
|
PYTHON_VERSION = sys.version_info[:2]
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
|
||||||
class AutoName(Enum):
|
class AutoName(Enum):
|
||||||
def _generate_next_value_(name, _start, _count, _last_values):
|
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
|
||||||
|
|
||||||
|
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def list_get(arr, index):
|
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
||||||
|
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
|
||||||
try:
|
try:
|
||||||
return arr[index]
|
return seq[index]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def ensure_list(value: t.Collection[T]) -> t.List[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def ensure_list(value: T) -> t.List[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def ensure_list(value):
|
def ensure_list(value):
|
||||||
|
"""
|
||||||
|
Ensures that a value is a list, otherwise casts or wraps it into one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: the value of interest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
|
||||||
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
return []
|
return []
|
||||||
return value if isinstance(value, (list, tuple, set)) else [value]
|
elif isinstance(value, (list, tuple)):
|
||||||
|
return list(value)
|
||||||
|
|
||||||
|
return [value]
|
||||||
|
|
||||||
|
|
||||||
def csv(*args, sep=", "):
|
@t.overload
|
||||||
|
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def ensure_collection(value: T) -> t.Collection[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_collection(value):
|
||||||
|
"""
|
||||||
|
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: the value of interest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value if it's a collection, or else the value wrapped in a list.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
return (
|
||||||
|
value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def csv(*args, sep: str = ", ") -> str:
|
||||||
|
"""
|
||||||
|
Formats any number of string arguments as CSV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: the string arguments to format.
|
||||||
|
sep: the argument separator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The arguments formatted as a CSV string.
|
||||||
|
"""
|
||||||
return sep.join(arg for arg in args if arg)
|
return sep.join(arg for arg in args if arg)
|
||||||
|
|
||||||
|
|
||||||
def subclasses(module_name, classes, exclude=()):
|
def subclasses(
|
||||||
|
module_name: str,
|
||||||
|
classes: t.Type | t.Tuple[t.Type, ...],
|
||||||
|
exclude: t.Type | t.Tuple[t.Type, ...] = (),
|
||||||
|
) -> t.List[t.Type]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
|
Returns all subclasses for a collection of classes, possibly excluding some of them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_name (str): The name of the module to search for subclasses in.
|
module_name: the name of the module to search for subclasses in.
|
||||||
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
|
classes: class(es) we want to find the subclasses of.
|
||||||
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
|
exclude: class(es) we want to exclude from the returned list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of all the target subclasses.
|
The target subclasses.
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
obj
|
obj
|
||||||
|
@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def apply_index_offset(expressions, offset):
|
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
||||||
|
"""
|
||||||
|
Applies an offset to a given integer literal expression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expressions: the expression the offset will be applied to, wrapped in a list.
|
||||||
|
offset: the offset that will be applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original expression with the offset applied to it, wrapped in a list. If the provided
|
||||||
|
`expressions` argument contains more than one expressions, it's returned unaffected.
|
||||||
|
"""
|
||||||
if not offset or len(expressions) != 1:
|
if not offset or len(expressions) != 1:
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
|
@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
|
||||||
logger.warning("Applying array index offset (%s)", offset)
|
logger.warning("Applying array index offset (%s)", offset)
|
||||||
expression.args["this"] = str(int(expression.args["this"]) + offset)
|
expression.args["this"] = str(int(expression.args["this"]) + offset)
|
||||||
return [expression]
|
return [expression]
|
||||||
|
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
|
|
||||||
def camel_to_snake_case(name):
|
def camel_to_snake_case(name: str) -> str:
|
||||||
|
"""Converts `name` from camelCase to snake_case and returns the result."""
|
||||||
return CAMEL_CASE_PATTERN.sub("_", name).upper()
|
return CAMEL_CASE_PATTERN.sub("_", name).upper()
|
||||||
|
|
||||||
|
|
||||||
def while_changing(expression, func):
|
def while_changing(
|
||||||
|
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
|
||||||
|
) -> E:
|
||||||
|
"""
|
||||||
|
Applies a transformation to a given expression until a fix point is reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: the expression to be transformed.
|
||||||
|
func: the transformation to be applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transformed expression.
|
||||||
|
"""
|
||||||
while True:
|
while True:
|
||||||
start = hash(expression)
|
start = hash(expression)
|
||||||
expression = func(expression)
|
expression = func(expression)
|
||||||
|
@ -80,10 +182,19 @@ def while_changing(expression, func):
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def tsort(dag):
|
def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
|
||||||
|
"""
|
||||||
|
Sorts a given directed acyclic graph in topological order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag: the graph to be sorted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list that contains all of the graph's nodes in topological order.
|
||||||
|
"""
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
def visit(node, visited):
|
def visit(node: T, visited: t.Set[T]) -> None:
|
||||||
if node in result:
|
if node in result:
|
||||||
return
|
return
|
||||||
if node in visited:
|
if node in visited:
|
||||||
|
@ -103,10 +214,8 @@ def tsort(dag):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def open_file(file_name):
|
def open_file(file_name: str) -> t.TextIO:
|
||||||
"""
|
"""Open a file that may be compressed as gzip and return it in universal newline mode."""
|
||||||
Open a file that may be compressed as gzip and return in newline mode.
|
|
||||||
"""
|
|
||||||
with open(file_name, "rb") as f:
|
with open(file_name, "rb") as f:
|
||||||
gzipped = f.read(2) == b"\x1f\x8b"
|
gzipped = f.read(2) == b"\x1f\x8b"
|
||||||
|
|
||||||
|
@ -119,14 +228,14 @@ def open_file(file_name):
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def csv_reader(table):
|
def csv_reader(table: Table) -> t.Any:
|
||||||
"""
|
"""
|
||||||
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
|
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (exp.Table): A table expression with an anonymous function READ_CSV in it
|
table: a `Table` expression with an anonymous function `READ_CSV` in it.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
A python csv reader.
|
A python csv reader.
|
||||||
"""
|
"""
|
||||||
file, *args = table.this.expressions
|
file, *args = table.this.expressions
|
||||||
|
@ -147,13 +256,16 @@ def csv_reader(table):
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
def find_new_name(taken, base):
|
def find_new_name(taken: t.Sequence[str], base: str) -> str:
|
||||||
"""
|
"""
|
||||||
Searches for a new name.
|
Searches for a new name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
taken (Sequence[str]): set of taken names
|
taken: a collection of taken names.
|
||||||
base (str): base name to alter
|
base: base name to alter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new, available name.
|
||||||
"""
|
"""
|
||||||
if base not in taken:
|
if base not in taken:
|
||||||
return base
|
return base
|
||||||
|
@ -163,22 +275,26 @@ def find_new_name(taken, base):
|
||||||
while new in taken:
|
while new in taken:
|
||||||
i += 1
|
i += 1
|
||||||
new = f"{base}_{i}"
|
new = f"{base}_{i}"
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
def object_to_dict(obj, **kwargs):
|
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
|
||||||
|
"""Returns a dictionary created from an object's attributes."""
|
||||||
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
|
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
|
def split_num_words(
|
||||||
|
value: str, sep: str, min_num_words: int, fill_from_start: bool = True
|
||||||
|
) -> t.List[t.Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Perform a split on a value and return N words as a result with None used for words that don't exist.
|
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: The value to be split
|
value: the value to be split.
|
||||||
sep: The value to use to split on
|
sep: the value to use to split on.
|
||||||
min_num_words: The minimum number of words that are going to be in the result
|
min_num_words: the minimum number of words that are going to be in the result.
|
||||||
fill_from_start: Indicates that if None values should be inserted at the start or end of the list
|
fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> split_num_words("db.table", ".", 3)
|
>>> split_num_words("db.table", ".", 3)
|
||||||
|
@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
|
||||||
['db', 'table', None]
|
['db', 'table', None]
|
||||||
>>> split_num_words("db.table", ".", 1)
|
>>> split_num_words("db.table", ".", 1)
|
||||||
['db', 'table']
|
['db', 'table']
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of words returned by `split`, possibly augmented by a number of `None` values.
|
||||||
"""
|
"""
|
||||||
words = value.split(sep)
|
words = value.split(sep)
|
||||||
if fill_from_start:
|
if fill_from_start:
|
||||||
|
@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
|
||||||
|
|
||||||
def is_iterable(value: t.Any) -> bool:
|
def is_iterable(value: t.Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if the value is an iterable but does not include strings and bytes
|
Checks if the value is an iterable, excluding the types `str` and `bytes`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> is_iterable([1,2])
|
>>> is_iterable([1,2])
|
||||||
|
@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
|
||||||
False
|
False
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: The value to check if it is an interable
|
value: the value to check if it is an iterable.
|
||||||
|
|
||||||
Returns: Bool indicating if it is an iterable
|
Returns:
|
||||||
|
A `bool` value indicating if it is an iterable.
|
||||||
"""
|
"""
|
||||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
||||||
|
|
||||||
|
|
||||||
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
|
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
|
||||||
"""
|
"""
|
||||||
Flattens a list that can contain both iterables and non-iterable elements
|
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
|
||||||
|
type `str` and `bytes` are not regarded as iterables.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> list(flatten([[1, 2], 3]))
|
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
|
||||||
[1, 2, 3]
|
[1, 2, 3, 4, 5, 'bla']
|
||||||
>>> list(flatten([1, 2, 3]))
|
>>> list(flatten([1, 2, 3]))
|
||||||
[1, 2, 3]
|
[1, 2, 3]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: The value to be flattened
|
values: the value to be flattened.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
Yields non-iterable elements (not including str or byte as iterable)
|
Non-iterable elements in `values`.
|
||||||
"""
|
"""
|
||||||
for value in values:
|
for value in values:
|
||||||
if is_iterable(value):
|
if is_iterable(value):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.helper import ensure_list, subclasses
|
from sqlglot.helper import ensure_collection, ensure_list, subclasses
|
||||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||||
from sqlglot.schema import ensure_schema
|
from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
|
@ -48,35 +48,65 @@ class TypeAnnotator:
|
||||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.BIGINT
|
||||||
|
),
|
||||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
|
||||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
expr, exp.DataType.Type.DATETIME
|
||||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
),
|
||||||
|
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
|
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
|
||||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
expr, exp.DataType.Type.DATETIME
|
||||||
|
),
|
||||||
|
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.DATETIME
|
||||||
|
),
|
||||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
|
||||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
|
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
|
||||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
expr, exp.DataType.Type.DATE
|
||||||
|
),
|
||||||
|
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.VARCHAR
|
||||||
|
),
|
||||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
|
||||||
|
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||||
|
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||||
|
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||||
|
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
|
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.VARCHAR
|
||||||
|
),
|
||||||
|
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.VARCHAR
|
||||||
|
),
|
||||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
|
@ -88,32 +118,52 @@ class TypeAnnotator:
|
||||||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
|
||||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
expr, exp.DataType.Type.DOUBLE
|
||||||
|
),
|
||||||
|
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.BOOLEAN
|
||||||
|
),
|
||||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
exp.StrToTime: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
expr, exp.DataType.Type.VARCHAR
|
||||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
),
|
||||||
|
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.DATE
|
||||||
|
),
|
||||||
|
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.VARCHAR
|
||||||
|
),
|
||||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
|
||||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
expr, exp.DataType.Type.TIMESTAMP
|
||||||
|
),
|
||||||
|
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.VARCHAR
|
||||||
|
),
|
||||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
exp.VariancePop: lambda self, expr: self._annotate_with_type(
|
||||||
|
expr, exp.DataType.Type.DOUBLE
|
||||||
|
),
|
||||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||||
}
|
}
|
||||||
|
@ -124,7 +174,11 @@ class TypeAnnotator:
|
||||||
exp.DataType.Type.TEXT: set(),
|
exp.DataType.Type.TEXT: set(),
|
||||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||||
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
exp.DataType.Type.NCHAR: {
|
||||||
|
exp.DataType.Type.VARCHAR,
|
||||||
|
exp.DataType.Type.NVARCHAR,
|
||||||
|
exp.DataType.Type.TEXT,
|
||||||
|
},
|
||||||
exp.DataType.Type.CHAR: {
|
exp.DataType.Type.CHAR: {
|
||||||
exp.DataType.Type.NCHAR,
|
exp.DataType.Type.NCHAR,
|
||||||
exp.DataType.Type.VARCHAR,
|
exp.DataType.Type.VARCHAR,
|
||||||
|
@ -135,7 +189,11 @@ class TypeAnnotator:
|
||||||
exp.DataType.Type.DOUBLE: set(),
|
exp.DataType.Type.DOUBLE: set(),
|
||||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||||
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
exp.DataType.Type.BIGINT: {
|
||||||
|
exp.DataType.Type.DECIMAL,
|
||||||
|
exp.DataType.Type.FLOAT,
|
||||||
|
exp.DataType.Type.DOUBLE,
|
||||||
|
},
|
||||||
exp.DataType.Type.INT: {
|
exp.DataType.Type.INT: {
|
||||||
exp.DataType.Type.BIGINT,
|
exp.DataType.Type.BIGINT,
|
||||||
exp.DataType.Type.DECIMAL,
|
exp.DataType.Type.DECIMAL,
|
||||||
|
@ -160,7 +218,10 @@ class TypeAnnotator:
|
||||||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||||
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
|
exp.DataType.Type.TIMESTAMP: {
|
||||||
|
exp.DataType.Type.TIMESTAMPTZ,
|
||||||
|
exp.DataType.Type.TIMESTAMPLTZ,
|
||||||
|
},
|
||||||
exp.DataType.Type.DATETIME: {
|
exp.DataType.Type.DATETIME: {
|
||||||
exp.DataType.Type.TIMESTAMP,
|
exp.DataType.Type.TIMESTAMP,
|
||||||
exp.DataType.Type.TIMESTAMPTZ,
|
exp.DataType.Type.TIMESTAMPTZ,
|
||||||
|
@ -219,7 +280,7 @@ class TypeAnnotator:
|
||||||
|
|
||||||
def _annotate_args(self, expression):
|
def _annotate_args(self, expression):
|
||||||
for value in expression.args.values():
|
for value in expression.args.values():
|
||||||
for v in ensure_list(value):
|
for v in ensure_collection(value):
|
||||||
self._maybe_annotate(v)
|
self._maybe_annotate(v)
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
@ -243,7 +304,9 @@ class TypeAnnotator:
|
||||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||||
expression.type = exp.DataType.Type.NULL
|
expression.type = exp.DataType.Type.NULL
|
||||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||||
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
expression.type = exp.DataType.build(
|
||||||
|
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
expression.type = exp.DataType.Type.BOOLEAN
|
expression.type = exp.DataType.Type.BOOLEAN
|
||||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||||
|
@ -276,3 +339,17 @@ class TypeAnnotator:
|
||||||
def _annotate_with_type(self, expression, target_type):
|
def _annotate_with_type(self, expression, target_type):
|
||||||
expression.type = target_type
|
expression.type = target_type
|
||||||
return self._annotate_args(expression)
|
return self._annotate_args(expression)
|
||||||
|
|
||||||
|
def _annotate_by_args(self, expression, *args):
|
||||||
|
self._annotate_args(expression)
|
||||||
|
expressions = []
|
||||||
|
for arg in args:
|
||||||
|
arg_expr = expression.args.get(arg)
|
||||||
|
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
|
||||||
|
|
||||||
|
last_datatype = None
|
||||||
|
for expr in expressions:
|
||||||
|
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||||
|
|
||||||
|
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
||||||
|
return expression
|
||||||
|
|
|
@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
|
||||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||||
else:
|
else:
|
||||||
on_clause_columns = set()
|
on_clause_columns = set()
|
||||||
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
|
return any(
|
||||||
|
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_joined_on_all_unique_outputs(scope, join):
|
def _is_joined_on_all_unique_outputs(scope, join):
|
||||||
|
|
|
@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
|
||||||
|
|
||||||
# All table names are taken
|
# All table names are taken
|
||||||
for scope in root.traverse():
|
for scope in root.traverse():
|
||||||
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
|
taken.update(
|
||||||
|
{
|
||||||
|
source.name: source
|
||||||
|
for _, source in scope.sources.items()
|
||||||
|
if isinstance(source, exp.Table)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Map of Expression->alias
|
# Map of Expression->alias
|
||||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||||
|
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
|
||||||
new_ctes.append(cte_scope.expression.parent)
|
new_ctes.append(cte_scope.expression.parent)
|
||||||
|
|
||||||
# Now append the rest
|
# Now append the rest
|
||||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
|
for scope in itertools.chain(
|
||||||
|
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
|
||||||
|
):
|
||||||
for child_scope in scope.traverse():
|
for child_scope in scope.traverse():
|
||||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||||
if new_cte:
|
if new_cte:
|
||||||
|
|
|
@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||||
unmergable_window_columns = [
|
unmergable_window_columns = [
|
||||||
column
|
column
|
||||||
for column in outer_scope.columns
|
for column in outer_scope.columns
|
||||||
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
|
if column.find_ancestor(
|
||||||
|
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
|
||||||
|
)
|
||||||
]
|
]
|
||||||
window_expressions_in_unmergable = [
|
window_expressions_in_unmergable = [
|
||||||
column
|
column
|
||||||
|
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||||
and not (
|
and not (
|
||||||
isinstance(from_or_join, exp.From)
|
isinstance(from_or_join, exp.From)
|
||||||
and inner_select.args.get("where")
|
and inner_select.args.get("where")
|
||||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
and any(
|
||||||
|
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||||
|
)
|
||||||
)
|
)
|
||||||
and not _is_a_window_expression_in_unmergable_operation()
|
and not _is_a_window_expression_in_unmergable_operation()
|
||||||
)
|
)
|
||||||
|
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
||||||
if table.alias_or_name == node_to_replace.alias_or_name:
|
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||||
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
||||||
outer_scope.remove_source(alias)
|
outer_scope.remove_source(alias)
|
||||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
outer_scope.add_source(
|
||||||
|
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||||
|
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
|
||||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
any(
|
||||||
|
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
|
||||||
|
)
|
||||||
or len(outer_scope.selected_sources) != 1
|
or len(outer_scope.selected_sources) != 1
|
||||||
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
||||||
):
|
):
|
||||||
|
|
|
@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
|
||||||
Returns:
|
Returns:
|
||||||
int: difference
|
int: difference
|
||||||
"""
|
"""
|
||||||
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
|
return sum(_predicate_lengths(expression, dnf)) - (
|
||||||
|
len(list(expression.find_all(exp.Connector))) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _predicate_lengths(expression, dnf):
|
def _predicate_lengths(expression, dnf):
|
||||||
|
|
|
@ -68,4 +68,8 @@ def normalize(expression):
|
||||||
|
|
||||||
|
|
||||||
def other_table_names(join, exclude):
|
def other_table_names(join, exclude):
|
||||||
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
|
return [
|
||||||
|
name
|
||||||
|
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||||
|
if name != exclude
|
||||||
|
]
|
||||||
|
|
|
@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
||||||
|
|
||||||
# Find any additional rule parameters, beyond `expression`
|
# Find any additional rule parameters, beyond `expression`
|
||||||
rule_params = rule.__code__.co_varnames
|
rule_params = rule.__code__.co_varnames
|
||||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
rule_kwargs = {
|
||||||
|
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||||
|
}
|
||||||
expression = rule(expression, **rule_kwargs)
|
expression = rule(expression, **rule_kwargs)
|
||||||
return expression
|
return expression
|
||||||
|
|
|
@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
|
||||||
condition = condition.replace(simplify(condition))
|
condition = condition.replace(simplify(condition))
|
||||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||||
|
|
||||||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
predicates = list(
|
||||||
|
condition.flatten()
|
||||||
|
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
||||||
|
else [condition]
|
||||||
|
)
|
||||||
|
|
||||||
if cnf_like:
|
if cnf_like:
|
||||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||||
|
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||||
for column in predicate.find_all(exp.Column):
|
for column in predicate.find_all(exp.Column):
|
||||||
if column.table == table:
|
if column.table == table:
|
||||||
condition = column.find_ancestor(exp.Condition)
|
condition = column.find_ancestor(exp.Condition)
|
||||||
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
|
predicate_condition = (
|
||||||
|
exp.and_(predicate_condition, condition)
|
||||||
|
if predicate_condition
|
||||||
|
else condition
|
||||||
|
)
|
||||||
|
|
||||||
if predicate_condition:
|
if predicate_condition:
|
||||||
conditions[table] = (
|
conditions[table] = (
|
||||||
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
|
exp.or_(conditions[table], predicate_condition)
|
||||||
|
if table in conditions
|
||||||
|
else predicate_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, node in nodes.items():
|
for name, node in nodes.items():
|
||||||
|
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
||||||
nodes[table] = node
|
nodes[table] = node
|
||||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||||
# We can't push down window expressions
|
# We can't push down window expressions
|
||||||
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
|
has_window_expression = any(
|
||||||
|
select for select in node.selects if select.find(exp.Window)
|
||||||
|
)
|
||||||
# we can't push down predicates to select statements if they are referenced in
|
# we can't push down predicates to select statements if they are referenced in
|
||||||
# multiple places.
|
# multiple places.
|
||||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
|
if (
|
||||||
|
not node.args.get("group")
|
||||||
|
and scope_ref_count[id(source)] < 2
|
||||||
|
and not has_window_expression
|
||||||
|
):
|
||||||
nodes[table] = node
|
nodes[table] = node
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
|
||||||
|
|
||||||
def _replace_alias(column):
|
def _replace_alias(column):
|
||||||
if isinstance(column, exp.Column) and column.name in aliases:
|
if isinstance(column, exp.Column) and column.name in aliases:
|
||||||
return aliases[column.name]
|
return aliases[column.name].copy()
|
||||||
return column
|
return column
|
||||||
|
|
||||||
return predicate.transform(_replace_alias)
|
return predicate.transform(_replace_alias)
|
||||||
|
|
|
@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
|
||||||
|
|
||||||
|
|
||||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||||
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
|
new_selections = [
|
||||||
|
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||||
|
]
|
||||||
if not new_selections:
|
if not new_selections:
|
||||||
new_selections.append(DEFAULT_SELECTION)
|
new_selections.append(DEFAULT_SELECTION)
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
|
|
|
@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
|
||||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||||
for ordered in scope.find_all(exp.Ordered):
|
for ordered in scope.find_all(exp.Ordered):
|
||||||
for column in ordered.find_all(exp.Column):
|
for column in ordered.find_all(exp.Column):
|
||||||
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
|
if (
|
||||||
|
not column.table
|
||||||
|
and column.parent is not ordered
|
||||||
|
and column.name in resolver.all_columns
|
||||||
|
):
|
||||||
columns_missing_from_scope.append(column)
|
columns_missing_from_scope.append(column)
|
||||||
|
|
||||||
# Determine whether each reference in the having clause is to a column or an alias.
|
# Determine whether each reference in the having clause is to a column or an alias.
|
||||||
for having in scope.find_all(exp.Having):
|
for having in scope.find_all(exp.Having):
|
||||||
for column in having.find_all(exp.Column):
|
for column in having.find_all(exp.Column):
|
||||||
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
|
if (
|
||||||
|
not column.table
|
||||||
|
and column.find_ancestor(exp.AggFunc)
|
||||||
|
and column.name in resolver.all_columns
|
||||||
|
):
|
||||||
columns_missing_from_scope.append(column)
|
columns_missing_from_scope.append(column)
|
||||||
|
|
||||||
for column in columns_missing_from_scope:
|
for column in columns_missing_from_scope:
|
||||||
|
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
|
||||||
"""Ensure all output columns are aliased"""
|
"""Ensure all output columns are aliased"""
|
||||||
new_selections = []
|
new_selections = []
|
||||||
|
|
||||||
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
for i, (selection, aliased_column) in enumerate(
|
||||||
|
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||||
|
):
|
||||||
if isinstance(selection, exp.Column):
|
if isinstance(selection, exp.Column):
|
||||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||||
alias_ = alias(exp.column(""), alias=selection.name)
|
alias_ = alias(exp.column(""), alias=selection.name)
|
||||||
|
@ -343,14 +353,18 @@ class _Resolver:
|
||||||
(str) table name
|
(str) table name
|
||||||
"""
|
"""
|
||||||
if self._unambiguous_columns is None:
|
if self._unambiguous_columns is None:
|
||||||
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||||
|
self._get_all_source_columns()
|
||||||
|
)
|
||||||
return self._unambiguous_columns.get(column_name)
|
return self._unambiguous_columns.get(column_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_columns(self):
|
def all_columns(self):
|
||||||
"""All available columns of all sources in this scope"""
|
"""All available columns of all sources in this scope"""
|
||||||
if self._all_columns is None:
|
if self._all_columns is None:
|
||||||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
self._all_columns = set(
|
||||||
|
column for columns in self._get_all_source_columns().values() for column in columns
|
||||||
|
)
|
||||||
return self._all_columns
|
return self._all_columns
|
||||||
|
|
||||||
def get_source_columns(self, name, only_visible=False):
|
def get_source_columns(self, name, only_visible=False):
|
||||||
|
@ -377,7 +391,9 @@ class _Resolver:
|
||||||
|
|
||||||
def _get_all_source_columns(self):
|
def _get_all_source_columns(self):
|
||||||
if self._source_columns is None:
|
if self._source_columns is None:
|
||||||
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
self._source_columns = {
|
||||||
|
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||||
|
}
|
||||||
return self._source_columns
|
return self._source_columns
|
||||||
|
|
||||||
def _get_unambiguous_columns(self, source_columns):
|
def _get_unambiguous_columns(self, source_columns):
|
||||||
|
|
|
@ -226,7 +226,9 @@ class Scope:
|
||||||
self._ensure_collected()
|
self._ensure_collected()
|
||||||
columns = self._raw_columns
|
columns = self._raw_columns
|
||||||
|
|
||||||
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
|
external_columns = [
|
||||||
|
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||||
|
]
|
||||||
|
|
||||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||||
|
|
||||||
|
@ -278,7 +280,11 @@ class Scope:
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Scope]: Mapping of source alias to Scope
|
dict[str, Scope]: Mapping of source alias to Scope
|
||||||
"""
|
"""
|
||||||
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
|
return {
|
||||||
|
alias: scope
|
||||||
|
for alias, scope in self.sources.items()
|
||||||
|
if isinstance(scope, Scope) and scope.is_cte
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selects(self):
|
def selects(self):
|
||||||
|
@ -307,7 +313,9 @@ class Scope:
|
||||||
sources in the current scope.
|
sources in the current scope.
|
||||||
"""
|
"""
|
||||||
if self._external_columns is None:
|
if self._external_columns is None:
|
||||||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
self._external_columns = [
|
||||||
|
c for c in self.columns if c.table not in self.selected_sources
|
||||||
|
]
|
||||||
return self._external_columns
|
return self._external_columns
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -229,7 +229,9 @@ def simplify_literals(expression):
|
||||||
operands.append(a)
|
operands.append(a)
|
||||||
|
|
||||||
if len(operands) < size:
|
if len(operands) < size:
|
||||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
return functools.reduce(
|
||||||
|
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||||
|
)
|
||||||
elif isinstance(expression, exp.Neg):
|
elif isinstance(expression, exp.Neg):
|
||||||
this = expression.this
|
this = expression.this
|
||||||
if this.is_number:
|
if this.is_number:
|
||||||
|
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
|
||||||
return TRUE if not_ else FALSE
|
return TRUE if not_ else FALSE
|
||||||
if a == NULL:
|
if a == NULL:
|
||||||
return FALSE if not_ else TRUE
|
return FALSE if not_ else TRUE
|
||||||
|
elif isinstance(expression, exp.NullSafeEQ):
|
||||||
|
if a == b:
|
||||||
|
return TRUE
|
||||||
|
elif isinstance(expression, exp.NullSafeNEQ):
|
||||||
|
if a == b:
|
||||||
|
return FALSE
|
||||||
elif NULL in (a, b):
|
elif NULL in (a, b):
|
||||||
return NULL
|
return NULL
|
||||||
|
|
||||||
|
@ -357,7 +365,7 @@ def extract_date(cast):
|
||||||
|
|
||||||
def extract_interval(interval):
|
def extract_interval(interval):
|
||||||
try:
|
try:
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta # type: ignore
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(predicate, exp.Binary):
|
if isinstance(predicate, exp.Binary):
|
||||||
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
|
key = (
|
||||||
|
predicate.right
|
||||||
|
if any(node is column for node, *_ in predicate.left.walk())
|
||||||
|
else predicate.left
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
else:
|
else:
|
||||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||||
elif isinstance(parent_predicate, exp.All):
|
elif isinstance(parent_predicate, exp.All):
|
||||||
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
|
parent_predicate = _replace(
|
||||||
|
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||||
|
)
|
||||||
elif isinstance(parent_predicate, exp.Any):
|
elif isinstance(parent_predicate, exp.Any):
|
||||||
if value.this in group_by:
|
if value.this in group_by:
|
||||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||||
|
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
||||||
|
|
||||||
if key in group_by:
|
if key in group_by:
|
||||||
key.replace(nested)
|
key.replace(nested)
|
||||||
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
|
parent_predicate = _replace(
|
||||||
|
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||||
|
)
|
||||||
elif isinstance(predicate, exp.EQ):
|
elif isinstance(predicate, exp.EQ):
|
||||||
parent_predicate = _replace(
|
parent_predicate = _replace(
|
||||||
parent_predicate,
|
parent_predicate,
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
|
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
|
||||||
from sqlglot.helper import apply_index_offset, ensure_list, list_get
|
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
||||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||||
|
from sqlglot.trie import in_trie, new_trie
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
@ -20,7 +24,15 @@ def parse_var_map(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Parser:
|
class _Parser(type):
|
||||||
|
def __new__(cls, clsname, bases, attrs):
|
||||||
|
klass = super().__new__(cls, clsname, bases, attrs)
|
||||||
|
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
|
||||||
|
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
|
||||||
|
return klass
|
||||||
|
|
||||||
|
|
||||||
|
class Parser(metaclass=_Parser):
|
||||||
"""
|
"""
|
||||||
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
|
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
|
||||||
and produces a parsed syntax tree.
|
and produces a parsed syntax tree.
|
||||||
|
@ -45,16 +57,16 @@ class Parser:
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
|
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
|
||||||
"DATE_TO_DATE_STR": lambda args: exp.Cast(
|
"DATE_TO_DATE_STR": lambda args: exp.Cast(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||||
),
|
),
|
||||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||||
),
|
),
|
||||||
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
|
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
|
||||||
this=exp.Cast(
|
this=exp.Cast(
|
||||||
this=list_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||||
),
|
),
|
||||||
start=exp.Literal.number(1),
|
start=exp.Literal.number(1),
|
||||||
|
@ -90,6 +102,7 @@ class Parser:
|
||||||
TokenType.NVARCHAR,
|
TokenType.NVARCHAR,
|
||||||
TokenType.TEXT,
|
TokenType.TEXT,
|
||||||
TokenType.BINARY,
|
TokenType.BINARY,
|
||||||
|
TokenType.VARBINARY,
|
||||||
TokenType.JSON,
|
TokenType.JSON,
|
||||||
TokenType.INTERVAL,
|
TokenType.INTERVAL,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
|
@ -243,6 +256,7 @@ class Parser:
|
||||||
EQUALITY = {
|
EQUALITY = {
|
||||||
TokenType.EQ: exp.EQ,
|
TokenType.EQ: exp.EQ,
|
||||||
TokenType.NEQ: exp.NEQ,
|
TokenType.NEQ: exp.NEQ,
|
||||||
|
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
|
||||||
}
|
}
|
||||||
|
|
||||||
COMPARISON = {
|
COMPARISON = {
|
||||||
|
@ -298,6 +312,21 @@ class Parser:
|
||||||
TokenType.ANTI,
|
TokenType.ANTI,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LAMBDAS = {
|
||||||
|
TokenType.ARROW: lambda self, expressions: self.expression(
|
||||||
|
exp.Lambda,
|
||||||
|
this=self._parse_conjunction().transform(
|
||||||
|
self._replace_lambda, {node.name for node in expressions}
|
||||||
|
),
|
||||||
|
expressions=expressions,
|
||||||
|
),
|
||||||
|
TokenType.FARROW: lambda self, expressions: self.expression(
|
||||||
|
exp.Kwarg,
|
||||||
|
this=exp.Var(this=expressions[0].name),
|
||||||
|
expression=self._parse_conjunction(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
COLUMN_OPERATORS = {
|
COLUMN_OPERATORS = {
|
||||||
TokenType.DOT: None,
|
TokenType.DOT: None,
|
||||||
TokenType.DCOLON: lambda self, this, to: self.expression(
|
TokenType.DCOLON: lambda self, this, to: self.expression(
|
||||||
|
@ -362,20 +391,30 @@ class Parser:
|
||||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||||
|
TokenType.USE: lambda self: self._parse_use(),
|
||||||
}
|
}
|
||||||
|
|
||||||
PRIMARY_PARSERS = {
|
PRIMARY_PARSERS = {
|
||||||
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
|
TokenType.STRING: lambda self, token: self.expression(
|
||||||
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
|
exp.Literal, this=token.text, is_string=True
|
||||||
TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
|
),
|
||||||
TokenType.NULL: lambda *_: exp.Null(),
|
TokenType.NUMBER: lambda self, token: self.expression(
|
||||||
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
exp.Literal, this=token.text, is_string=False
|
||||||
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
),
|
||||||
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
|
TokenType.STAR: lambda self, _: self.expression(
|
||||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
||||||
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
),
|
||||||
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
|
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||||
|
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||||
|
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
||||||
|
TokenType.PARAMETER: lambda self, _: self.expression(
|
||||||
|
exp.Parameter, this=self._parse_var() or self._parse_primary()
|
||||||
|
),
|
||||||
|
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
||||||
|
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||||
|
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||||
|
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
RANGE_PARSERS = {
|
RANGE_PARSERS = {
|
||||||
|
@ -411,16 +450,24 @@ class Parser:
|
||||||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||||
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||||
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
|
||||||
|
exp.TableFormatProperty
|
||||||
|
),
|
||||||
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||||
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||||
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
||||||
TokenType.DETERMINISTIC: lambda self: self.expression(
|
TokenType.DETERMINISTIC: lambda self: self.expression(
|
||||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||||
),
|
),
|
||||||
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
|
TokenType.IMMUTABLE: lambda self: self.expression(
|
||||||
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
|
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||||
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
|
),
|
||||||
|
TokenType.STABLE: lambda self: self.expression(
|
||||||
|
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
|
||||||
|
),
|
||||||
|
TokenType.VOLATILE: lambda self: self.expression(
|
||||||
|
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSTRAINT_PARSERS = {
|
CONSTRAINT_PARSERS = {
|
||||||
|
@ -450,7 +497,8 @@ class Parser:
|
||||||
"group": lambda self: self._parse_group(),
|
"group": lambda self: self._parse_group(),
|
||||||
"having": lambda self: self._parse_having(),
|
"having": lambda self: self._parse_having(),
|
||||||
"qualify": lambda self: self._parse_qualify(),
|
"qualify": lambda self: self._parse_qualify(),
|
||||||
"window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
|
"window": lambda self: self._match(TokenType.WINDOW)
|
||||||
|
and self._parse_window(self._parse_id_var(), alias=True),
|
||||||
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
|
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
|
||||||
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
||||||
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
||||||
|
@ -459,6 +507,9 @@ class Parser:
|
||||||
"offset": lambda self: self._parse_offset(),
|
"offset": lambda self: self._parse_offset(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||||
|
SET_PARSERS: t.Dict[str, t.Callable] = {}
|
||||||
|
|
||||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||||
|
|
||||||
CREATABLES = {
|
CREATABLES = {
|
||||||
|
@ -488,7 +539,9 @@ class Parser:
|
||||||
"_curr",
|
"_curr",
|
||||||
"_next",
|
"_next",
|
||||||
"_prev",
|
"_prev",
|
||||||
"_greedy_subqueries",
|
"_prev_comment",
|
||||||
|
"_show_trie",
|
||||||
|
"_set_trie",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -519,7 +572,7 @@ class Parser:
|
||||||
self._curr = None
|
self._curr = None
|
||||||
self._next = None
|
self._next = None
|
||||||
self._prev = None
|
self._prev = None
|
||||||
self._greedy_subqueries = False
|
self._prev_comment = None
|
||||||
|
|
||||||
def parse(self, raw_tokens, sql=None):
|
def parse(self, raw_tokens, sql=None):
|
||||||
"""
|
"""
|
||||||
|
@ -533,10 +586,12 @@ class Parser:
|
||||||
Returns
|
Returns
|
||||||
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
|
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
|
||||||
"""
|
"""
|
||||||
return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
|
return self._parse(
|
||||||
|
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
|
||||||
|
)
|
||||||
|
|
||||||
def parse_into(self, expression_types, raw_tokens, sql=None):
|
def parse_into(self, expression_types, raw_tokens, sql=None):
|
||||||
for expression_type in ensure_list(expression_types):
|
for expression_type in ensure_collection(expression_types):
|
||||||
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
||||||
if not parser:
|
if not parser:
|
||||||
raise TypeError(f"No parser registered for {expression_type}")
|
raise TypeError(f"No parser registered for {expression_type}")
|
||||||
|
@ -597,6 +652,9 @@ class Parser:
|
||||||
|
|
||||||
def expression(self, exp_class, **kwargs):
|
def expression(self, exp_class, **kwargs):
|
||||||
instance = exp_class(**kwargs)
|
instance = exp_class(**kwargs)
|
||||||
|
if self._prev_comment:
|
||||||
|
instance.comment = self._prev_comment
|
||||||
|
self._prev_comment = None
|
||||||
self.validate_expression(instance)
|
self.validate_expression(instance)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -633,14 +691,16 @@ class Parser:
|
||||||
|
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def _get_token(self, index):
|
|
||||||
return list_get(self._tokens, index)
|
|
||||||
|
|
||||||
def _advance(self, times=1):
|
def _advance(self, times=1):
|
||||||
self._index += times
|
self._index += times
|
||||||
self._curr = self._get_token(self._index)
|
self._curr = seq_get(self._tokens, self._index)
|
||||||
self._next = self._get_token(self._index + 1)
|
self._next = seq_get(self._tokens, self._index + 1)
|
||||||
self._prev = self._get_token(self._index - 1) if self._index > 0 else None
|
if self._index > 0:
|
||||||
|
self._prev = self._tokens[self._index - 1]
|
||||||
|
self._prev_comment = self._prev.comment
|
||||||
|
else:
|
||||||
|
self._prev = None
|
||||||
|
self._prev_comment = None
|
||||||
|
|
||||||
def _retreat(self, index):
|
def _retreat(self, index):
|
||||||
self._advance(index - self._index)
|
self._advance(index - self._index)
|
||||||
|
@ -661,6 +721,7 @@ class Parser:
|
||||||
|
|
||||||
expression = self._parse_expression()
|
expression = self._parse_expression()
|
||||||
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
||||||
|
|
||||||
self._parse_query_modifiers(expression)
|
self._parse_query_modifiers(expression)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@ -682,7 +743,11 @@ class Parser:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_exists(self, not_=False):
|
def _parse_exists(self, not_=False):
|
||||||
return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
|
return (
|
||||||
|
self._match(TokenType.IF)
|
||||||
|
and (not not_ or self._match(TokenType.NOT))
|
||||||
|
and self._match(TokenType.EXISTS)
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_create(self):
|
def _parse_create(self):
|
||||||
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
||||||
|
@ -931,7 +996,9 @@ class Parser:
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Delete,
|
exp.Delete,
|
||||||
this=self._parse_table(schema=True),
|
this=self._parse_table(schema=True),
|
||||||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
|
using=self._parse_csv(
|
||||||
|
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
|
||||||
|
),
|
||||||
where=self._parse_where(),
|
where=self._parse_where(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -983,11 +1050,13 @@ class Parser:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_values():
|
def parse_values():
|
||||||
k = self._parse_var()
|
key = self._parse_var()
|
||||||
|
value = None
|
||||||
|
|
||||||
if self._match(TokenType.EQ):
|
if self._match(TokenType.EQ):
|
||||||
v = self._parse_string()
|
value = self._parse_string()
|
||||||
return (k, v)
|
|
||||||
return (k, None)
|
return exp.Property(this=key, value=value)
|
||||||
|
|
||||||
self._match_l_paren()
|
self._match_l_paren()
|
||||||
values = self._parse_csv(parse_values)
|
values = self._parse_csv(parse_values)
|
||||||
|
@ -1019,6 +1088,8 @@ class Parser:
|
||||||
self.raise_error(f"{this.key} does not support CTE")
|
self.raise_error(f"{this.key} does not support CTE")
|
||||||
this = cte
|
this = cte
|
||||||
elif self._match(TokenType.SELECT):
|
elif self._match(TokenType.SELECT):
|
||||||
|
comment = self._prev_comment
|
||||||
|
|
||||||
hint = self._parse_hint()
|
hint = self._parse_hint()
|
||||||
all_ = self._match(TokenType.ALL)
|
all_ = self._match(TokenType.ALL)
|
||||||
distinct = self._match(TokenType.DISTINCT)
|
distinct = self._match(TokenType.DISTINCT)
|
||||||
|
@ -1033,7 +1104,7 @@ class Parser:
|
||||||
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
||||||
|
|
||||||
limit = self._parse_limit(top=True)
|
limit = self._parse_limit(top=True)
|
||||||
expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
|
expressions = self._parse_csv(self._parse_expression)
|
||||||
|
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Select,
|
exp.Select,
|
||||||
|
@ -1042,6 +1113,7 @@ class Parser:
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
this.comment = comment
|
||||||
from_ = self._parse_from()
|
from_ = self._parse_from()
|
||||||
if from_:
|
if from_:
|
||||||
this.set("from", from_)
|
this.set("from", from_)
|
||||||
|
@ -1072,8 +1144,10 @@ class Parser:
|
||||||
while True:
|
while True:
|
||||||
expressions.append(self._parse_cte())
|
expressions.append(self._parse_cte())
|
||||||
|
|
||||||
if not self._match(TokenType.COMMA):
|
if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
self._match(TokenType.WITH)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.With,
|
exp.With,
|
||||||
|
@ -1111,11 +1185,7 @@ class Parser:
|
||||||
if not alias and not columns:
|
if not alias and not columns:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(exp.TableAlias, this=alias, columns=columns)
|
||||||
exp.TableAlias,
|
|
||||||
this=alias,
|
|
||||||
columns=columns,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_subquery(self, this):
|
def _parse_subquery(self, this):
|
||||||
return self.expression(
|
return self.expression(
|
||||||
|
@ -1150,12 +1220,6 @@ class Parser:
|
||||||
if expression:
|
if expression:
|
||||||
this.set(key, expression)
|
this.set(key, expression)
|
||||||
|
|
||||||
def _parse_annotation(self, expression):
|
|
||||||
if self._match(TokenType.ANNOTATION):
|
|
||||||
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
|
|
||||||
|
|
||||||
return expression
|
|
||||||
|
|
||||||
def _parse_hint(self):
|
def _parse_hint(self):
|
||||||
if self._match(TokenType.HINT):
|
if self._match(TokenType.HINT):
|
||||||
hints = self._parse_csv(self._parse_function)
|
hints = self._parse_csv(self._parse_function)
|
||||||
|
@ -1295,7 +1359,9 @@ class Parser:
|
||||||
if not table:
|
if not table:
|
||||||
self.raise_error("Expected table name")
|
self.raise_error("Expected table name")
|
||||||
|
|
||||||
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
|
this = self.expression(
|
||||||
|
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||||
|
)
|
||||||
|
|
||||||
if schema:
|
if schema:
|
||||||
return self._parse_schema(this=this)
|
return self._parse_schema(this=this)
|
||||||
|
@ -1500,7 +1566,9 @@ class Parser:
|
||||||
if not skip_order_token and not self._match(TokenType.ORDER_BY):
|
if not skip_order_token and not self._match(TokenType.ORDER_BY):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
|
return self.expression(
|
||||||
|
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_sort(self, token_type, exp_class):
|
def _parse_sort(self, token_type, exp_class):
|
||||||
if not self._match(token_type):
|
if not self._match(token_type):
|
||||||
|
@ -1521,7 +1589,8 @@ class Parser:
|
||||||
if (
|
if (
|
||||||
not explicitly_null_ordered
|
not explicitly_null_ordered
|
||||||
and (
|
and (
|
||||||
(asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
|
(asc and self.null_ordering == "nulls_are_small")
|
||||||
|
or (desc and self.null_ordering != "nulls_are_small")
|
||||||
)
|
)
|
||||||
and self.null_ordering != "nulls_are_last"
|
and self.null_ordering != "nulls_are_last"
|
||||||
):
|
):
|
||||||
|
@ -1606,6 +1675,9 @@ class Parser:
|
||||||
|
|
||||||
def _parse_is(self, this):
|
def _parse_is(self, this):
|
||||||
negate = self._match(TokenType.NOT)
|
negate = self._match(TokenType.NOT)
|
||||||
|
if self._match(TokenType.DISTINCT_FROM):
|
||||||
|
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||||
|
return self.expression(klass, this=this, expression=self._parse_expression())
|
||||||
this = self.expression(
|
this = self.expression(
|
||||||
exp.Is,
|
exp.Is,
|
||||||
this=this,
|
this=this,
|
||||||
|
@ -1653,9 +1725,13 @@ class Parser:
|
||||||
expression=self._parse_term(),
|
expression=self._parse_term(),
|
||||||
)
|
)
|
||||||
elif self._match_pair(TokenType.LT, TokenType.LT):
|
elif self._match_pair(TokenType.LT, TokenType.LT):
|
||||||
this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
|
this = self.expression(
|
||||||
|
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
|
||||||
|
)
|
||||||
elif self._match_pair(TokenType.GT, TokenType.GT):
|
elif self._match_pair(TokenType.GT, TokenType.GT):
|
||||||
this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
|
this = self.expression(
|
||||||
|
exp.BitwiseRightShift, this=this, expression=self._parse_term()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -1685,7 +1761,7 @@ class Parser:
|
||||||
)
|
)
|
||||||
|
|
||||||
index = self._index
|
index = self._index
|
||||||
type_token = self._parse_types()
|
type_token = self._parse_types(check_func=True)
|
||||||
this = self._parse_column()
|
this = self._parse_column()
|
||||||
|
|
||||||
if type_token:
|
if type_token:
|
||||||
|
@ -1698,7 +1774,7 @@ class Parser:
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_types(self):
|
def _parse_types(self, check_func=False):
|
||||||
index = self._index
|
index = self._index
|
||||||
|
|
||||||
if not self._match_set(self.TYPE_TOKENS):
|
if not self._match_set(self.TYPE_TOKENS):
|
||||||
|
@ -1708,10 +1784,13 @@ class Parser:
|
||||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||||
is_struct = type_token == TokenType.STRUCT
|
is_struct = type_token == TokenType.STRUCT
|
||||||
expressions = None
|
expressions = None
|
||||||
|
maybe_func = False
|
||||||
|
|
||||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||||
return exp.DataType(
|
return exp.DataType(
|
||||||
this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True
|
this=exp.DataType.Type.ARRAY,
|
||||||
|
expressions=[exp.DataType.build(type_token.value)],
|
||||||
|
nested=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._match(TokenType.L_BRACKET):
|
if self._match(TokenType.L_BRACKET):
|
||||||
|
@ -1731,6 +1810,7 @@ class Parser:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
maybe_func = True
|
||||||
|
|
||||||
if nested and self._match(TokenType.LT):
|
if nested and self._match(TokenType.LT):
|
||||||
if is_struct:
|
if is_struct:
|
||||||
|
@ -1741,26 +1821,47 @@ class Parser:
|
||||||
if not self._match(TokenType.GT):
|
if not self._match(TokenType.GT):
|
||||||
self.raise_error("Expecting >")
|
self.raise_error("Expecting >")
|
||||||
|
|
||||||
|
value = None
|
||||||
if type_token in self.TIMESTAMPS:
|
if type_token in self.TIMESTAMPS:
|
||||||
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
|
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
|
||||||
if tz:
|
value = exp.DataType(
|
||||||
return exp.DataType(
|
|
||||||
this=exp.DataType.Type.TIMESTAMPTZ,
|
this=exp.DataType.Type.TIMESTAMPTZ,
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
)
|
)
|
||||||
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
elif (
|
||||||
if ltz:
|
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
||||||
return exp.DataType(
|
):
|
||||||
|
value = exp.DataType(
|
||||||
this=exp.DataType.Type.TIMESTAMPLTZ,
|
this=exp.DataType.Type.TIMESTAMPLTZ,
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
)
|
)
|
||||||
self._match(TokenType.WITHOUT_TIME_ZONE)
|
elif self._match(TokenType.WITHOUT_TIME_ZONE):
|
||||||
|
value = exp.DataType(
|
||||||
return exp.DataType(
|
|
||||||
this=exp.DataType.Type.TIMESTAMP,
|
this=exp.DataType.Type.TIMESTAMP,
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
maybe_func = maybe_func and value is None
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
value = exp.DataType(
|
||||||
|
this=exp.DataType.Type.TIMESTAMP,
|
||||||
|
expressions=expressions,
|
||||||
|
)
|
||||||
|
|
||||||
|
if maybe_func and check_func:
|
||||||
|
index2 = self._index
|
||||||
|
peek = self._parse_string()
|
||||||
|
|
||||||
|
if not peek:
|
||||||
|
self._retreat(index)
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._retreat(index2)
|
||||||
|
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
|
||||||
return exp.DataType(
|
return exp.DataType(
|
||||||
this=exp.DataType.Type[type_token.value.upper()],
|
this=exp.DataType.Type[type_token.value.upper()],
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
|
@ -1826,22 +1927,29 @@ class Parser:
|
||||||
return exp.Literal.number(f"0.{self._prev.text}")
|
return exp.Literal.number(f"0.{self._prev.text}")
|
||||||
|
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
comment = self._prev_comment
|
||||||
query = self._parse_select()
|
query = self._parse_select()
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
expressions = [query]
|
expressions = [query]
|
||||||
else:
|
else:
|
||||||
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
|
expressions = self._parse_csv(
|
||||||
|
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
|
||||||
|
)
|
||||||
|
|
||||||
this = list_get(expressions, 0)
|
this = seq_get(expressions, 0)
|
||||||
self._parse_query_modifiers(this)
|
self._parse_query_modifiers(this)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
if isinstance(this, exp.Subqueryable):
|
if isinstance(this, exp.Subqueryable):
|
||||||
return self._parse_set_operations(self._parse_subquery(this))
|
this = self._parse_set_operations(self._parse_subquery(this))
|
||||||
if len(expressions) > 1:
|
elif len(expressions) > 1:
|
||||||
return self.expression(exp.Tuple, expressions=expressions)
|
this = self.expression(exp.Tuple, expressions=expressions)
|
||||||
return self.expression(exp.Paren, this=this)
|
else:
|
||||||
|
this = self.expression(exp.Paren, this=this)
|
||||||
|
if comment:
|
||||||
|
this.comment = comment
|
||||||
|
return this
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1894,7 +2002,8 @@ class Parser:
|
||||||
self.validate_expression(this, args)
|
self.validate_expression(this, args)
|
||||||
else:
|
else:
|
||||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||||
self._match_r_paren()
|
|
||||||
|
self._match_r_paren(this)
|
||||||
return self._parse_window(this)
|
return self._parse_window(this)
|
||||||
|
|
||||||
def _parse_user_defined_function(self):
|
def _parse_user_defined_function(self):
|
||||||
|
@ -1920,6 +2029,18 @@ class Parser:
|
||||||
|
|
||||||
return self.expression(exp.Identifier, this=token.text)
|
return self.expression(exp.Identifier, this=token.text)
|
||||||
|
|
||||||
|
def _parse_session_parameter(self):
|
||||||
|
kind = None
|
||||||
|
this = self._parse_id_var() or self._parse_primary()
|
||||||
|
if self._match(TokenType.DOT):
|
||||||
|
kind = this.name
|
||||||
|
this = self._parse_var() or self._parse_primary()
|
||||||
|
return self.expression(
|
||||||
|
exp.SessionParameter,
|
||||||
|
this=this,
|
||||||
|
kind=kind,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_udf_kwarg(self):
|
def _parse_udf_kwarg(self):
|
||||||
this = self._parse_id_var()
|
this = self._parse_id_var()
|
||||||
kind = self._parse_types()
|
kind = self._parse_types()
|
||||||
|
@ -1938,11 +2059,15 @@ class Parser:
|
||||||
else:
|
else:
|
||||||
expressions = [self._parse_id_var()]
|
expressions = [self._parse_id_var()]
|
||||||
|
|
||||||
if not self._match(TokenType.ARROW):
|
if self._match_set(self.LAMBDAS):
|
||||||
|
return self.LAMBDAS[self._prev.token_type](self, expressions)
|
||||||
|
|
||||||
self._retreat(index)
|
self._retreat(index)
|
||||||
|
|
||||||
if self._match(TokenType.DISTINCT):
|
if self._match(TokenType.DISTINCT):
|
||||||
this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction))
|
this = self.expression(
|
||||||
|
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
this = self._parse_conjunction()
|
this = self._parse_conjunction()
|
||||||
|
|
||||||
|
@ -1953,20 +2078,15 @@ class Parser:
|
||||||
|
|
||||||
return self._parse_alias(self._parse_limit(self._parse_order(this)))
|
return self._parse_alias(self._parse_limit(self._parse_order(this)))
|
||||||
|
|
||||||
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
|
|
||||||
return self.expression(
|
|
||||||
exp.Lambda,
|
|
||||||
this=conjunction,
|
|
||||||
expressions=expressions,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_schema(self, this=None):
|
def _parse_schema(self, this=None):
|
||||||
index = self._index
|
index = self._index
|
||||||
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
|
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
|
||||||
self._retreat(index)
|
self._retreat(index)
|
||||||
return this
|
return this
|
||||||
|
|
||||||
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
|
args = self._parse_csv(
|
||||||
|
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
|
||||||
|
)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return self.expression(exp.Schema, this=this, expressions=args)
|
return self.expression(exp.Schema, this=this, expressions=args)
|
||||||
|
|
||||||
|
@ -2104,6 +2224,7 @@ class Parser:
|
||||||
if not self._match(TokenType.R_BRACKET):
|
if not self._match(TokenType.R_BRACKET):
|
||||||
self.raise_error("Expected ]")
|
self.raise_error("Expected ]")
|
||||||
|
|
||||||
|
this.comment = self._prev_comment
|
||||||
return self._parse_bracket(this)
|
return self._parse_bracket(this)
|
||||||
|
|
||||||
def _parse_case(self):
|
def _parse_case(self):
|
||||||
|
@ -2124,7 +2245,9 @@ class Parser:
|
||||||
if not self._match(TokenType.END):
|
if not self._match(TokenType.END):
|
||||||
self.raise_error("Expected END after CASE", self._prev)
|
self.raise_error("Expected END after CASE", self._prev)
|
||||||
|
|
||||||
return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
|
return self._parse_window(
|
||||||
|
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_if(self):
|
def _parse_if(self):
|
||||||
if self._match(TokenType.L_PAREN):
|
if self._match(TokenType.L_PAREN):
|
||||||
|
@ -2331,7 +2454,9 @@ class Parser:
|
||||||
self._match(TokenType.BETWEEN)
|
self._match(TokenType.BETWEEN)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
|
"value": (
|
||||||
|
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
|
||||||
|
)
|
||||||
or self._parse_bitwise(),
|
or self._parse_bitwise(),
|
||||||
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
|
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
|
||||||
}
|
}
|
||||||
|
@ -2348,7 +2473,7 @@ class Parser:
|
||||||
this=this,
|
this=this,
|
||||||
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
|
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
|
||||||
)
|
)
|
||||||
self._match_r_paren()
|
self._match_r_paren(aliases)
|
||||||
return aliases
|
return aliases
|
||||||
|
|
||||||
alias = self._parse_id_var(any_token)
|
alias = self._parse_id_var(any_token)
|
||||||
|
@ -2365,28 +2490,29 @@ class Parser:
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||||
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
|
self._advance()
|
||||||
|
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||||
return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
|
return None
|
||||||
|
return exp.Identifier(this=self._prev.text, quoted=False)
|
||||||
|
|
||||||
def _parse_string(self):
|
def _parse_string(self):
|
||||||
if self._match(TokenType.STRING):
|
if self._match(TokenType.STRING):
|
||||||
return exp.Literal.string(self._prev.text)
|
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
|
||||||
return self._parse_placeholder()
|
return self._parse_placeholder()
|
||||||
|
|
||||||
def _parse_number(self):
|
def _parse_number(self):
|
||||||
if self._match(TokenType.NUMBER):
|
if self._match(TokenType.NUMBER):
|
||||||
return exp.Literal.number(self._prev.text)
|
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
|
||||||
return self._parse_placeholder()
|
return self._parse_placeholder()
|
||||||
|
|
||||||
def _parse_identifier(self):
|
def _parse_identifier(self):
|
||||||
if self._match(TokenType.IDENTIFIER):
|
if self._match(TokenType.IDENTIFIER):
|
||||||
return exp.Identifier(this=self._prev.text, quoted=True)
|
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
|
||||||
return self._parse_placeholder()
|
return self._parse_placeholder()
|
||||||
|
|
||||||
def _parse_var(self):
|
def _parse_var(self):
|
||||||
if self._match(TokenType.VAR):
|
if self._match(TokenType.VAR):
|
||||||
return exp.Var(this=self._prev.text)
|
return self.expression(exp.Var, this=self._prev.text)
|
||||||
return self._parse_placeholder()
|
return self._parse_placeholder()
|
||||||
|
|
||||||
def _parse_var_or_string(self):
|
def _parse_var_or_string(self):
|
||||||
|
@ -2394,27 +2520,27 @@ class Parser:
|
||||||
|
|
||||||
def _parse_null(self):
|
def _parse_null(self):
|
||||||
if self._match(TokenType.NULL):
|
if self._match(TokenType.NULL):
|
||||||
return exp.Null()
|
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_boolean(self):
|
def _parse_boolean(self):
|
||||||
if self._match(TokenType.TRUE):
|
if self._match(TokenType.TRUE):
|
||||||
return exp.Boolean(this=True)
|
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
|
||||||
if self._match(TokenType.FALSE):
|
if self._match(TokenType.FALSE):
|
||||||
return exp.Boolean(this=False)
|
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_star(self):
|
def _parse_star(self):
|
||||||
if self._match(TokenType.STAR):
|
if self._match(TokenType.STAR):
|
||||||
return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
|
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_placeholder(self):
|
def _parse_placeholder(self):
|
||||||
if self._match(TokenType.PLACEHOLDER):
|
if self._match(TokenType.PLACEHOLDER):
|
||||||
return exp.Placeholder()
|
return self.expression(exp.Placeholder)
|
||||||
elif self._match(TokenType.COLON):
|
elif self._match(TokenType.COLON):
|
||||||
self._advance()
|
self._advance()
|
||||||
return exp.Placeholder(this=self._prev.text)
|
return self.expression(exp.Placeholder, this=self._prev.text)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_except(self):
|
def _parse_except(self):
|
||||||
|
@ -2432,22 +2558,27 @@ class Parser:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
def _parse_csv(self, parse):
|
def _parse_csv(self, parse_method):
|
||||||
parse_result = parse()
|
parse_result = parse_method()
|
||||||
items = [parse_result] if parse_result is not None else []
|
items = [parse_result] if parse_result is not None else []
|
||||||
|
|
||||||
while self._match(TokenType.COMMA):
|
while self._match(TokenType.COMMA):
|
||||||
parse_result = parse()
|
if parse_result and self._prev_comment is not None:
|
||||||
|
parse_result.comment = self._prev_comment
|
||||||
|
|
||||||
|
parse_result = parse_method()
|
||||||
if parse_result is not None:
|
if parse_result is not None:
|
||||||
items.append(parse_result)
|
items.append(parse_result)
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
def _parse_tokens(self, parse, expressions):
|
def _parse_tokens(self, parse_method, expressions):
|
||||||
this = parse()
|
this = parse_method()
|
||||||
|
|
||||||
while self._match_set(expressions):
|
while self._match_set(expressions):
|
||||||
this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
|
this = self.expression(
|
||||||
|
expressions[self._prev.token_type], this=this, expression=parse_method()
|
||||||
|
)
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
|
@ -2460,6 +2591,47 @@ class Parser:
|
||||||
def _parse_select_or_expression(self):
|
def _parse_select_or_expression(self):
|
||||||
return self._parse_select() or self._parse_expression()
|
return self._parse_select() or self._parse_expression()
|
||||||
|
|
||||||
|
def _parse_use(self):
|
||||||
|
return self.expression(exp.Use, this=self._parse_id_var())
|
||||||
|
|
||||||
|
def _parse_show(self):
|
||||||
|
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||||
|
if parser:
|
||||||
|
return parser(self)
|
||||||
|
self._advance()
|
||||||
|
return self.expression(exp.Show, this=self._prev.text.upper())
|
||||||
|
|
||||||
|
def _default_parse_set_item(self):
|
||||||
|
return self.expression(
|
||||||
|
exp.SetItem,
|
||||||
|
this=self._parse_statement(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_set_item(self):
|
||||||
|
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
|
||||||
|
return parser(self) if parser else self._default_parse_set_item()
|
||||||
|
|
||||||
|
def _parse_set(self):
|
||||||
|
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||||
|
|
||||||
|
def _find_parser(self, parsers, trie):
|
||||||
|
index = self._index
|
||||||
|
this = []
|
||||||
|
while True:
|
||||||
|
# The current token might be multiple words
|
||||||
|
curr = self._curr.text.upper()
|
||||||
|
key = curr.split(" ")
|
||||||
|
this.append(curr)
|
||||||
|
self._advance()
|
||||||
|
result, trie = in_trie(trie, key)
|
||||||
|
if result == 0:
|
||||||
|
break
|
||||||
|
if result == 2:
|
||||||
|
subparser = parsers[" ".join(this)]
|
||||||
|
return subparser
|
||||||
|
self._retreat(index)
|
||||||
|
return None
|
||||||
|
|
||||||
def _match(self, token_type):
|
def _match(self, token_type):
|
||||||
if not self._curr:
|
if not self._curr:
|
||||||
return None
|
return None
|
||||||
|
@ -2491,13 +2663,17 @@ class Parser:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _match_l_paren(self):
|
def _match_l_paren(self, expression=None):
|
||||||
if not self._match(TokenType.L_PAREN):
|
if not self._match(TokenType.L_PAREN):
|
||||||
self.raise_error("Expecting (")
|
self.raise_error("Expecting (")
|
||||||
|
if expression and self._prev_comment:
|
||||||
|
expression.comment = self._prev_comment
|
||||||
|
|
||||||
def _match_r_paren(self):
|
def _match_r_paren(self, expression=None):
|
||||||
if not self._match(TokenType.R_PAREN):
|
if not self._match(TokenType.R_PAREN):
|
||||||
self.raise_error("Expecting )")
|
self.raise_error("Expecting )")
|
||||||
|
if expression and self._prev_comment:
|
||||||
|
expression.comment = self._prev_comment
|
||||||
|
|
||||||
def _match_text(self, *texts):
|
def _match_text(self, *texts):
|
||||||
index = self._index
|
index = self._index
|
||||||
|
|
|
@ -72,7 +72,9 @@ class Step:
|
||||||
if from_:
|
if from_:
|
||||||
from_ = from_.expressions
|
from_ = from_.expressions
|
||||||
if len(from_) > 1:
|
if len(from_) > 1:
|
||||||
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
|
raise UnsupportedError(
|
||||||
|
"Multi-from statements are unsupported. Run it through the optimizer"
|
||||||
|
)
|
||||||
|
|
||||||
step = Scan.from_expression(from_[0], ctes)
|
step = Scan.from_expression(from_[0], ctes)
|
||||||
else:
|
else:
|
||||||
|
@ -102,7 +104,7 @@ class Step:
|
||||||
continue
|
continue
|
||||||
if operand not in operands:
|
if operand not in operands:
|
||||||
operands[operand] = f"_a_{next(sequence)}"
|
operands[operand] = f"_a_{next(sequence)}"
|
||||||
operand.replace(exp.column(operands[operand], step.name, quoted=True))
|
operand.replace(exp.column(operands[operand], quoted=True))
|
||||||
else:
|
else:
|
||||||
projections.append(e)
|
projections.append(e)
|
||||||
|
|
||||||
|
@ -117,9 +119,11 @@ class Step:
|
||||||
aggregate = Aggregate()
|
aggregate = Aggregate()
|
||||||
aggregate.source = step.name
|
aggregate.source = step.name
|
||||||
aggregate.name = step.name
|
aggregate.name = step.name
|
||||||
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
|
aggregate.operands = tuple(
|
||||||
|
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||||
|
)
|
||||||
aggregate.aggregations = aggregations
|
aggregate.aggregations = aggregations
|
||||||
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
|
aggregate.group = group.expressions
|
||||||
aggregate.add_dependency(step)
|
aggregate.add_dependency(step)
|
||||||
step = aggregate
|
step = aggregate
|
||||||
|
|
||||||
|
@ -136,9 +140,6 @@ class Step:
|
||||||
sort.key = order.expressions
|
sort.key = order.expressions
|
||||||
sort.add_dependency(step)
|
sort.add_dependency(step)
|
||||||
step = sort
|
step = sort
|
||||||
for k in sort.key + projections:
|
|
||||||
for column in k.find_all(exp.Column):
|
|
||||||
column.set("table", exp.to_identifier(step.name, quoted=True))
|
|
||||||
|
|
||||||
step.projections = projections
|
step.projections = projections
|
||||||
|
|
||||||
|
@ -203,7 +204,9 @@ class Scan(Step):
|
||||||
alias_ = expression.alias
|
alias_ = expression.alias
|
||||||
|
|
||||||
if not alias_:
|
if not alias_:
|
||||||
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
|
raise UnsupportedError(
|
||||||
|
"Tables/Subqueries must be aliased. Run it through the optimizer"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(expression, exp.Subquery):
|
if isinstance(expression, exp.Subquery):
|
||||||
table = expression.this
|
table = expression.this
|
||||||
|
|
0
sqlglot/py.typed
Normal file
0
sqlglot/py.typed
Normal file
|
@ -1,44 +1,60 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.errors import OptimizeError
|
from sqlglot.errors import SchemaError
|
||||||
from sqlglot.helper import csv_reader
|
from sqlglot.helper import csv_reader
|
||||||
|
from sqlglot.trie import in_trie, new_trie
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.dataframe.sql.types import StructType
|
||||||
|
|
||||||
|
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
||||||
|
|
||||||
|
TABLE_ARGS = ("this", "db", "catalog")
|
||||||
|
|
||||||
|
|
||||||
class Schema(abc.ABC):
|
class Schema(abc.ABC):
|
||||||
"""Abstract base class for database schemas"""
|
"""Abstract base class for database schemas"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def add_table(self, table, column_mapping=None):
|
def add_table(
|
||||||
|
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register or update a table. Some implementing classes may require column information to also be provided
|
Register or update a table. Some implementing classes may require column information to also be provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
table: table expression instance or string representing the table.
|
||||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
column_mapping: a column mapping that describes the structure of the table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def column_names(self, table, only_visible=False):
|
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||||
"""
|
"""
|
||||||
Get the column names for a table.
|
Get the column names for a table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (sqlglot.expressions.Table): Table expression instance
|
table: the `Table` expression instance.
|
||||||
only_visible (bool): Whether to include invisible columns
|
only_visible: whether to include invisible columns.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[str]: list of column names
|
The list of column names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_column_type(self, table, column):
|
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
|
||||||
"""
|
"""
|
||||||
Get the exp.DataType type of a column in the schema.
|
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (sqlglot.expressions.Table): The source table.
|
table: the source table.
|
||||||
column (sqlglot.expressions.Column): The target column.
|
column: the target column.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
The resulting column type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,132 +76,179 @@ class MappingSchema(Schema):
|
||||||
dialect (str): The dialect to be used for custom type mappings.
|
dialect (str): The dialect to be used for custom type mappings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schema=None, visible=None, dialect=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
schema: t.Optional[t.Dict] = None,
|
||||||
|
visible: t.Optional[t.Dict] = None,
|
||||||
|
dialect: t.Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self.schema = schema or {}
|
self.schema = schema or {}
|
||||||
self.visible = visible
|
self.visible = visible or {}
|
||||||
|
self.schema_trie = self._build_trie(self.schema)
|
||||||
self.dialect = dialect
|
self.dialect = dialect
|
||||||
self._type_mapping_cache = {}
|
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
|
||||||
self.supported_table_args = []
|
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||||
self.forbidden_table_args = set()
|
|
||||||
if self.schema:
|
|
||||||
self._initialize_supported_args()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_mapping_schema(cls, mapping_schema):
|
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||||
return MappingSchema(
|
return MappingSchema(
|
||||||
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
|
schema=mapping_schema.schema,
|
||||||
|
visible=mapping_schema.visible,
|
||||||
|
dialect=mapping_schema.dialect,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self, **kwargs):
|
def copy(self, **kwargs) -> MappingSchema:
|
||||||
return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
|
return MappingSchema(
|
||||||
|
**{ # type: ignore
|
||||||
|
"schema": self.schema.copy(),
|
||||||
|
"visible": self.visible.copy(),
|
||||||
|
"dialect": self.dialect,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def add_table(self, table, column_mapping=None):
|
@property
|
||||||
|
def supported_table_args(self):
|
||||||
|
if not self._supported_table_args and self.schema:
|
||||||
|
depth = _dict_depth(self.schema)
|
||||||
|
|
||||||
|
if not depth or depth == 1: # {}
|
||||||
|
self._supported_table_args = tuple()
|
||||||
|
elif 2 <= depth <= 4:
|
||||||
|
self._supported_table_args = TABLE_ARGS[: depth - 1]
|
||||||
|
else:
|
||||||
|
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
|
||||||
|
|
||||||
|
return self._supported_table_args
|
||||||
|
|
||||||
|
def add_table(
|
||||||
|
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
table: the `Table` expression instance or string representing the table.
|
||||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
column_mapping: a column mapping that describes the structure of the table.
|
||||||
"""
|
"""
|
||||||
table = exp.to_table(table)
|
table_ = self._ensure_table(table)
|
||||||
self._validate_table(table)
|
|
||||||
column_mapping = ensure_column_mapping(column_mapping)
|
column_mapping = ensure_column_mapping(column_mapping)
|
||||||
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
|
schema = self.find_schema(table_, raise_on_missing=False)
|
||||||
existing_column_mapping = _nested_get(
|
|
||||||
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
|
if schema and not column_mapping:
|
||||||
)
|
|
||||||
if existing_column_mapping and not column_mapping:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
_nested_set(
|
_nested_set(
|
||||||
self.schema,
|
self.schema,
|
||||||
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
|
list(reversed(self.table_parts(table_))),
|
||||||
column_mapping,
|
column_mapping,
|
||||||
)
|
)
|
||||||
self._initialize_supported_args()
|
self.schema_trie = self._build_trie(self.schema)
|
||||||
|
|
||||||
def _get_table_args_from_table(self, table):
|
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
||||||
if table.args.get("catalog") is not None:
|
table_ = exp.to_table(table)
|
||||||
return "catalog", "db", "this"
|
|
||||||
if table.args.get("db") is not None:
|
|
||||||
return "db", "this"
|
|
||||||
return ("this",)
|
|
||||||
|
|
||||||
def _validate_table(self, table):
|
if not table_:
|
||||||
if not self.supported_table_args and isinstance(table, exp.Table):
|
raise SchemaError(f"Not a valid table '{table}'")
|
||||||
return
|
|
||||||
for forbidden in self.forbidden_table_args:
|
|
||||||
if table.text(forbidden):
|
|
||||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
|
||||||
for expected in self.supported_table_args:
|
|
||||||
if not table.text(expected):
|
|
||||||
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
|
|
||||||
|
|
||||||
def column_names(self, table, only_visible=False):
|
return table_
|
||||||
table = exp.to_table(table)
|
|
||||||
if not isinstance(table.this, exp.Identifier):
|
|
||||||
return fs_get(table)
|
|
||||||
|
|
||||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
def table_parts(self, table: exp.Table) -> t.List[str]:
|
||||||
|
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
|
||||||
|
|
||||||
for forbidden in self.forbidden_table_args:
|
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||||
if table.text(forbidden):
|
table_ = self._ensure_table(table)
|
||||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
|
||||||
|
if not isinstance(table_.this, exp.Identifier):
|
||||||
|
return fs_get(table) # type: ignore
|
||||||
|
|
||||||
|
schema = self.find_schema(table_)
|
||||||
|
|
||||||
|
if schema is None:
|
||||||
|
raise SchemaError(f"Could not find table schema {table}")
|
||||||
|
|
||||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
|
||||||
if not only_visible or not self.visible:
|
if not only_visible or not self.visible:
|
||||||
return columns
|
return list(schema)
|
||||||
|
|
||||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||||
return [col for col in columns if col in visible]
|
return [col for col in schema if col in visible] # type: ignore
|
||||||
|
|
||||||
def get_column_type(self, table, column):
|
def find_schema(
|
||||||
try:
|
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
) -> t.Optional[t.Dict[str, str]]:
|
||||||
|
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||||
|
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
|
||||||
|
|
||||||
|
if value == 0:
|
||||||
|
if raise_on_missing:
|
||||||
|
raise SchemaError(f"Cannot find schema for {table}.")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif value == 1:
|
||||||
|
possibilities = flatten_schema(trie)
|
||||||
|
if len(possibilities) == 1:
|
||||||
|
parts.extend(possibilities[0])
|
||||||
|
else:
|
||||||
|
message = ", ".join(".".join(parts) for parts in possibilities)
|
||||||
|
if raise_on_missing:
|
||||||
|
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._nested_get(parts, raise_on_missing=raise_on_missing)
|
||||||
|
|
||||||
|
def get_column_type(
|
||||||
|
self, table: exp.Table | str, column: exp.Column | str
|
||||||
|
) -> exp.DataType.Type:
|
||||||
|
column_name = column if isinstance(column, str) else column.name
|
||||||
|
table_ = exp.to_table(table)
|
||||||
|
if table_:
|
||||||
|
table_schema = self.find_schema(table_)
|
||||||
|
schema_type = table_schema.get(column_name).upper() # type: ignore
|
||||||
return self._convert_type(schema_type)
|
return self._convert_type(schema_type)
|
||||||
except:
|
raise SchemaError(f"Could not convert table '{table}'")
|
||||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
|
||||||
|
|
||||||
def _convert_type(self, schema_type):
|
def _convert_type(self, schema_type: str) -> exp.DataType.Type:
|
||||||
"""
|
"""
|
||||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema_type (str): The type we want to convert.
|
schema_type: the type we want to convert.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
The resulting expression type.
|
||||||
"""
|
"""
|
||||||
if schema_type not in self._type_mapping_cache:
|
if schema_type not in self._type_mapping_cache:
|
||||||
try:
|
try:
|
||||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
|
||||||
schema_type, into=exp.DataType, dialect=self.dialect
|
if expression is None:
|
||||||
).this
|
raise ValueError(f"Could not parse {schema_type}")
|
||||||
|
self._type_mapping_cache[schema_type] = expression.this
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
raise SchemaError(f"Failed to convert type {schema_type}")
|
||||||
|
|
||||||
return self._type_mapping_cache[schema_type]
|
return self._type_mapping_cache[schema_type]
|
||||||
|
|
||||||
def _initialize_supported_args(self):
|
def _build_trie(self, schema: t.Dict):
|
||||||
if not self.supported_table_args:
|
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
|
||||||
depth = _dict_depth(self.schema)
|
|
||||||
|
|
||||||
all_args = ["this", "db", "catalog"]
|
def _nested_get(
|
||||||
if not depth or depth == 1: # {}
|
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
|
||||||
self.supported_table_args = []
|
) -> t.Optional[t.Any]:
|
||||||
elif 2 <= depth <= 4:
|
return _nested_get(
|
||||||
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
|
d or self.schema,
|
||||||
else:
|
*zip(self.supported_table_args, reversed(parts)),
|
||||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
raise_on_missing=raise_on_missing,
|
||||||
|
)
|
||||||
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_schema(schema):
|
def ensure_schema(schema: t.Any) -> Schema:
|
||||||
if isinstance(schema, Schema):
|
if isinstance(schema, Schema):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
return MappingSchema(schema)
|
return MappingSchema(schema)
|
||||||
|
|
||||||
|
|
||||||
def ensure_column_mapping(mapping):
|
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
|
||||||
if isinstance(mapping, dict):
|
if isinstance(mapping, dict):
|
||||||
return mapping
|
return mapping
|
||||||
elif isinstance(mapping, str):
|
elif isinstance(mapping, str):
|
||||||
|
@ -196,7 +259,7 @@ def ensure_column_mapping(mapping):
|
||||||
}
|
}
|
||||||
# Check if mapping looks like a DataFrame StructType
|
# Check if mapping looks like a DataFrame StructType
|
||||||
elif hasattr(mapping, "simpleString"):
|
elif hasattr(mapping, "simpleString"):
|
||||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
|
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
|
||||||
elif isinstance(mapping, list):
|
elif isinstance(mapping, list):
|
||||||
return {x.strip(): None for x in mapping}
|
return {x.strip(): None for x in mapping}
|
||||||
elif mapping is None:
|
elif mapping is None:
|
||||||
|
@ -204,7 +267,20 @@ def ensure_column_mapping(mapping):
|
||||||
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
||||||
|
|
||||||
|
|
||||||
def fs_get(table):
|
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
|
||||||
|
tables = []
|
||||||
|
keys = keys or []
|
||||||
|
depth = _dict_depth(schema)
|
||||||
|
|
||||||
|
for k, v in schema.items():
|
||||||
|
if depth >= 3:
|
||||||
|
tables.extend(flatten_schema(v, keys + [k]))
|
||||||
|
elif depth == 2:
|
||||||
|
tables.append(keys + [k])
|
||||||
|
return tables
|
||||||
|
|
||||||
|
|
||||||
|
def fs_get(table: exp.Table) -> t.List[str]:
|
||||||
name = table.this.name
|
name = table.this.name
|
||||||
|
|
||||||
if name.upper() == "READ_CSV":
|
if name.upper() == "READ_CSV":
|
||||||
|
@ -214,21 +290,23 @@ def fs_get(table):
|
||||||
raise ValueError(f"Cannot read schema for {table}")
|
raise ValueError(f"Cannot read schema for {table}")
|
||||||
|
|
||||||
|
|
||||||
def _nested_get(d, *path, raise_on_missing=True):
|
def _nested_get(
|
||||||
|
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
|
||||||
|
) -> t.Optional[t.Any]:
|
||||||
"""
|
"""
|
||||||
Get a value for a nested dictionary.
|
Get a value for a nested dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
d (dict): dictionary
|
d: the dictionary to search.
|
||||||
*path (tuple[str, str]): tuples of (name, key)
|
*path: tuples of (name, key), where:
|
||||||
`key` is the key in the dictionary to get.
|
`key` is the key in the dictionary to get.
|
||||||
`name` is a string to use in the error if `key` isn't found.
|
`name` is a string to use in the error if `key` isn't found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The value or None if it doesn't exist
|
The value or None if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
for name, key in path:
|
for name, key in path:
|
||||||
d = d.get(key)
|
d = d.get(key) # type: ignore
|
||||||
if d is None:
|
if d is None:
|
||||||
if raise_on_missing:
|
if raise_on_missing:
|
||||||
name = "table" if name == "this" else name
|
name = "table" if name == "this" else name
|
||||||
|
@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _nested_set(d, keys, value):
|
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
|
||||||
"""
|
"""
|
||||||
In-place set a value for a nested dictionary
|
In-place set a value for a nested dictionary
|
||||||
|
|
||||||
Ex:
|
Example:
|
||||||
>>> _nested_set({}, ["top_key", "second_key"], "value")
|
>>> _nested_set({}, ["top_key", "second_key"], "value")
|
||||||
{'top_key': {'second_key': 'value'}}
|
{'top_key': {'second_key': 'value'}}
|
||||||
|
|
||||||
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
||||||
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
|
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
|
||||||
|
|
||||||
d (dict): dictionary
|
Args:
|
||||||
keys (Iterable[str]): ordered iterable of keys that makeup path to value
|
d: dictionary to update.
|
||||||
value (Any): The value to set in the dictionary for the given key path
|
keys: the keys that makeup the path to `value`.
|
||||||
|
value: the value to set in the dictionary for the given key path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The (possibly) updated dictionary.
|
||||||
"""
|
"""
|
||||||
if not keys:
|
if not keys:
|
||||||
return
|
return d
|
||||||
|
|
||||||
if len(keys) == 1:
|
if len(keys) == 1:
|
||||||
d[keys[0]] = value
|
d[keys[0]] = value
|
||||||
return
|
return d
|
||||||
|
|
||||||
subd = d
|
subd = d
|
||||||
for key in keys[:-1]:
|
for key in keys[:-1]:
|
||||||
if key not in subd:
|
if key not in subd:
|
||||||
subd = subd.setdefault(key, {})
|
subd = subd.setdefault(key, {})
|
||||||
else:
|
else:
|
||||||
subd = subd[key]
|
subd = subd[key]
|
||||||
|
|
||||||
subd[keys[-1]] = value
|
subd[keys[-1]] = value
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _dict_depth(d):
|
def _dict_depth(d: t.Dict) -> int:
|
||||||
"""
|
"""
|
||||||
Get the nesting depth of a dictionary.
|
Get the nesting depth of a dictionary.
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
# the generic time format is based on python time.strftime
|
import typing as t
|
||||||
|
|
||||||
|
# The generic time format is based on python time.strftime.
|
||||||
# https://docs.python.org/3/library/time.html#time.strftime
|
# https://docs.python.org/3/library/time.html#time.strftime
|
||||||
from sqlglot.trie import in_trie, new_trie
|
from sqlglot.trie import in_trie, new_trie
|
||||||
|
|
||||||
|
|
||||||
def format_time(string, mapping, trie=None):
|
def format_time(
|
||||||
|
string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None
|
||||||
|
) -> t.Optional[str]:
|
||||||
"""
|
"""
|
||||||
Converts a time string given a mapping.
|
Converts a time string given a mapping.
|
||||||
|
|
||||||
|
@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None):
|
||||||
>>> format_time("%Y", {"%Y": "YYYY"})
|
>>> format_time("%Y", {"%Y": "YYYY"})
|
||||||
'YYYY'
|
'YYYY'
|
||||||
|
|
||||||
mapping: Dictionary of time format to target time format
|
Args:
|
||||||
trie: Optional trie, can be passed in for performance
|
mapping: dictionary of time format to target time format.
|
||||||
|
trie: optional trie, can be passed in for performance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted time string.
|
||||||
"""
|
"""
|
||||||
if not string:
|
if not string:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
end = 1
|
end = 1
|
||||||
size = len(string)
|
size = len(string)
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
|
||||||
from sqlglot.helper import AutoName
|
from sqlglot.helper import AutoName
|
||||||
|
@ -27,6 +30,7 @@ class TokenType(AutoName):
|
||||||
NOT = auto()
|
NOT = auto()
|
||||||
EQ = auto()
|
EQ = auto()
|
||||||
NEQ = auto()
|
NEQ = auto()
|
||||||
|
NULLSAFE_EQ = auto()
|
||||||
AND = auto()
|
AND = auto()
|
||||||
OR = auto()
|
OR = auto()
|
||||||
AMP = auto()
|
AMP = auto()
|
||||||
|
@ -36,12 +40,14 @@ class TokenType(AutoName):
|
||||||
TILDA = auto()
|
TILDA = auto()
|
||||||
ARROW = auto()
|
ARROW = auto()
|
||||||
DARROW = auto()
|
DARROW = auto()
|
||||||
|
FARROW = auto()
|
||||||
|
HASH = auto()
|
||||||
HASH_ARROW = auto()
|
HASH_ARROW = auto()
|
||||||
DHASH_ARROW = auto()
|
DHASH_ARROW = auto()
|
||||||
LR_ARROW = auto()
|
LR_ARROW = auto()
|
||||||
ANNOTATION = auto()
|
|
||||||
DOLLAR = auto()
|
DOLLAR = auto()
|
||||||
PARAMETER = auto()
|
PARAMETER = auto()
|
||||||
|
SESSION_PARAMETER = auto()
|
||||||
|
|
||||||
SPACE = auto()
|
SPACE = auto()
|
||||||
BREAK = auto()
|
BREAK = auto()
|
||||||
|
@ -73,7 +79,7 @@ class TokenType(AutoName):
|
||||||
NVARCHAR = auto()
|
NVARCHAR = auto()
|
||||||
TEXT = auto()
|
TEXT = auto()
|
||||||
BINARY = auto()
|
BINARY = auto()
|
||||||
BYTEA = auto()
|
VARBINARY = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
|
@ -142,6 +148,7 @@ class TokenType(AutoName):
|
||||||
DESCRIBE = auto()
|
DESCRIBE = auto()
|
||||||
DETERMINISTIC = auto()
|
DETERMINISTIC = auto()
|
||||||
DISTINCT = auto()
|
DISTINCT = auto()
|
||||||
|
DISTINCT_FROM = auto()
|
||||||
DISTRIBUTE_BY = auto()
|
DISTRIBUTE_BY = auto()
|
||||||
DIV = auto()
|
DIV = auto()
|
||||||
DROP = auto()
|
DROP = auto()
|
||||||
|
@ -238,6 +245,7 @@ class TokenType(AutoName):
|
||||||
RETURNS = auto()
|
RETURNS = auto()
|
||||||
RIGHT = auto()
|
RIGHT = auto()
|
||||||
RLIKE = auto()
|
RLIKE = auto()
|
||||||
|
ROLLBACK = auto()
|
||||||
ROLLUP = auto()
|
ROLLUP = auto()
|
||||||
ROW = auto()
|
ROW = auto()
|
||||||
ROWS = auto()
|
ROWS = auto()
|
||||||
|
@ -287,37 +295,49 @@ class TokenType(AutoName):
|
||||||
|
|
||||||
|
|
||||||
class Token:
|
class Token:
|
||||||
__slots__ = ("token_type", "text", "line", "col")
|
__slots__ = ("token_type", "text", "line", "col", "comment")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def number(cls, number):
|
def number(cls, number: int) -> Token:
|
||||||
|
"""Returns a NUMBER token with `number` as its text."""
|
||||||
return cls(TokenType.NUMBER, str(number))
|
return cls(TokenType.NUMBER, str(number))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def string(cls, string):
|
def string(cls, string: str) -> Token:
|
||||||
|
"""Returns a STRING token with `string` as its text."""
|
||||||
return cls(TokenType.STRING, string)
|
return cls(TokenType.STRING, string)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def identifier(cls, identifier):
|
def identifier(cls, identifier: str) -> Token:
|
||||||
|
"""Returns an IDENTIFIER token with `identifier` as its text."""
|
||||||
return cls(TokenType.IDENTIFIER, identifier)
|
return cls(TokenType.IDENTIFIER, identifier)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def var(cls, var):
|
def var(cls, var: str) -> Token:
|
||||||
|
"""Returns an VAR token with `var` as its text."""
|
||||||
return cls(TokenType.VAR, var)
|
return cls(TokenType.VAR, var)
|
||||||
|
|
||||||
def __init__(self, token_type, text, line=1, col=1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
token_type: TokenType,
|
||||||
|
text: str,
|
||||||
|
line: int = 1,
|
||||||
|
col: int = 1,
|
||||||
|
comment: t.Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self.token_type = token_type
|
self.token_type = token_type
|
||||||
self.text = text
|
self.text = text
|
||||||
self.line = line
|
self.line = line
|
||||||
self.col = max(col - len(text), 1)
|
self.col = max(col - len(text), 1)
|
||||||
|
self.comment = comment
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||||
return f"<Token {attributes}>"
|
return f"<Token {attributes}>"
|
||||||
|
|
||||||
|
|
||||||
class _Tokenizer(type):
|
class _Tokenizer(type):
|
||||||
def __new__(cls, clsname, bases, attrs):
|
def __new__(cls, clsname, bases, attrs): # type: ignore
|
||||||
klass = super().__new__(cls, clsname, bases, attrs)
|
klass = super().__new__(cls, clsname, bases, attrs)
|
||||||
|
|
||||||
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
||||||
|
@ -325,27 +345,29 @@ class _Tokenizer(type):
|
||||||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||||
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||||
|
klass._ESCAPES = set(klass.ESCAPES)
|
||||||
klass._COMMENTS = dict(
|
klass._COMMENTS = dict(
|
||||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
|
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
||||||
|
for comment in klass.COMMENTS
|
||||||
)
|
)
|
||||||
|
|
||||||
klass.KEYWORD_TRIE = new_trie(
|
klass.KEYWORD_TRIE = new_trie(
|
||||||
key.upper()
|
key.upper()
|
||||||
for key, value in {
|
for key in {
|
||||||
**klass.KEYWORDS,
|
**klass.KEYWORDS,
|
||||||
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
|
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
|
||||||
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
||||||
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
||||||
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
||||||
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
|
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
|
||||||
}.items()
|
}
|
||||||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||||
)
|
)
|
||||||
|
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _delimeter_list_to_dict(list):
|
def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
|
||||||
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
|
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
|
||||||
|
|
||||||
|
|
||||||
|
@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"*": TokenType.STAR,
|
"*": TokenType.STAR,
|
||||||
"~": TokenType.TILDA,
|
"~": TokenType.TILDA,
|
||||||
"?": TokenType.PLACEHOLDER,
|
"?": TokenType.PLACEHOLDER,
|
||||||
"#": TokenType.ANNOTATION,
|
|
||||||
"@": TokenType.PARAMETER,
|
"@": TokenType.PARAMETER,
|
||||||
# used for breaking a var like x'y' but nothing else
|
# used for breaking a var like x'y' but nothing else
|
||||||
# the token type doesn't matter
|
# the token type doesn't matter
|
||||||
"'": TokenType.QUOTE,
|
"'": TokenType.QUOTE,
|
||||||
"`": TokenType.IDENTIFIER,
|
"`": TokenType.IDENTIFIER,
|
||||||
'"': TokenType.IDENTIFIER,
|
'"': TokenType.IDENTIFIER,
|
||||||
|
"#": TokenType.HASH,
|
||||||
}
|
}
|
||||||
|
|
||||||
QUOTES = ["'"]
|
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||||
|
|
||||||
BIT_STRINGS = []
|
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||||
|
|
||||||
HEX_STRINGS = []
|
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||||
|
|
||||||
BYTE_STRINGS = []
|
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||||
|
|
||||||
IDENTIFIERS = ['"']
|
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||||
|
|
||||||
ESCAPE = "'"
|
ESCAPES = ["'"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
"/*+": TokenType.HINT,
|
"/*+": TokenType.HINT,
|
||||||
|
@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"<=": TokenType.LTE,
|
"<=": TokenType.LTE,
|
||||||
"<>": TokenType.NEQ,
|
"<>": TokenType.NEQ,
|
||||||
"!=": TokenType.NEQ,
|
"!=": TokenType.NEQ,
|
||||||
|
"<=>": TokenType.NULLSAFE_EQ,
|
||||||
"->": TokenType.ARROW,
|
"->": TokenType.ARROW,
|
||||||
"->>": TokenType.DARROW,
|
"->>": TokenType.DARROW,
|
||||||
|
"=>": TokenType.FARROW,
|
||||||
"#>": TokenType.HASH_ARROW,
|
"#>": TokenType.HASH_ARROW,
|
||||||
"#>>": TokenType.DHASH_ARROW,
|
"#>>": TokenType.DHASH_ARROW,
|
||||||
"<->": TokenType.LR_ARROW,
|
"<->": TokenType.LR_ARROW,
|
||||||
|
@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"DESCRIBE": TokenType.DESCRIBE,
|
"DESCRIBE": TokenType.DESCRIBE,
|
||||||
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
||||||
"DISTINCT": TokenType.DISTINCT,
|
"DISTINCT": TokenType.DISTINCT,
|
||||||
|
"DISTINCT FROM": TokenType.DISTINCT_FROM,
|
||||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||||
"DIV": TokenType.DIV,
|
"DIV": TokenType.DIV,
|
||||||
"DROP": TokenType.DROP,
|
"DROP": TokenType.DROP,
|
||||||
|
@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"RETURNS": TokenType.RETURNS,
|
"RETURNS": TokenType.RETURNS,
|
||||||
"RIGHT": TokenType.RIGHT,
|
"RIGHT": TokenType.RIGHT,
|
||||||
"RLIKE": TokenType.RLIKE,
|
"RLIKE": TokenType.RLIKE,
|
||||||
|
"ROLLBACK": TokenType.ROLLBACK,
|
||||||
"ROLLUP": TokenType.ROLLUP,
|
"ROLLUP": TokenType.ROLLUP,
|
||||||
"ROW": TokenType.ROW,
|
"ROW": TokenType.ROW,
|
||||||
"ROWS": TokenType.ROWS,
|
"ROWS": TokenType.ROWS,
|
||||||
|
@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"TEXT": TokenType.TEXT,
|
"TEXT": TokenType.TEXT,
|
||||||
"CLOB": TokenType.TEXT,
|
"CLOB": TokenType.TEXT,
|
||||||
"BINARY": TokenType.BINARY,
|
"BINARY": TokenType.BINARY,
|
||||||
"BLOB": TokenType.BINARY,
|
"BLOB": TokenType.VARBINARY,
|
||||||
"BYTEA": TokenType.BINARY,
|
"BYTEA": TokenType.VARBINARY,
|
||||||
|
"VARBINARY": TokenType.VARBINARY,
|
||||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||||
|
@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
TokenType.SET,
|
TokenType.SET,
|
||||||
TokenType.SHOW,
|
TokenType.SHOW,
|
||||||
TokenType.TRUNCATE,
|
TokenType.TRUNCATE,
|
||||||
TokenType.USE,
|
|
||||||
TokenType.VACUUM,
|
TokenType.VACUUM,
|
||||||
|
TokenType.ROLLBACK,
|
||||||
}
|
}
|
||||||
|
|
||||||
# handle numeric literals like in hive (3L = BIGINT)
|
# handle numeric literals like in hive (3L = BIGINT)
|
||||||
NUMERIC_LITERALS = {}
|
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||||
ENCODE = None
|
ENCODE: t.Optional[str] = None
|
||||||
|
|
||||||
COMMENTS = ["--", ("/*", "*/")]
|
COMMENTS = ["--", ("/*", "*/")]
|
||||||
KEYWORD_TRIE = None # autofilled
|
KEYWORD_TRIE = None # autofilled
|
||||||
|
@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"_current",
|
"_current",
|
||||||
"_line",
|
"_line",
|
||||||
"_col",
|
"_col",
|
||||||
|
"_comment",
|
||||||
"_char",
|
"_char",
|
||||||
"_end",
|
"_end",
|
||||||
"_peek",
|
"_peek",
|
||||||
|
"_prev_token_line",
|
||||||
|
"_prev_token_comment",
|
||||||
"_prev_token_type",
|
"_prev_token_type",
|
||||||
|
"_replace_backslash",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""
|
self._replace_backslash = "\\" in self._ESCAPES # type: ignore
|
||||||
Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
|
|
||||||
"""
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
self.sql = ""
|
self.sql = ""
|
||||||
self.size = 0
|
self.size = 0
|
||||||
self.tokens = []
|
self.tokens: t.List[Token] = []
|
||||||
self._start = 0
|
self._start = 0
|
||||||
self._current = 0
|
self._current = 0
|
||||||
self._line = 1
|
self._line = 1
|
||||||
self._col = 1
|
self._col = 1
|
||||||
|
self._comment = None
|
||||||
|
|
||||||
self._char = None
|
self._char = None
|
||||||
self._end = None
|
self._end = None
|
||||||
self._peek = None
|
self._peek = None
|
||||||
|
self._prev_token_line = -1
|
||||||
|
self._prev_token_comment = None
|
||||||
self._prev_token_type = None
|
self._prev_token_type = None
|
||||||
|
|
||||||
def tokenize(self, sql):
|
def tokenize(self, sql: str) -> t.List[Token]:
|
||||||
|
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||||
self.reset()
|
self.reset()
|
||||||
self.sql = sql
|
self.sql = sql
|
||||||
self.size = len(sql)
|
self.size = len(sql)
|
||||||
|
@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
if not self._char:
|
if not self._char:
|
||||||
break
|
break
|
||||||
|
|
||||||
white_space = self.WHITE_SPACE.get(self._char)
|
white_space = self.WHITE_SPACE.get(self._char) # type: ignore
|
||||||
identifier_end = self._IDENTIFIERS.get(self._char)
|
identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
|
||||||
|
|
||||||
if white_space:
|
if white_space:
|
||||||
if white_space == TokenType.BREAK:
|
if white_space == TokenType.BREAK:
|
||||||
self._col = 1
|
self._col = 1
|
||||||
self._line += 1
|
self._line += 1
|
||||||
elif self._char.isdigit():
|
elif self._char.isdigit(): # type:ignore
|
||||||
self._scan_number()
|
self._scan_number()
|
||||||
elif identifier_end:
|
elif identifier_end:
|
||||||
self._scan_identifier(identifier_end)
|
self._scan_identifier(identifier_end)
|
||||||
|
@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._scan_keywords()
|
self._scan_keywords()
|
||||||
return self.tokens
|
return self.tokens
|
||||||
|
|
||||||
def _chars(self, size):
|
def _chars(self, size: int) -> str:
|
||||||
if size == 1:
|
if size == 1:
|
||||||
return self._char
|
return self._char # type: ignore
|
||||||
start = self._current - 1
|
start = self._current - 1
|
||||||
end = start + size
|
end = start + size
|
||||||
if end <= self.size:
|
if end <= self.size:
|
||||||
return self.sql[start:end]
|
return self.sql[start:end]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _advance(self, i=1):
|
def _advance(self, i: int = 1) -> None:
|
||||||
self._col += i
|
self._col += i
|
||||||
self._current += i
|
self._current += i
|
||||||
self._end = self._current >= self.size
|
self._end = self._current >= self.size # type: ignore
|
||||||
self._char = self.sql[self._current - 1]
|
self._char = self.sql[self._current - 1] # type: ignore
|
||||||
self._peek = self.sql[self._current] if self._current < self.size else ""
|
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _text(self):
|
def _text(self) -> str:
|
||||||
return self.sql[self._start : self._current]
|
return self.sql[self._start : self._current]
|
||||||
|
|
||||||
def _add(self, token_type, text=None):
|
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||||
self._prev_token_type = token_type
|
self._prev_token_line = self._line
|
||||||
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
|
self._prev_token_comment = self._comment
|
||||||
|
self._prev_token_type = token_type # type: ignore
|
||||||
|
self.tokens.append(
|
||||||
|
Token(
|
||||||
|
token_type,
|
||||||
|
self._text if text is None else text,
|
||||||
|
self._line,
|
||||||
|
self._col,
|
||||||
|
self._comment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._comment = None
|
||||||
|
|
||||||
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
|
if token_type in self.COMMANDS and (
|
||||||
|
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
||||||
|
):
|
||||||
self._start = self._current
|
self._start = self._current
|
||||||
while not self._end and self._peek != ";":
|
while not self._end and self._peek != ";":
|
||||||
self._advance()
|
self._advance()
|
||||||
if self._start < self._current:
|
if self._start < self._current:
|
||||||
self._add(TokenType.STRING)
|
self._add(TokenType.STRING)
|
||||||
|
|
||||||
def _scan_keywords(self):
|
def _scan_keywords(self) -> None:
|
||||||
size = 0
|
size = 0
|
||||||
word = None
|
word = None
|
||||||
chars = self._text
|
chars = self._text
|
||||||
|
@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
if skip:
|
if skip:
|
||||||
result = 1
|
result = 1
|
||||||
else:
|
else:
|
||||||
result, trie = in_trie(trie, char.upper())
|
result, trie = in_trie(trie, char.upper()) # type: ignore
|
||||||
|
|
||||||
if result == 0:
|
if result == 0:
|
||||||
break
|
break
|
||||||
|
@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
else:
|
else:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
chars = None
|
chars = None # type: ignore
|
||||||
|
|
||||||
if not word:
|
if not word:
|
||||||
if self._char in self.SINGLE_TOKENS:
|
if self._char in self.SINGLE_TOKENS:
|
||||||
token = self.SINGLE_TOKENS[self._char]
|
self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
|
||||||
if token == TokenType.ANNOTATION:
|
|
||||||
self._scan_annotation()
|
|
||||||
return
|
|
||||||
self._add(token)
|
|
||||||
return
|
return
|
||||||
self._scan_var()
|
self._scan_var()
|
||||||
return
|
return
|
||||||
|
@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._advance(size - 1)
|
self._advance(size - 1)
|
||||||
self._add(self.KEYWORDS[word.upper()])
|
self._add(self.KEYWORDS[word.upper()])
|
||||||
|
|
||||||
def _scan_comment(self, comment_start):
|
def _scan_comment(self, comment_start: str) -> bool:
|
||||||
if comment_start not in self._COMMENTS:
|
if comment_start not in self._COMMENTS: # type: ignore
|
||||||
return False
|
return False
|
||||||
|
|
||||||
comment_end = self._COMMENTS[comment_start]
|
comment_start_line = self._line
|
||||||
|
comment_start_size = len(comment_start)
|
||||||
|
comment_end = self._COMMENTS[comment_start] # type: ignore
|
||||||
|
|
||||||
if comment_end:
|
if comment_end:
|
||||||
comment_end_size = len(comment_end)
|
comment_end_size = len(comment_end)
|
||||||
|
|
||||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||||
self._advance()
|
self._advance()
|
||||||
|
|
||||||
|
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
|
||||||
self._advance(comment_end_size - 1)
|
self._advance(comment_end_size - 1)
|
||||||
else:
|
else:
|
||||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
|
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
|
self._comment = self._text[comment_start_size:] # type: ignore
|
||||||
|
|
||||||
|
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
|
||||||
|
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
|
||||||
|
|
||||||
|
if comment_start_line == self._prev_token_line:
|
||||||
|
if self._prev_token_comment is None:
|
||||||
|
self.tokens[-1].comment = self._comment
|
||||||
|
|
||||||
|
self._comment = None
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _scan_annotation(self):
|
def _scan_number(self) -> None:
|
||||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
|
|
||||||
self._advance()
|
|
||||||
self._add(TokenType.ANNOTATION, self._text[1:])
|
|
||||||
|
|
||||||
def _scan_number(self):
|
|
||||||
if self._char == "0":
|
if self._char == "0":
|
||||||
peek = self._peek.upper()
|
peek = self._peek.upper() # type: ignore
|
||||||
if peek == "B":
|
if peek == "B":
|
||||||
return self._scan_bits()
|
return self._scan_bits()
|
||||||
elif peek == "X":
|
elif peek == "X":
|
||||||
|
@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
scientific = 0
|
scientific = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._peek.isdigit():
|
if self._peek.isdigit(): # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
elif self._peek == "." and not decimal:
|
elif self._peek == "." and not decimal:
|
||||||
decimal = True
|
decimal = True
|
||||||
|
@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
elif self._peek in ("-", "+") and scientific == 1:
|
elif self._peek in ("-", "+") and scientific == 1:
|
||||||
scientific += 1
|
scientific += 1
|
||||||
self._advance()
|
self._advance()
|
||||||
elif self._peek.upper() == "E" and not scientific:
|
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
||||||
scientific += 1
|
scientific += 1
|
||||||
self._advance()
|
self._advance()
|
||||||
elif self._peek.isalpha():
|
elif self._peek.isalpha(): # type: ignore
|
||||||
self._add(TokenType.NUMBER)
|
self._add(TokenType.NUMBER)
|
||||||
literal = []
|
literal = []
|
||||||
while self._peek.isalpha():
|
while self._peek.isalpha(): # type: ignore
|
||||||
literal.append(self._peek.upper())
|
literal.append(self._peek.upper()) # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
literal = "".join(literal)
|
literal = "".join(literal) # type: ignore
|
||||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
|
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
||||||
if token_type:
|
if token_type:
|
||||||
self._add(TokenType.DCOLON, "::")
|
self._add(TokenType.DCOLON, "::")
|
||||||
return self._add(token_type, literal)
|
return self._add(token_type, literal) # type: ignore
|
||||||
return self._advance(-len(literal))
|
return self._advance(-len(literal))
|
||||||
else:
|
else:
|
||||||
return self._add(TokenType.NUMBER)
|
return self._add(TokenType.NUMBER)
|
||||||
|
|
||||||
def _scan_bits(self):
|
def _scan_bits(self) -> None:
|
||||||
self._advance()
|
self._advance()
|
||||||
value = self._extract_value()
|
value = self._extract_value()
|
||||||
try:
|
try:
|
||||||
|
@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self._add(TokenType.IDENTIFIER)
|
self._add(TokenType.IDENTIFIER)
|
||||||
|
|
||||||
def _scan_hex(self):
|
def _scan_hex(self) -> None:
|
||||||
self._advance()
|
self._advance()
|
||||||
value = self._extract_value()
|
value = self._extract_value()
|
||||||
try:
|
try:
|
||||||
|
@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self._add(TokenType.IDENTIFIER)
|
self._add(TokenType.IDENTIFIER)
|
||||||
|
|
||||||
def _extract_value(self):
|
def _extract_value(self) -> str:
|
||||||
while True:
|
while True:
|
||||||
char = self._peek.strip()
|
char = self._peek.strip() # type: ignore
|
||||||
if char and char not in self.SINGLE_TOKENS:
|
if char and char not in self.SINGLE_TOKENS:
|
||||||
self._advance()
|
self._advance()
|
||||||
else:
|
else:
|
||||||
|
@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
return self._text
|
return self._text
|
||||||
|
|
||||||
def _scan_string(self, quote):
|
def _scan_string(self, quote: str) -> bool:
|
||||||
quote_end = self._QUOTES.get(quote)
|
quote_end = self._QUOTES.get(quote) # type: ignore
|
||||||
if quote_end is None:
|
if quote_end is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._advance(len(quote))
|
self._advance(len(quote))
|
||||||
text = self._extract_string(quote_end)
|
text = self._extract_string(quote_end)
|
||||||
|
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
|
||||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
text = text.replace("\\\\", "\\") if self._replace_backslash else text
|
||||||
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
|
|
||||||
self._add(TokenType.STRING, text)
|
self._add(TokenType.STRING, text)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# X'1234, b'0110', E'\\\\\' etc.
|
# X'1234, b'0110', E'\\\\\' etc.
|
||||||
def _scan_formatted_string(self, string_start):
|
def _scan_formatted_string(self, string_start: str) -> bool:
|
||||||
if string_start in self._HEX_STRINGS:
|
if string_start in self._HEX_STRINGS: # type: ignore
|
||||||
delimiters = self._HEX_STRINGS
|
delimiters = self._HEX_STRINGS # type: ignore
|
||||||
token_type = TokenType.HEX_STRING
|
token_type = TokenType.HEX_STRING
|
||||||
base = 16
|
base = 16
|
||||||
elif string_start in self._BIT_STRINGS:
|
elif string_start in self._BIT_STRINGS: # type: ignore
|
||||||
delimiters = self._BIT_STRINGS
|
delimiters = self._BIT_STRINGS # type: ignore
|
||||||
token_type = TokenType.BIT_STRING
|
token_type = TokenType.BIT_STRING
|
||||||
base = 2
|
base = 2
|
||||||
elif string_start in self._BYTE_STRINGS:
|
elif string_start in self._BYTE_STRINGS: # type: ignore
|
||||||
delimiters = self._BYTE_STRINGS
|
delimiters = self._BYTE_STRINGS # type: ignore
|
||||||
token_type = TokenType.BYTE_STRING
|
token_type = TokenType.BYTE_STRING
|
||||||
base = None
|
base = None
|
||||||
else:
|
else:
|
||||||
|
@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
try:
|
try:
|
||||||
self._add(token_type, f"{int(text, base)}")
|
self._add(token_type, f"{int(text, base)}")
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
|
raise RuntimeError(
|
||||||
|
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _scan_identifier(self, identifier_end):
|
def _scan_identifier(self, identifier_end: str) -> None:
|
||||||
while self._peek != identifier_end:
|
while self._peek != identifier_end:
|
||||||
if self._end:
|
if self._end:
|
||||||
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
||||||
|
@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
self._advance()
|
self._advance()
|
||||||
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
||||||
|
|
||||||
def _scan_var(self):
|
def _scan_var(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
char = self._peek.strip()
|
char = self._peek.strip() # type: ignore
|
||||||
if char and char not in self.SINGLE_TOKENS:
|
if char and char not in self.SINGLE_TOKENS:
|
||||||
self._advance()
|
self._advance()
|
||||||
else:
|
else:
|
||||||
|
@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_string(self, delimiter):
|
def _extract_string(self, delimiter: str) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
delim_size = len(delimiter)
|
delim_size = len(delimiter)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._char == self.ESCAPE and self._peek == delimiter:
|
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
|
||||||
text += delimiter
|
text += delimiter
|
||||||
self._advance(2)
|
self._advance(2)
|
||||||
else:
|
else:
|
||||||
|
@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
if self._end:
|
if self._end:
|
||||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||||
text += self._char
|
text += self._char # type: ignore
|
||||||
self._advance()
|
self._advance()
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.generator import Generator
|
||||||
|
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
|
|
||||||
|
|
||||||
def unalias_group(expression):
|
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||||
"""
|
"""
|
||||||
Replace references to select aliases in GROUP BY clauses.
|
Replace references to select aliases in GROUP BY clauses.
|
||||||
|
|
||||||
|
@ -9,6 +16,12 @@ def unalias_group(expression):
|
||||||
>>> import sqlglot
|
>>> import sqlglot
|
||||||
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
|
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
|
||||||
'SELECT a AS b FROM x GROUP BY 1'
|
'SELECT a AS b FROM x GROUP BY 1'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: the expression that will be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transformed expression.
|
||||||
"""
|
"""
|
||||||
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
|
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
|
||||||
aliased_selects = {
|
aliased_selects = {
|
||||||
|
@ -30,18 +43,19 @@ def unalias_group(expression):
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def preprocess(transforms, to_sql):
|
def preprocess(
|
||||||
|
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||||
|
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||||
|
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||||
"""
|
"""
|
||||||
Create a new transform function that can be used a value in `Generator.TRANSFORMS`
|
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||||
to convert expressions to SQL.
|
expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transforms (list[(exp.Expression) -> exp.Expression]):
|
transforms: sequence of transform functions. These will be called in order.
|
||||||
Sequence of transform functions. These will be called in order.
|
to_sql: final transform that converts the resulting expression to a SQL string.
|
||||||
to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
|
|
||||||
Final transform that converts the resulting expression to a SQL string.
|
|
||||||
Returns:
|
Returns:
|
||||||
(sqlglot.generator.Generator, exp.Expression) -> str:
|
|
||||||
Function that can be used as a generator transform.
|
Function that can be used as a generator transform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -54,12 +68,10 @@ def preprocess(transforms, to_sql):
|
||||||
return _to_sql
|
return _to_sql
|
||||||
|
|
||||||
|
|
||||||
def delegate(attr):
|
def delegate(attr: str) -> t.Callable:
|
||||||
"""
|
"""
|
||||||
Create a new method that delegates to `attr`.
|
Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
|
||||||
|
functions that delegate to existing generator methods.
|
||||||
This is useful for creating `Generator.TRANSFORMS` functions that delegate
|
|
||||||
to existing generator methods.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _transform(self, *args, **kwargs):
|
def _transform(self, *args, **kwargs):
|
||||||
|
|
|
@ -1,5 +1,26 @@
|
||||||
def new_trie(keywords):
|
import typing as t
|
||||||
trie = {}
|
|
||||||
|
key = t.Sequence[t.Hashable]
|
||||||
|
|
||||||
|
|
||||||
|
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
|
||||||
|
"""
|
||||||
|
Creates a new trie out of a collection of keywords.
|
||||||
|
|
||||||
|
The trie is represented as a sequence of nested dictionaries keyed by either single character
|
||||||
|
strings, or by 0, which is used to designate that a keyword is in the trie.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> new_trie(["bla", "foo", "blab"])
|
||||||
|
{'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keywords: the keywords to create the trie from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The trie corresponding to `keywords`.
|
||||||
|
"""
|
||||||
|
trie: t.Dict = {}
|
||||||
|
|
||||||
for key in keywords:
|
for key in keywords:
|
||||||
current = trie
|
current = trie
|
||||||
|
@ -11,7 +32,28 @@ def new_trie(keywords):
|
||||||
return trie
|
return trie
|
||||||
|
|
||||||
|
|
||||||
def in_trie(trie, key):
|
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
|
||||||
|
"""
|
||||||
|
Checks whether a key is in a trie.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> in_trie(new_trie(["cat"]), "bob")
|
||||||
|
(0, {'c': {'a': {'t': {0: True}}}})
|
||||||
|
|
||||||
|
>>> in_trie(new_trie(["cat"]), "ca")
|
||||||
|
(1, {'t': {0: True}})
|
||||||
|
|
||||||
|
>>> in_trie(new_trie(["cat"]), "cat")
|
||||||
|
(2, {0: True})
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trie: the trie to be searched.
|
||||||
|
key: the target key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
||||||
|
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||||
|
"""
|
||||||
if not key:
|
if not key:
|
||||||
return (0, trie)
|
return (0, trie)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import sys
|
|
||||||
import typing as t
|
import typing as t
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import sqlglot
|
import sqlglot
|
||||||
|
from sqlglot.helper import PYTHON_VERSION
|
||||||
from tests.helpers import SKIP_INTEGRATION
|
from tests.helpers import SKIP_INTEGRATION
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
|
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
|
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
|
||||||
|
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
|
||||||
)
|
)
|
||||||
class DataFrameValidator(unittest.TestCase):
|
class DataFrameValidator(unittest.TestCase):
|
||||||
spark = None
|
spark = None
|
||||||
|
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
|
|
||||||
# This is for test `test_branching_root_dataframes`
|
# This is for test `test_branching_root_dataframes`
|
||||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
cls.spark = (
|
||||||
|
SparkSession.builder.master("local[*]")
|
||||||
|
.appName("Unit-tests")
|
||||||
|
.config(conf=config)
|
||||||
|
.getOrCreate()
|
||||||
|
)
|
||||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||||
cls.sqlglot = SqlglotSparkSession()
|
cls.sqlglot = SqlglotSparkSession()
|
||||||
cls.spark_employee_schema = types.StructType(
|
cls.spark_employee_schema = types.StructType(
|
||||||
|
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
)
|
)
|
||||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||||
[
|
[
|
||||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField(
|
||||||
|
"employee_id", sqlglotSparkTypes.IntegerType(), False
|
||||||
|
),
|
||||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||||
|
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
(4, "Claire", "Littleton", 27, 2),
|
(4, "Claire", "Littleton", 27, 2),
|
||||||
(5, "Hugo", "Reyes", 29, 100),
|
(5, "Hugo", "Reyes", 29, 100),
|
||||||
]
|
]
|
||||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
cls.df_employee = cls.spark.createDataFrame(
|
||||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
data=employee_data, schema=cls.spark_employee_schema
|
||||||
|
)
|
||||||
|
cls.dfs_employee = cls.sqlglot.createDataFrame(
|
||||||
|
data=employee_data, schema=cls.sqlglot_employee_schema
|
||||||
|
)
|
||||||
cls.df_employee.createOrReplaceTempView("employee")
|
cls.df_employee.createOrReplaceTempView("employee")
|
||||||
|
|
||||||
cls.spark_store_schema = types.StructType(
|
cls.spark_store_schema = types.StructType(
|
||||||
|
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
[
|
[
|
||||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField(
|
||||||
|
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||||
|
),
|
||||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
(2, "Arrow", 2, 2000),
|
(2, "Arrow", 2, 2000),
|
||||||
]
|
]
|
||||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
cls.dfs_store = cls.sqlglot.createDataFrame(
|
||||||
|
data=store_data, schema=cls.sqlglot_store_schema
|
||||||
|
)
|
||||||
cls.df_store.createOrReplaceTempView("store")
|
cls.df_store.createOrReplaceTempView("store")
|
||||||
|
|
||||||
cls.spark_district_schema = types.StructType(
|
cls.spark_district_schema = types.StructType(
|
||||||
|
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
|
||||||
)
|
)
|
||||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||||
[
|
[
|
||||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
sqlglotSparkTypes.StructField(
|
||||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
"district_id", sqlglotSparkTypes.IntegerType(), False
|
||||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
),
|
||||||
|
sqlglotSparkTypes.StructField(
|
||||||
|
"district_name", sqlglotSparkTypes.StringType(), False
|
||||||
|
),
|
||||||
|
sqlglotSparkTypes.StructField(
|
||||||
|
"manager_name", sqlglotSparkTypes.StringType(), False
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
district_data = [
|
district_data = [
|
||||||
(1, "Temple", "Dogen"),
|
(1, "Temple", "Dogen"),
|
||||||
(2, "Lighthouse", "Jacob"),
|
(2, "Lighthouse", "Jacob"),
|
||||||
]
|
]
|
||||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
cls.df_district = cls.spark.createDataFrame(
|
||||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
data=district_data, schema=cls.spark_district_schema
|
||||||
|
)
|
||||||
|
cls.dfs_district = cls.sqlglot.createDataFrame(
|
||||||
|
data=district_data, schema=cls.sqlglot_district_schema
|
||||||
|
)
|
||||||
cls.df_district.createOrReplaceTempView("district")
|
cls.df_district.createOrReplaceTempView("district")
|
||||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||||
|
|
|
@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
|
|
||||||
def test_alias_with_select(self):
|
def test_alias_with_select(self):
|
||||||
df_employee = self.df_spark_employee.alias("df_employee").select(
|
df_employee = self.df_spark_employee.alias("df_employee").select(
|
||||||
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
|
self.df_spark_employee["employee_id"],
|
||||||
|
F.col("df_employee.fname"),
|
||||||
|
self.df_spark_employee.lname,
|
||||||
)
|
)
|
||||||
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
|
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
|
||||||
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
|
self.df_sqlglot_employee["employee_id"],
|
||||||
|
SF.col("dfs_employee.fname"),
|
||||||
|
self.df_sqlglot_employee.lname,
|
||||||
)
|
)
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
def test_case_when_otherwise(self):
|
def test_case_when_otherwise(self):
|
||||||
df = self.df_spark_employee.select(
|
df = self.df_spark_employee.select(
|
||||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
|
F.when(
|
||||||
|
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||||
|
F.lit("between 40 and 60"),
|
||||||
|
)
|
||||||
.when(F.col("age") < F.lit(40), "less than 40")
|
.when(F.col("age") < F.lit(40), "less than 40")
|
||||||
.otherwise("greater than 60")
|
.otherwise("greater than 60")
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
|
SF.when(
|
||||||
|
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||||
|
SF.lit("between 40 and 60"),
|
||||||
|
)
|
||||||
.when(SF.col("age") < SF.lit(40), "less than 40")
|
.when(SF.col("age") < SF.lit(40), "less than 40")
|
||||||
.otherwise("greater than 60")
|
.otherwise("greater than 60")
|
||||||
)
|
)
|
||||||
|
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
|
|
||||||
def test_case_when_no_otherwise(self):
|
def test_case_when_no_otherwise(self):
|
||||||
df = self.df_spark_employee.select(
|
df = self.df_spark_employee.select(
|
||||||
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
|
F.when(
|
||||||
F.col("age") < F.lit(40), "less than 40"
|
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
|
||||||
)
|
F.lit("between 40 and 60"),
|
||||||
|
).when(F.col("age") < F.lit(40), "less than 40")
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
|
SF.when(
|
||||||
SF.col("age") < SF.lit(40), "less than 40"
|
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
|
||||||
)
|
SF.lit("between 40 and 60"),
|
||||||
|
).when(SF.col("age") < SF.lit(40), "less than 40")
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
def test_where_clause_multiple_and(self):
|
def test_where_clause_multiple_and(self):
|
||||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
|
df_employee = self.df_spark_employee.where(
|
||||||
|
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
|
||||||
|
)
|
||||||
dfs_employee = self.df_sqlglot_employee.where(
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
|
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
|
||||||
)
|
)
|
||||||
|
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
def test_where_clause_multiple_or(self):
|
def test_where_clause_multiple_or(self):
|
||||||
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
|
df_employee = self.df_spark_employee.where(
|
||||||
|
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
|
||||||
|
)
|
||||||
dfs_employee = self.df_sqlglot_employee.where(
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
|
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
|
||||||
)
|
)
|
||||||
|
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
|
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
|
df_employee = self.df_spark_employee.where(
|
||||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
|
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
|
||||||
|
)
|
||||||
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
|
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
|
||||||
|
)
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
|
df_employee = self.df_spark_employee.where(
|
||||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
|
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
|
||||||
|
)
|
||||||
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
|
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
|
||||||
|
)
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
|
df_employee = self.df_spark_employee.where(
|
||||||
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
|
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
|
||||||
|
)
|
||||||
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
|
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
|
||||||
|
)
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
df_employee = self.df_spark_employee.where(
|
df_employee = self.df_spark_employee.where(
|
||||||
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
|
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
|
||||||
)
|
)
|
||||||
dfs_employee = self.df_sqlglot_employee.where(
|
dfs_employee = self.df_sqlglot_employee.where(
|
||||||
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
|
self.df_sqlglot_employee["age"] * SF.lit(0.5)
|
||||||
|
== self.df_sqlglot_employee["age"] / SF.lit(2)
|
||||||
)
|
)
|
||||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||||
|
|
||||||
def test_join_inner(self):
|
def test_join_inner(self):
|
||||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
|
df_joined = self.df_spark_employee.join(
|
||||||
|
self.df_spark_store, on=["store_id"], how="inner"
|
||||||
|
).select(
|
||||||
self.df_spark_employee.employee_id,
|
self.df_spark_employee.employee_id,
|
||||||
self.df_spark_employee["fname"],
|
self.df_spark_employee["fname"],
|
||||||
F.col("lname"),
|
F.col("lname"),
|
||||||
|
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.df_spark_store.store_name,
|
self.df_spark_store.store_name,
|
||||||
self.df_spark_store["num_sales"],
|
self.df_spark_store["num_sales"],
|
||||||
)
|
)
|
||||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
|
dfs_joined = self.df_sqlglot_employee.join(
|
||||||
|
self.df_sqlglot_store, on=["store_id"], how="inner"
|
||||||
|
).select(
|
||||||
self.df_sqlglot_employee.employee_id,
|
self.df_sqlglot_employee.employee_id,
|
||||||
self.df_sqlglot_employee["fname"],
|
self.df_sqlglot_employee["fname"],
|
||||||
SF.col("lname"),
|
SF.col("lname"),
|
||||||
|
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||||
|
|
||||||
def test_join_inner_no_select(self):
|
def test_join_inner_no_select(self):
|
||||||
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
|
df_joined = self.df_spark_employee.select(
|
||||||
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
|
F.col("store_id"), F.col("fname"), F.col("lname")
|
||||||
|
).join(
|
||||||
|
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
|
||||||
|
on=["store_id"],
|
||||||
|
how="inner",
|
||||||
)
|
)
|
||||||
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
|
dfs_joined = self.df_sqlglot_employee.select(
|
||||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
|
SF.col("store_id"), SF.col("fname"), SF.col("lname")
|
||||||
|
).join(
|
||||||
|
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
|
||||||
|
on=["store_id"],
|
||||||
|
how="inner",
|
||||||
)
|
)
|
||||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||||
|
|
||||||
def test_join_inner_equality_single(self):
|
def test_join_inner_equality_single(self):
|
||||||
df_joined = self.df_spark_employee.join(
|
df_joined = self.df_spark_employee.join(
|
||||||
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
|
self.df_spark_store,
|
||||||
|
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
|
||||||
|
how="inner",
|
||||||
).select(
|
).select(
|
||||||
self.df_spark_employee.employee_id,
|
self.df_spark_employee.employee_id,
|
||||||
self.df_spark_employee["fname"],
|
self.df_spark_employee["fname"],
|
||||||
|
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.df_spark_store["num_sales"],
|
self.df_spark_store["num_sales"],
|
||||||
)
|
)
|
||||||
dfs_joined = self.df_sqlglot_employee.join(
|
dfs_joined = self.df_sqlglot_employee.join(
|
||||||
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
|
self.df_sqlglot_store,
|
||||||
|
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
|
||||||
|
how="inner",
|
||||||
).select(
|
).select(
|
||||||
self.df_sqlglot_employee.employee_id,
|
self.df_sqlglot_employee.employee_id,
|
||||||
self.df_sqlglot_employee["fname"],
|
self.df_sqlglot_employee["fname"],
|
||||||
|
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
|
||||||
|
|
||||||
def test_join_full_outer(self):
|
def test_join_full_outer(self):
|
||||||
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
|
df_joined = self.df_spark_employee.join(
|
||||||
|
self.df_spark_store, on=["store_id"], how="full_outer"
|
||||||
|
).select(
|
||||||
self.df_spark_employee.employee_id,
|
self.df_spark_employee.employee_id,
|
||||||
self.df_spark_employee["fname"],
|
self.df_spark_employee["fname"],
|
||||||
F.col("lname"),
|
F.col("lname"),
|
||||||
|
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.df_spark_store.store_name,
|
self.df_spark_store.store_name,
|
||||||
self.df_spark_store["num_sales"],
|
self.df_spark_store["num_sales"],
|
||||||
)
|
)
|
||||||
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
|
dfs_joined = self.df_sqlglot_employee.join(
|
||||||
|
self.df_sqlglot_store, on=["store_id"], how="full_outer"
|
||||||
|
).select(
|
||||||
self.df_sqlglot_employee.employee_id,
|
self.df_sqlglot_employee.employee_id,
|
||||||
self.df_sqlglot_employee["fname"],
|
self.df_sqlglot_employee["fname"],
|
||||||
SF.col("lname"),
|
SF.col("lname"),
|
||||||
|
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
|
|
||||||
def test_triple_join(self):
|
def test_triple_join(self):
|
||||||
df = (
|
df = (
|
||||||
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
|
self.df_employee.join(
|
||||||
|
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
|
||||||
|
)
|
||||||
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
|
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
|
||||||
.select(
|
.select(
|
||||||
self.df_employee.employee_id,
|
self.df_employee.employee_id,
|
||||||
|
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
dfs = (
|
dfs = (
|
||||||
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
|
self.dfs_employee.join(
|
||||||
|
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
|
||||||
|
)
|
||||||
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
|
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
|
||||||
.select(
|
.select(
|
||||||
self.dfs_employee.employee_id,
|
self.dfs_employee.employee_id,
|
||||||
|
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_join_select_and_select_start(self):
|
def test_join_select_and_select_start(self):
|
||||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
|
df = self.df_spark_employee.select(
|
||||||
self.df_spark_store, "store_id", "inner"
|
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
|
||||||
)
|
).join(self.df_spark_store, "store_id", "inner")
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
self.df_sqlglot_store, "store_id", "inner"
|
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||||
)
|
).join(self.df_sqlglot_store, "store_id", "inner")
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
|
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
dfs_unioned = (
|
dfs_unioned = (
|
||||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
|
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
|
||||||
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
|
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
|
||||||
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
|
.unionAll(
|
||||||
|
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
|
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
|
||||||
|
|
||||||
def test_union_by_name(self):
|
def test_union_by_name(self):
|
||||||
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
|
df = self.df_spark_employee.select(
|
||||||
|
F.col("employee_id"), F.col("fname"), F.col("lname")
|
||||||
|
).unionByName(
|
||||||
self.df_spark_store.select(
|
self.df_spark_store.select(
|
||||||
F.col("store_name").alias("lname"),
|
F.col("store_name").alias("lname"),
|
||||||
F.col("store_id").alias("employee_id"),
|
F.col("store_id").alias("employee_id"),
|
||||||
|
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
|
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
|
||||||
|
).unionByName(
|
||||||
self.df_sqlglot_store.select(
|
self.df_sqlglot_store.select(
|
||||||
SF.col("store_name").alias("lname"),
|
SF.col("store_name").alias("lname"),
|
||||||
SF.col("store_id").alias("employee_id"),
|
SF.col("store_id").alias("employee_id"),
|
||||||
|
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_order_by_default(self):
|
def test_order_by_default(self):
|
||||||
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
|
df = (
|
||||||
|
self.df_spark_store.groupBy(F.col("district_id"))
|
||||||
|
.agg(F.min("num_sales"))
|
||||||
|
.orderBy(F.col("district_id"))
|
||||||
|
)
|
||||||
|
|
||||||
dfs = (
|
dfs = (
|
||||||
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
|
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||||
|
.agg(SF.min("num_sales"))
|
||||||
|
.orderBy(SF.col("district_id"))
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
df = (
|
df = (
|
||||||
self.df_spark_store.groupBy(F.col("district_id"))
|
self.df_spark_store.groupBy(F.col("district_id"))
|
||||||
.agg(F.min("num_sales").alias("total_sales"))
|
.agg(F.min("num_sales").alias("total_sales"))
|
||||||
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
|
.orderBy(
|
||||||
|
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = (
|
dfs = (
|
||||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||||
.agg(SF.min("num_sales").alias("total_sales"))
|
.agg(SF.min("num_sales").alias("total_sales"))
|
||||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
|
.orderBy(
|
||||||
|
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
df = (
|
df = (
|
||||||
self.df_spark_store.groupBy(F.col("district_id"))
|
self.df_spark_store.groupBy(F.col("district_id"))
|
||||||
.agg(F.min("num_sales").alias("total_sales"))
|
.agg(F.min("num_sales").alias("total_sales"))
|
||||||
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
|
.orderBy(
|
||||||
|
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = (
|
dfs = (
|
||||||
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
self.df_sqlglot_store.groupBy(SF.col("district_id"))
|
||||||
.agg(SF.min("num_sales").alias("total_sales"))
|
.agg(SF.min("num_sales").alias("total_sales"))
|
||||||
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
|
.orderBy(
|
||||||
|
SF.when(
|
||||||
|
SF.col("district_id") == SF.lit(1), SF.col("district_id")
|
||||||
|
).desc_nulls_first()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_intersect(self):
|
def test_intersect(self):
|
||||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
df_employee_duplicate = self.df_spark_employee.select(
|
||||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
F.col("employee_id"), F.col("store_id")
|
||||||
)
|
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||||
|
|
||||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
df_store_duplicate = self.df_spark_store.select(
|
||||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
F.col("store_id"), F.col("district_id")
|
||||||
)
|
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||||
|
|
||||||
df = df_employee_duplicate.intersect(df_store_duplicate)
|
df = df_employee_duplicate.intersect(df_store_duplicate)
|
||||||
|
|
||||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
SF.col("employee_id"), SF.col("store_id")
|
||||||
)
|
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||||
|
|
||||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
SF.col("store_id"), SF.col("district_id")
|
||||||
)
|
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||||
|
|
||||||
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
|
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_intersect_all(self):
|
def test_intersect_all(self):
|
||||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
df_employee_duplicate = self.df_spark_employee.select(
|
||||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
F.col("employee_id"), F.col("store_id")
|
||||||
)
|
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||||
|
|
||||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
df_store_duplicate = self.df_spark_store.select(
|
||||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
F.col("store_id"), F.col("district_id")
|
||||||
)
|
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||||
|
|
||||||
df = df_employee_duplicate.intersectAll(df_store_duplicate)
|
df = df_employee_duplicate.intersectAll(df_store_duplicate)
|
||||||
|
|
||||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
SF.col("employee_id"), SF.col("store_id")
|
||||||
)
|
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||||
|
|
||||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
SF.col("store_id"), SF.col("district_id")
|
||||||
)
|
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||||
|
|
||||||
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
|
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_except_all(self):
|
def test_except_all(self):
|
||||||
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
|
df_employee_duplicate = self.df_spark_employee.select(
|
||||||
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
|
F.col("employee_id"), F.col("store_id")
|
||||||
)
|
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
|
||||||
|
|
||||||
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
|
df_store_duplicate = self.df_spark_store.select(
|
||||||
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
|
F.col("store_id"), F.col("district_id")
|
||||||
)
|
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
|
||||||
|
|
||||||
df = df_employee_duplicate.exceptAll(df_store_duplicate)
|
df = df_employee_duplicate.exceptAll(df_store_duplicate)
|
||||||
|
|
||||||
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
|
dfs_employee_duplicate = self.df_sqlglot_employee.select(
|
||||||
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
|
SF.col("employee_id"), SF.col("store_id")
|
||||||
)
|
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
|
||||||
|
|
||||||
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
|
dfs_store_duplicate = self.df_sqlglot_store.select(
|
||||||
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
|
SF.col("store_id"), SF.col("district_id")
|
||||||
)
|
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
|
||||||
|
|
||||||
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
|
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
|
||||||
|
|
||||||
|
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_drop_na_default(self):
|
def test_drop_na_default(self):
|
||||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
|
df = self.df_spark_employee.select(
|
||||||
|
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||||
|
).dropna()
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||||
|
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
).dropna(how="any", thresh=2)
|
).dropna(how="any", thresh=2)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.lit(None),
|
||||||
|
SF.lit(1),
|
||||||
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||||
).dropna(how="any", thresh=2)
|
).dropna(how="any", thresh=2)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
).dropna(thresh=1, subset="the_age")
|
).dropna(thresh=1, subset="the_age")
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.lit(None),
|
||||||
|
SF.lit(1),
|
||||||
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
|
||||||
).dropna(thresh=1, subset="the_age")
|
).dropna(thresh=1, subset="the_age")
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
||||||
def test_dropna_na_function(self):
|
def test_dropna_na_function(self):
|
||||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
|
df = self.df_spark_employee.select(
|
||||||
|
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||||
|
).na.drop()
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||||
|
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_fillna_default(self):
|
def test_fillna_default(self):
|
||||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
|
df = self.df_spark_employee.select(
|
||||||
|
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||||
|
).fillna(100)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||||
|
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
||||||
def test_fillna_na_func(self):
|
def test_fillna_na_func(self):
|
||||||
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
|
df = self.df_spark_employee.select(
|
||||||
|
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
|
||||||
|
).na.fill(100)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
|
||||||
|
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_replace_basic(self):
|
def test_replace_basic(self):
|
||||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
|
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||||
|
to_replace=37, value=100
|
||||||
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||||
to_replace=37, value=100
|
to_replace=37, value=100
|
||||||
|
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
||||||
def test_replace_mapping(self):
|
def test_replace_mapping(self):
|
||||||
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
|
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
|
||||||
|
{37: 100}
|
||||||
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
|
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
|
||||||
|
{37: 100}
|
||||||
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
||||||
|
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
to_replace=37, value=100
|
to_replace=37, value=100
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
to_replace=37, value=100
|
SF.col("age"), SF.lit(37).alias("test_col")
|
||||||
)
|
).na.replace(to_replace=37, value=100)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
||||||
|
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
"first_name", "first_name_again"
|
"first_name", "first_name_again"
|
||||||
)
|
)
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
|
dfs = self.df_sqlglot_employee.select(
|
||||||
"first_name", "first_name_again"
|
SF.col("fname").alias("first_name")
|
||||||
)
|
).withColumnRenamed("first_name", "first_name_again")
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
def test_drop_column_single(self):
|
def test_drop_column_single(self):
|
||||||
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
|
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
|
||||||
|
|
||||||
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
|
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
|
||||||
|
"age"
|
||||||
|
)
|
||||||
|
|
||||||
self.compare_spark_with_sqlglot(df, dfs)
|
self.compare_spark_with_sqlglot(df, dfs)
|
||||||
|
|
||||||
|
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
|
||||||
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
|
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
|
||||||
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
|
||||||
)
|
)
|
||||||
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
|
df_sqlglot_store_cols = self.df_sqlglot_store.select(
|
||||||
|
SF.col("store_id"), SF.col("store_name")
|
||||||
|
)
|
||||||
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
|
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
|
||||||
df_sqlglot_employee_cols.age,
|
df_sqlglot_employee_cols.age,
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
|
||||||
ON
|
ON
|
||||||
e.store_id = s.store_id
|
e.store_id = s.store_id
|
||||||
"""
|
"""
|
||||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
df = (
|
||||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
self.spark.sql(query)
|
||||||
|
.groupBy(F.col("store_id"))
|
||||||
|
.agg(F.countDistinct(F.col("employee_id")))
|
||||||
|
)
|
||||||
|
dfs = (
|
||||||
|
self.sqlglot.sql(query)
|
||||||
|
.groupBy(SF.col("store_id"))
|
||||||
|
.agg(SF.countDistinct(SF.col("employee_id")))
|
||||||
|
)
|
||||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
||||||
|
|
|
@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
|
||||||
(4, "Claire", "Littleton", 27, 2),
|
(4, "Claire", "Littleton", 27, 2),
|
||||||
(5, "Hugo", "Reyes", 29, 100),
|
(5, "Hugo", "Reyes", 29, 100),
|
||||||
]
|
]
|
||||||
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
|
self.df_employee = self.spark.createDataFrame(
|
||||||
|
data=employee_data, schema=self.employee_schema
|
||||||
|
)
|
||||||
|
|
||||||
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
|
def compare_sql(
|
||||||
|
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
|
||||||
|
):
|
||||||
actual_sqls = df.sql(pretty=pretty)
|
actual_sqls = df.sql(pretty=pretty)
|
||||||
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
|
expected_statements = (
|
||||||
|
[expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||||
|
)
|
||||||
self.assertEqual(len(expected_statements), len(actual_sqls))
|
self.assertEqual(len(expected_statements), len(actual_sqls))
|
||||||
for expected, actual in zip(expected_statements, actual_sqls):
|
for expected, actual in zip(expected_statements, actual_sqls):
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
|
|
@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
|
||||||
|
|
||||||
def test_and(self):
|
def test_and(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
|
"cola = colb AND colc = cold",
|
||||||
|
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_or(self):
|
def test_or(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
|
"cola = colb OR colc = cold",
|
||||||
|
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_mod(self):
|
def test_mod(self):
|
||||||
|
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
|
||||||
|
|
||||||
def test_when_otherwise(self):
|
def test_when_otherwise(self):
|
||||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
|
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
|
||||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
|
self.assertEqual(
|
||||||
|
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
||||||
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
|
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
|
||||||
|
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
|
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
|
||||||
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
|
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
|
||||||
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
|
F.col("cola")
|
||||||
|
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
|
||||||
|
.sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_over(self):
|
def test_over(self):
|
||||||
|
|
|
@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
|
||||||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||||
|
|
||||||
def test_columns(self):
|
def test_columns(self):
|
||||||
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
|
self.assertEqual(
|
||||||
|
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
|
||||||
|
)
|
||||||
|
|
||||||
def test_cache(self):
|
def test_cache(self):
|
||||||
df = self.df_employee.select("fname").cache()
|
df = self.df_employee.select("fname").cache()
|
||||||
|
|
|
@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
|
||||||
col = SF.window(SF.col("cola"), "10 minutes")
|
col = SF.window(SF.col("cola"), "10 minutes")
|
||||||
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
|
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
|
||||||
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
|
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
|
||||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
|
self.assertEqual(
|
||||||
|
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
|
||||||
|
)
|
||||||
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
|
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
|
||||||
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
|
self.assertEqual(
|
||||||
|
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
|
||||||
|
)
|
||||||
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
|
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
|
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
|
||||||
|
col_no_slide.sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_session_window(self):
|
def test_session_window(self):
|
||||||
|
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_from_json(self):
|
def test_from_json(self):
|
||||||
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
self.assertEqual(
|
||||||
|
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||||
|
)
|
||||||
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
self.assertEqual(
|
||||||
|
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||||
|
)
|
||||||
col_no_option = SF.from_json("cola", "cola INT")
|
col_no_option = SF.from_json("cola", "cola INT")
|
||||||
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
|
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
|
||||||
|
|
||||||
|
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_schema_of_json(self):
|
def test_schema_of_json(self):
|
||||||
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
|
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
self.assertEqual(
|
||||||
|
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||||
|
)
|
||||||
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
|
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
||||||
col_no_option = SF.schema_of_json("cola")
|
col_no_option = SF.schema_of_json("cola")
|
||||||
|
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
|
||||||
col = SF.array_sort(SF.col("cola"))
|
col = SF.array_sort(SF.col("cola"))
|
||||||
self.assertEqual("ARRAY_SORT(cola)", col.sql())
|
self.assertEqual("ARRAY_SORT(cola)", col.sql())
|
||||||
col_comparator = SF.array_sort(
|
col_comparator = SF.array_sort(
|
||||||
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
|
"cola",
|
||||||
|
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
|
||||||
|
SF.length(y) - SF.length(x)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
|
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
|
||||||
|
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_from_csv(self):
|
def test_from_csv(self):
|
||||||
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
|
self.assertEqual(
|
||||||
|
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
|
||||||
|
)
|
||||||
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
|
||||||
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
|
self.assertEqual(
|
||||||
|
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
|
||||||
|
)
|
||||||
col_no_option = SF.from_csv("cola", "cola INT")
|
col_no_option = SF.from_csv("cola", "cola INT")
|
||||||
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
|
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
|
||||||
|
|
||||||
|
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
|
||||||
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
|
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
|
||||||
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
|
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
|
||||||
|
|
||||||
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
|
self.assertEqual(
|
||||||
|
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
|
||||||
|
)
|
||||||
|
|
||||||
def test_exists(self):
|
def test_exists(self):
|
||||||
col_str = SF.exists("cola", lambda x: x % 2 == 0)
|
col_str = SF.exists("cola", lambda x: x % 2 == 0)
|
||||||
|
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
|
||||||
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
|
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
|
||||||
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
|
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
|
||||||
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
|
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
|
||||||
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
|
col_custom_names = SF.filter(
|
||||||
|
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
|
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
|
||||||
|
col_custom_names.sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_zip_with(self):
|
def test_zip_with(self):
|
||||||
|
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
|
||||||
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
|
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
|
||||||
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
|
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
|
||||||
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
|
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
|
||||||
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
|
self.assertEqual(
|
||||||
|
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
|
||||||
|
)
|
||||||
|
|
||||||
def test_transform_keys(self):
|
def test_transform_keys(self):
|
||||||
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
|
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
|
||||||
|
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
|
||||||
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
|
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
|
||||||
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
|
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
|
||||||
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
|
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
|
||||||
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
|
self.assertEqual(
|
||||||
|
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
|
||||||
|
)
|
||||||
|
|
||||||
def test_map_filter(self):
|
def test_map_filter(self):
|
||||||
col_str = SF.map_filter("cola", lambda k, v: k > v)
|
col_str = SF.map_filter("cola", lambda k, v: k > v)
|
||||||
|
|
|
@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
|
|
||||||
def test_cdf_no_schema(self):
|
def test_cdf_no_schema(self):
|
||||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||||
expected = (
|
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||||
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
|
||||||
)
|
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_cdf_row_mixed_primitives(self):
|
def test_cdf_row_mixed_primitives(self):
|
||||||
|
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||||
df = self.spark.sql(query)
|
df = self.spark.sql(query)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
|
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
|
||||||
|
df.sql(pretty=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch("sqlglot.schema", MappingSchema())
|
@mock.patch("sqlglot.schema", MappingSchema())
|
||||||
|
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
|
||||||
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
|
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
|
||||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||||
df = self.spark.sql(query)
|
df = self.spark.sql(query)
|
||||||
expected = (
|
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||||
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
|
||||||
)
|
|
||||||
self.compare_sql(df, expected)
|
self.compare_sql(df, expected)
|
||||||
|
|
||||||
def test_session_create_builder_patterns(self):
|
def test_session_create_builder_patterns(self):
|
||||||
|
|
|
@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
|
||||||
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
|
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
|
||||||
|
|
||||||
def test_map(self):
|
def test_map(self):
|
||||||
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
|
self.assertEqual(
|
||||||
|
"map<int, string>",
|
||||||
|
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
|
||||||
|
)
|
||||||
|
|
||||||
def test_struct_field(self):
|
def test_struct_field(self):
|
||||||
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
|
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
|
||||||
|
|
|
@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
|
||||||
|
|
||||||
def test_window_rows_unbounded(self):
|
def test_window_rows_unbounded(self):
|
||||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||||
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
|
|
||||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
|
||||||
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
|
|
||||||
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
|
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||||
|
rows_between_unbounded_start.sql(),
|
||||||
|
)
|
||||||
|
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||||
|
self.assertEqual(
|
||||||
|
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
|
rows_between_unbounded_end.sql(),
|
||||||
|
)
|
||||||
|
rows_between_unbounded_both = Window.rowsBetween(
|
||||||
|
Window.unboundedPreceding, Window.unboundedFollowing
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
|
rows_between_unbounded_both.sql(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_window_range_unbounded(self):
|
def test_window_range_unbounded(self):
|
||||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
|
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
|
||||||
|
range_between_unbounded_start.sql(),
|
||||||
)
|
)
|
||||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||||
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
|
|
||||||
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
|
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
|
range_between_unbounded_end.sql(),
|
||||||
|
)
|
||||||
|
range_between_unbounded_both = Window.rangeBetween(
|
||||||
|
Window.unboundedPreceding, Window.unboundedFollowing
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||||
|
range_between_unbounded_both.sql(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -157,6 +157,14 @@ class TestBigQuery(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"DIV(x, y)",
|
||||||
|
write={
|
||||||
|
"bigquery": "DIV(x, y)",
|
||||||
|
"duckdb": "CAST(x / y AS INT)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
self.validate_identity(
|
self.validate_identity(
|
||||||
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
|
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
|
||||||
)
|
)
|
||||||
|
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
|
||||||
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
|
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
|
||||||
)
|
)
|
||||||
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
|
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
|
||||||
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
|
self.validate_identity(
|
||||||
|
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
|
||||||
|
)
|
||||||
|
|
|
@ -18,7 +18,6 @@ class TestClickhouse(Validator):
|
||||||
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(1 AS NULLABLE(Int64))",
|
"CAST(1 AS NULLABLE(Int64))",
|
||||||
write={
|
write={
|
||||||
|
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
|
||||||
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
|
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT x #! comment",
|
||||||
|
write={"": "SELECT x /* comment */"},
|
||||||
|
)
|
||||||
|
|
|
@ -22,7 +22,8 @@ class TestDatabricks(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
|
"SELECT DATEDIFF('end', 'start')",
|
||||||
|
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT DATE_ADD('2020-01-01', 1)",
|
"SELECT DATE_ADD('2020-01-01', 1)",
|
||||||
|
|
|
@ -1,20 +1,18 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import (
|
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
|
||||||
Dialect,
|
|
||||||
Dialects,
|
|
||||||
ErrorLevel,
|
|
||||||
UnsupportedError,
|
|
||||||
parse_one,
|
|
||||||
transpile,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Validator(unittest.TestCase):
|
class Validator(unittest.TestCase):
|
||||||
dialect = None
|
dialect = None
|
||||||
|
|
||||||
def validate_identity(self, sql):
|
def parse_one(self, sql):
|
||||||
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
|
return parse_one(sql, read=self.dialect)
|
||||||
|
|
||||||
|
def validate_identity(self, sql, write_sql=None):
|
||||||
|
expression = self.parse_one(sql)
|
||||||
|
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
|
||||||
|
return expression
|
||||||
|
|
||||||
def validate_all(self, sql, read=None, write=None, pretty=False):
|
def validate_all(self, sql, read=None, write=None, pretty=False):
|
||||||
"""
|
"""
|
||||||
|
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
|
||||||
read (dict): Mapping of dialect -> SQL
|
read (dict): Mapping of dialect -> SQL
|
||||||
write (dict): Mapping of dialect -> SQL
|
write (dict): Mapping of dialect -> SQL
|
||||||
"""
|
"""
|
||||||
expression = parse_one(sql, read=self.dialect)
|
expression = self.parse_one(sql)
|
||||||
|
|
||||||
for read_dialect, read_sql in (read or {}).items():
|
for read_dialect, read_sql in (read or {}).items():
|
||||||
with self.subTest(f"{read_dialect} -> {sql}"):
|
with self.subTest(f"{read_dialect} -> {sql}"):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
|
parse_one(read_sql, read_dialect).sql(
|
||||||
|
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
|
||||||
|
),
|
||||||
sql,
|
sql,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,10 +83,6 @@ class TestDialect(Validator):
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS BINARY(4))",
|
"CAST(a AS BINARY(4))",
|
||||||
read={
|
|
||||||
"presto": "CAST(a AS VARBINARY(4))",
|
|
||||||
"sqlite": "CAST(a AS VARBINARY(4))",
|
|
||||||
},
|
|
||||||
write={
|
write={
|
||||||
"bigquery": "CAST(a AS BINARY(4))",
|
"bigquery": "CAST(a AS BINARY(4))",
|
||||||
"clickhouse": "CAST(a AS BINARY(4))",
|
"clickhouse": "CAST(a AS BINARY(4))",
|
||||||
|
@ -103,6 +99,24 @@ class TestDialect(Validator):
|
||||||
"starrocks": "CAST(a AS BINARY(4))",
|
"starrocks": "CAST(a AS BINARY(4))",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CAST(a AS VARBINARY(4))",
|
||||||
|
write={
|
||||||
|
"bigquery": "CAST(a AS VARBINARY(4))",
|
||||||
|
"clickhouse": "CAST(a AS VARBINARY(4))",
|
||||||
|
"duckdb": "CAST(a AS VARBINARY(4))",
|
||||||
|
"mysql": "CAST(a AS VARBINARY(4))",
|
||||||
|
"hive": "CAST(a AS BINARY(4))",
|
||||||
|
"oracle": "CAST(a AS BLOB(4))",
|
||||||
|
"postgres": "CAST(a AS BYTEA(4))",
|
||||||
|
"presto": "CAST(a AS VARBINARY(4))",
|
||||||
|
"redshift": "CAST(a AS VARBYTE(4))",
|
||||||
|
"snowflake": "CAST(a AS VARBINARY(4))",
|
||||||
|
"sqlite": "CAST(a AS BLOB(4))",
|
||||||
|
"spark": "CAST(a AS BINARY(4))",
|
||||||
|
"starrocks": "CAST(a AS VARBINARY(4))",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
|
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
|
||||||
write={
|
write={
|
||||||
|
@ -472,45 +486,57 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'day')",
|
"DATE_TRUNC('day', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": "DATE(x)",
|
"mysql": "DATE(x)",
|
||||||
"starrocks": "DATE(x)",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'week')",
|
"DATE_TRUNC('week', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
|
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
|
||||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'month')",
|
"DATE_TRUNC('month', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
|
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
|
||||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'quarter')",
|
"DATE_TRUNC('quarter', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
|
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
|
||||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'year')",
|
"DATE_TRUNC('year', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
|
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
|
||||||
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_TRUNC(x, 'millenium')",
|
"DATE_TRUNC('millenium', x)",
|
||||||
write={
|
write={
|
||||||
"mysql": UnsupportedError,
|
"mysql": UnsupportedError,
|
||||||
"starrocks": UnsupportedError,
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DATE_TRUNC('year', x)",
|
||||||
|
read={
|
||||||
|
"starrocks": "DATE_TRUNC('year', x)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"starrocks": "DATE_TRUNC('year', x)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DATE_TRUNC(x, year)",
|
||||||
|
read={
|
||||||
|
"bigquery": "DATE_TRUNC(x, year)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"bigquery": "DATE_TRUNC(x, year)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -564,6 +590,22 @@ class TestDialect(Validator):
|
||||||
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"TIMESTAMP '2022-01-01'",
|
||||||
|
write={
|
||||||
|
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
|
"starrocks": "CAST('2022-01-01' AS DATETIME)",
|
||||||
|
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"TIMESTAMP('2022-01-01')",
|
||||||
|
write={
|
||||||
|
"mysql": "TIMESTAMP('2022-01-01')",
|
||||||
|
"starrocks": "TIMESTAMP('2022-01-01')",
|
||||||
|
"hive": "TIMESTAMP('2022-01-01')",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
for unit in ("DAY", "MONTH", "YEAR"):
|
for unit in ("DAY", "MONTH", "YEAR"):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_limit(self):
|
def test_limit(self):
|
||||||
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
|
self.validate_all(
|
||||||
|
"SELECT * FROM data LIMIT 10, 20",
|
||||||
|
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT x FROM y LIMIT 10",
|
"SELECT x FROM y LIMIT 10",
|
||||||
write={
|
write={
|
||||||
|
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
|
||||||
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_nullsafe_eq(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a IS NOT DISTINCT FROM b",
|
||||||
|
read={
|
||||||
|
"mysql": "SELECT a <=> b",
|
||||||
|
"postgres": "SELECT a IS NOT DISTINCT FROM b",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT a <=> b",
|
||||||
|
"postgres": "SELECT a IS NOT DISTINCT FROM b",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_nullsafe_neq(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a IS DISTINCT FROM b",
|
||||||
|
read={
|
||||||
|
"postgres": "SELECT a IS DISTINCT FROM b",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"mysql": "SELECT NOT a <=> b",
|
||||||
|
"postgres": "SELECT a IS DISTINCT FROM b",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hash_comments(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT 1 /* arbitrary content,,, until end-of-line */",
|
||||||
|
read={
|
||||||
|
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
|
||||||
|
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
|
||||||
|
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"""/* comment1 */
|
||||||
|
SELECT
|
||||||
|
x, -- comment2
|
||||||
|
y -- comment3""",
|
||||||
|
read={
|
||||||
|
"mysql": """SELECT # comment1
|
||||||
|
x, # comment2
|
||||||
|
y # comment3""",
|
||||||
|
"bigquery": """SELECT # comment1
|
||||||
|
x, # comment2
|
||||||
|
y # comment3""",
|
||||||
|
"clickhouse": """SELECT # comment1
|
||||||
|
x, # comment2
|
||||||
|
y # comment3""",
|
||||||
|
},
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from sqlglot import expressions as exp
|
||||||
from tests.dialects.test_dialect import Validator
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +21,52 @@ class TestMySQL(Validator):
|
||||||
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
|
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
|
||||||
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
|
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
|
||||||
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
|
||||||
|
self.validate_identity("@@GLOBAL.max_connections")
|
||||||
|
|
||||||
|
# SET Commands
|
||||||
|
self.validate_identity("SET @var_name = expr")
|
||||||
|
self.validate_identity("SET @name = 43")
|
||||||
|
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
|
||||||
|
self.validate_identity("SET GLOBAL max_connections = 1000")
|
||||||
|
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
|
||||||
|
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
|
||||||
|
self.validate_identity("SET PERSIST max_connections = 1000")
|
||||||
|
self.validate_identity("SET @@PERSIST.max_connections = 1000")
|
||||||
|
self.validate_identity("SET PERSIST_ONLY back_log = 100")
|
||||||
|
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
|
||||||
|
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
|
||||||
|
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
|
||||||
|
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
|
||||||
|
self.validate_identity(
|
||||||
|
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
|
||||||
|
)
|
||||||
|
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
|
||||||
|
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
|
||||||
|
self.validate_identity("SET CHARACTER SET 'utf8'")
|
||||||
|
self.validate_identity("SET CHARACTER SET utf8")
|
||||||
|
self.validate_identity("SET CHARACTER SET DEFAULT")
|
||||||
|
self.validate_identity("SET NAMES 'utf8'")
|
||||||
|
self.validate_identity("SET NAMES DEFAULT")
|
||||||
|
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
|
||||||
|
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
|
||||||
|
self.validate_identity("SET autocommit = ON")
|
||||||
|
|
||||||
|
def test_escape(self):
|
||||||
|
self.validate_all(
|
||||||
|
r"'a \' b '' '",
|
||||||
|
write={
|
||||||
|
"mysql": r"'a '' b '' '",
|
||||||
|
"spark": r"'a \' b \' '",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_introducers(self):
|
def test_introducers(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -115,14 +162,6 @@ class TestMySQL(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hash_comments(self):
|
|
||||||
self.validate_all(
|
|
||||||
"SELECT 1 # arbitrary content,,, until end-of-line",
|
|
||||||
write={
|
|
||||||
"mysql": "SELECT 1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_mysql(self):
|
def test_mysql(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||||
|
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
|
||||||
},
|
},
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_show_simple(self):
|
||||||
|
for key, write_key in [
|
||||||
|
("BINARY LOGS", "BINARY LOGS"),
|
||||||
|
("MASTER LOGS", "BINARY LOGS"),
|
||||||
|
("STORAGE ENGINES", "ENGINES"),
|
||||||
|
("ENGINES", "ENGINES"),
|
||||||
|
("EVENTS", "EVENTS"),
|
||||||
|
("MASTER STATUS", "MASTER STATUS"),
|
||||||
|
("PLUGINS", "PLUGINS"),
|
||||||
|
("PRIVILEGES", "PRIVILEGES"),
|
||||||
|
("PROFILES", "PROFILES"),
|
||||||
|
("REPLICAS", "REPLICAS"),
|
||||||
|
("SLAVE HOSTS", "REPLICAS"),
|
||||||
|
]:
|
||||||
|
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, write_key)
|
||||||
|
|
||||||
|
def test_show_events(self):
|
||||||
|
for key in ["BINLOG", "RELAYLOG"]:
|
||||||
|
show = self.validate_identity(f"SHOW {key} EVENTS")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, f"{key} EVENTS")
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
|
||||||
|
self.assertEqual(show.text("log"), "log")
|
||||||
|
self.assertEqual(show.text("position"), "1")
|
||||||
|
self.assertEqual(show.text("limit"), "3")
|
||||||
|
self.assertEqual(show.text("offset"), "2")
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
|
||||||
|
self.assertEqual(show.text("limit"), "1")
|
||||||
|
self.assertIsNone(show.args.get("offset"))
|
||||||
|
|
||||||
|
def test_show_like_or_where(self):
|
||||||
|
for key, write_key in [
|
||||||
|
("CHARSET", "CHARACTER SET"),
|
||||||
|
("CHARACTER SET", "CHARACTER SET"),
|
||||||
|
("COLLATION", "COLLATION"),
|
||||||
|
("DATABASES", "DATABASES"),
|
||||||
|
("FUNCTION STATUS", "FUNCTION STATUS"),
|
||||||
|
("PROCEDURE STATUS", "PROCEDURE STATUS"),
|
||||||
|
("GLOBAL STATUS", "GLOBAL STATUS"),
|
||||||
|
("SESSION STATUS", "STATUS"),
|
||||||
|
("STATUS", "STATUS"),
|
||||||
|
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
|
||||||
|
("SESSION VARIABLES", "VARIABLES"),
|
||||||
|
("VARIABLES", "VARIABLES"),
|
||||||
|
]:
|
||||||
|
expected_name = write_key.strip("GLOBAL").strip()
|
||||||
|
template = "SHOW {}"
|
||||||
|
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, expected_name)
|
||||||
|
|
||||||
|
template = "SHOW {} LIKE '%foo%'"
|
||||||
|
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||||
|
self.assertEqual(show.text("like"), "%foo%")
|
||||||
|
|
||||||
|
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
|
||||||
|
show = self.validate_identity(template.format(key), template.format(write_key))
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertIsInstance(show.args["where"], exp.Where)
|
||||||
|
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
|
||||||
|
|
||||||
|
def test_show_columns(self):
|
||||||
|
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "COLUMNS")
|
||||||
|
self.assertEqual(show.text("target"), "tbl_name")
|
||||||
|
self.assertFalse(show.args["full"])
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.text("target"), "tbl_name")
|
||||||
|
self.assertTrue(show.args["full"])
|
||||||
|
self.assertEqual(show.text("db"), "db_name")
|
||||||
|
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||||
|
self.assertEqual(show.text("like"), "%foo%")
|
||||||
|
|
||||||
|
def test_show_name(self):
|
||||||
|
for key in [
|
||||||
|
"CREATE DATABASE",
|
||||||
|
"CREATE EVENT",
|
||||||
|
"CREATE FUNCTION",
|
||||||
|
"CREATE PROCEDURE",
|
||||||
|
"CREATE TABLE",
|
||||||
|
"CREATE TRIGGER",
|
||||||
|
"CREATE VIEW",
|
||||||
|
"FUNCTION CODE",
|
||||||
|
"PROCEDURE CODE",
|
||||||
|
]:
|
||||||
|
show = self.validate_identity(f"SHOW {key} foo")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
self.assertEqual(show.text("target"), "foo")
|
||||||
|
|
||||||
|
def test_show_grants(self):
|
||||||
|
show = self.validate_identity(f"SHOW GRANTS FOR foo")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "GRANTS")
|
||||||
|
self.assertEqual(show.text("target"), "foo")
|
||||||
|
|
||||||
|
def test_show_engine(self):
|
||||||
|
show = self.validate_identity("SHOW ENGINE foo STATUS")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "ENGINE")
|
||||||
|
self.assertEqual(show.text("target"), "foo")
|
||||||
|
self.assertFalse(show.args["mutex"])
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW ENGINE foo MUTEX")
|
||||||
|
self.assertEqual(show.name, "ENGINE")
|
||||||
|
self.assertEqual(show.text("target"), "foo")
|
||||||
|
self.assertTrue(show.args["mutex"])
|
||||||
|
|
||||||
|
def test_show_errors(self):
|
||||||
|
for key in ["ERRORS", "WARNINGS"]:
|
||||||
|
show = self.validate_identity(f"SHOW {key}")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
|
||||||
|
self.assertEqual(show.text("limit"), "3")
|
||||||
|
self.assertEqual(show.text("offset"), "2")
|
||||||
|
|
||||||
|
def test_show_index(self):
|
||||||
|
show = self.validate_identity("SHOW INDEX FROM foo")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "INDEX")
|
||||||
|
self.assertEqual(show.text("target"), "foo")
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
|
||||||
|
self.assertEqual(show.text("db"), "bar")
|
||||||
|
|
||||||
|
def test_show_db_like_or_where_sql(self):
|
||||||
|
for key in [
|
||||||
|
"OPEN TABLES",
|
||||||
|
"TABLE STATUS",
|
||||||
|
"TRIGGERS",
|
||||||
|
]:
|
||||||
|
show = self.validate_identity(f"SHOW {key}")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} FROM db_name")
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
self.assertEqual(show.text("db"), "db_name")
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||||
|
self.assertEqual(show.text("like"), "%foo%")
|
||||||
|
|
||||||
|
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
|
||||||
|
self.assertEqual(show.name, key)
|
||||||
|
self.assertIsInstance(show.args["where"], exp.Where)
|
||||||
|
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
|
||||||
|
|
||||||
|
def test_show_processlist(self):
|
||||||
|
show = self.validate_identity("SHOW PROCESSLIST")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "PROCESSLIST")
|
||||||
|
self.assertFalse(show.args["full"])
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW FULL PROCESSLIST")
|
||||||
|
self.assertEqual(show.name, "PROCESSLIST")
|
||||||
|
self.assertTrue(show.args["full"])
|
||||||
|
|
||||||
|
def test_show_profile(self):
|
||||||
|
show = self.validate_identity("SHOW PROFILE")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "PROFILE")
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW PROFILE BLOCK IO")
|
||||||
|
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
|
||||||
|
|
||||||
|
show = self.validate_identity(
|
||||||
|
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
|
||||||
|
)
|
||||||
|
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
|
||||||
|
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
|
||||||
|
self.assertEqual(show.text("query"), "1")
|
||||||
|
self.assertEqual(show.text("offset"), "2")
|
||||||
|
self.assertEqual(show.text("limit"), "3")
|
||||||
|
|
||||||
|
def test_show_replica_status(self):
|
||||||
|
show = self.validate_identity("SHOW REPLICA STATUS")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "REPLICA STATUS")
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "REPLICA STATUS")
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
|
||||||
|
self.assertEqual(show.text("channel"), "channel_name")
|
||||||
|
|
||||||
|
def test_show_tables(self):
|
||||||
|
show = self.validate_identity("SHOW TABLES")
|
||||||
|
self.assertIsInstance(show, exp.Show)
|
||||||
|
self.assertEqual(show.name, "TABLES")
|
||||||
|
|
||||||
|
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
|
||||||
|
self.assertTrue(show.args["full"])
|
||||||
|
self.assertEqual(show.text("db"), "db_name")
|
||||||
|
self.assertIsInstance(show.args["like"], exp.Literal)
|
||||||
|
self.assertEqual(show.text("like"), "%foo%")
|
||||||
|
|
||||||
|
def test_set_variable(self):
|
||||||
|
cmd = self.parse_one("SET SESSION x = 1")
|
||||||
|
item = cmd.expressions[0]
|
||||||
|
self.assertEqual(item.text("kind"), "SESSION")
|
||||||
|
self.assertIsInstance(item.this, exp.EQ)
|
||||||
|
self.assertEqual(item.this.left.name, "x")
|
||||||
|
self.assertEqual(item.this.right.name, "1")
|
||||||
|
|
||||||
|
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
|
||||||
|
item = cmd.expressions[0]
|
||||||
|
self.assertEqual(item.text("kind"), "")
|
||||||
|
self.assertIsInstance(item.this, exp.EQ)
|
||||||
|
self.assertIsInstance(item.this.left, exp.SessionParameter)
|
||||||
|
self.assertIsInstance(item.this.right, exp.SessionParameter)
|
||||||
|
|
||||||
|
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
|
||||||
|
item = cmd.expressions[0]
|
||||||
|
self.assertEqual(item.text("kind"), "NAMES")
|
||||||
|
self.assertEqual(item.name, "charset_name")
|
||||||
|
self.assertEqual(item.text("collate"), "collation_name")
|
||||||
|
|
||||||
|
cmd = self.parse_one("SET CHARSET DEFAULT")
|
||||||
|
item = cmd.expressions[0]
|
||||||
|
self.assertEqual(item.text("kind"), "CHARACTER SET")
|
||||||
|
self.assertEqual(item.this.name, "DEFAULT")
|
||||||
|
|
||||||
|
cmd = self.parse_one("SET x = 1, y = 2")
|
||||||
|
self.assertEqual(len(cmd.expressions), 2)
|
||||||
|
|
|
@ -8,7 +8,9 @@ class TestPostgres(Validator):
|
||||||
def test_ddl(self):
|
def test_ddl(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
|
||||||
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
|
write={
|
||||||
|
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
|
||||||
|
@ -59,15 +61,27 @@ class TestPostgres(Validator):
|
||||||
|
|
||||||
def test_postgres(self):
|
def test_postgres(self):
|
||||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
|
||||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
|
self.validate_identity(
|
||||||
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
|
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
|
||||||
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
|
||||||
|
)
|
||||||
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
|
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
|
||||||
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
|
self.validate_identity(
|
||||||
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
|
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
|
||||||
|
)
|
||||||
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
|
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
|
||||||
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
||||||
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
|
self.validate_identity(
|
||||||
|
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
|
||||||
|
)
|
||||||
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
|
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
|
||||||
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
||||||
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
||||||
|
@ -75,7 +89,7 @@ class TestPostgres(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "CREATE TABLE x (a UUID, b BINARY)",
|
"duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
|
||||||
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
|
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
|
||||||
"hive": "CREATE TABLE x (a UUID, b BINARY)",
|
"hive": "CREATE TABLE x (a UUID, b BINARY)",
|
||||||
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
"spark": "CREATE TABLE x (a UUID, b BINARY)",
|
||||||
|
@ -153,7 +167,9 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
|
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
|
||||||
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
|
read={
|
||||||
|
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
|
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
|
||||||
|
@ -169,11 +185,15 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
||||||
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
|
read={
|
||||||
|
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
|
||||||
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
|
read={
|
||||||
|
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"'[1,2,3]'::json->2",
|
"'[1,2,3]'::json->2",
|
||||||
|
@ -184,7 +204,8 @@ class TestPostgres(Validator):
|
||||||
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
|
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
|
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
|
||||||
|
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",
|
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",
|
||||||
|
|
|
@ -61,4 +61,6 @@ class TestRedshift(Validator):
|
||||||
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
|
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
|
||||||
)
|
)
|
||||||
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
|
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
|
||||||
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
|
self.validate_identity(
|
||||||
|
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
|
||||||
|
)
|
||||||
|
|
|
@ -336,7 +336,8 @@ class TestSnowflake(Validator):
|
||||||
def test_table_literal(self):
|
def test_table_literal(self):
|
||||||
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
|
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
|
r"""SELECT * FROM TABLE('MYTABLE')""",
|
||||||
|
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
|
||||||
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
|
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
|
self.validate_all(
|
||||||
|
r"""SELECT * FROM TABLE($MYVAR)""",
|
||||||
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
|
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
|
||||||
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
|
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
r"""SELECT * FROM TABLE(:BINDING)""",
|
||||||
|
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
|
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
|
||||||
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
|
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_flatten(self):
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
select
|
||||||
|
dag_report.acct_id,
|
||||||
|
dag_report.report_date,
|
||||||
|
dag_report.report_uuid,
|
||||||
|
dag_report.airflow_name,
|
||||||
|
dag_report.dag_id,
|
||||||
|
f.value::varchar as operator
|
||||||
|
from cs.telescope.dag_report,
|
||||||
|
table(flatten(input=>split(operators, ','))) f
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT
|
||||||
|
dag_report.acct_id,
|
||||||
|
dag_report.report_date,
|
||||||
|
dag_report.report_uuid,
|
||||||
|
dag_report.airflow_name,
|
||||||
|
dag_report.dag_id,
|
||||||
|
CAST(f.value AS VARCHAR) AS operator
|
||||||
|
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
|
||||||
|
},
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
|
||||||
|
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
SELECT id as "ID",
|
||||||
|
f.value AS "Contact",
|
||||||
|
f1.value:type AS "Type",
|
||||||
|
f1.value:content AS "Details"
|
||||||
|
FROM persons p,
|
||||||
|
lateral flatten(input => p.c, path => 'contact') f,
|
||||||
|
lateral flatten(input => f.value:business) f1
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"snowflake": """SELECT
|
||||||
|
id AS "ID",
|
||||||
|
f.value AS "Contact",
|
||||||
|
f1.value['type'] AS "Type",
|
||||||
|
f1.value['content'] AS "Details"
|
||||||
|
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
|
||||||
|
},
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
|
|
@ -284,4 +284,6 @@ TBLPROPERTIES (
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_iif(self):
|
def test_iif(self):
|
||||||
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
|
self.validate_all(
|
||||||
|
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
|
||||||
|
)
|
||||||
|
|
|
@ -6,3 +6,6 @@ class TestMySQL(Validator):
|
||||||
|
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
|
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
|
||||||
|
|
||||||
|
def test_time(self):
|
||||||
|
self.validate_identity("TIMESTAMP('2022-01-01')")
|
||||||
|
|
|
@ -278,12 +278,19 @@ class TestTSQL(Validator):
|
||||||
def test_add_date(self):
|
def test_add_date(self):
|
||||||
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
|
"SELECT DATEADD(year, 1, '2017/08/25')",
|
||||||
|
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT DATEADD(qq, 1, '2017/08/25')",
|
||||||
|
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
|
||||||
)
|
)
|
||||||
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT DATEADD(wk, 1, '2017/08/25')",
|
"SELECT DATEADD(wk, 1, '2017/08/25')",
|
||||||
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
|
write={
|
||||||
|
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
|
||||||
|
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_date_diff(self):
|
def test_date_diff(self):
|
||||||
|
@ -370,13 +377,21 @@ class TestTSQL(Validator):
|
||||||
"SELECT FORMAT(1000000.01,'###,###.###')",
|
"SELECT FORMAT(1000000.01,'###,###.###')",
|
||||||
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
|
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
|
||||||
)
|
)
|
||||||
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
|
self.validate_all(
|
||||||
|
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
|
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
|
||||||
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
|
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
|
"SELECT FORMAT(date_col, 'dd.mm.yyyy')",
|
||||||
|
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT FORMAT(date_col, 'm')",
|
||||||
|
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
|
||||||
)
|
)
|
||||||
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
|
|
||||||
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})
|
|
||||||
|
|
12
tests/fixtures/identity.sql
vendored
12
tests/fixtures/identity.sql
vendored
|
@ -523,6 +523,8 @@ DROP VIEW a.b
|
||||||
DROP VIEW IF EXISTS a
|
DROP VIEW IF EXISTS a
|
||||||
DROP VIEW IF EXISTS a.b
|
DROP VIEW IF EXISTS a.b
|
||||||
SHOW TABLES
|
SHOW TABLES
|
||||||
|
USE db
|
||||||
|
ROLLBACK
|
||||||
EXPLAIN SELECT * FROM x
|
EXPLAIN SELECT * FROM x
|
||||||
INSERT INTO x SELECT * FROM y
|
INSERT INTO x SELECT * FROM y
|
||||||
INSERT INTO x (SELECT * FROM y)
|
INSERT INTO x (SELECT * FROM y)
|
||||||
|
@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
|
||||||
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
|
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
|
||||||
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
|
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
|
||||||
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
||||||
|
SELECT CAST(x AS INT) /* comment */ FROM foo
|
||||||
|
SELECT a /* x */, b /* x */
|
||||||
|
SELECT * FROM foo /* x */, bla /* x */
|
||||||
|
SELECT 1 /* comment */ + 1
|
||||||
|
SELECT 1 /* c1 */ + 2 /* c2 */
|
||||||
|
SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
|
||||||
|
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
|
||||||
|
SELECT x FROM a.b.c /* x */, e.f.g /* x */
|
||||||
|
SELECT FOO(x /* c */) /* FOO */, b /* b */
|
||||||
|
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
|
||||||
|
|
10
tests/fixtures/optimizer/qualify_columns.sql
vendored
10
tests/fixtures/optimizer/qualify_columns.sql
vendored
|
@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
|
||||||
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
|
||||||
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
|
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
|
||||||
|
|
||||||
|
# dialect: starrocks
|
||||||
|
# execute: false
|
||||||
|
SELECT DATE_TRUNC('week', a) AS a FROM x;
|
||||||
|
SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
|
||||||
|
|
||||||
|
# dialect: bigquery
|
||||||
|
# execute: false
|
||||||
|
SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
|
||||||
|
SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
|
||||||
|
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
-- Derived tables
|
-- Derived tables
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
9
tests/fixtures/optimizer/simplify.sql
vendored
9
tests/fixtures/optimizer/simplify.sql
vendored
|
@ -79,6 +79,15 @@ NULL;
|
||||||
NULL = NULL;
|
NULL = NULL;
|
||||||
NULL;
|
NULL;
|
||||||
|
|
||||||
|
NULL <=> NULL;
|
||||||
|
TRUE;
|
||||||
|
|
||||||
|
a IS NOT DISTINCT FROM a;
|
||||||
|
TRUE;
|
||||||
|
|
||||||
|
NULL IS DISTINCT FROM NULL;
|
||||||
|
FALSE;
|
||||||
|
|
||||||
NOT (NOT TRUE);
|
NOT (NOT TRUE);
|
||||||
TRUE;
|
TRUE;
|
||||||
|
|
||||||
|
|
28
tests/fixtures/pretty.sql
vendored
28
tests/fixtures/pretty.sql
vendored
|
@ -287,3 +287,31 @@ SELECT
|
||||||
"fffffff"
|
"fffffff"
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
/*
|
||||||
|
multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT * FROM foo;
|
||||||
|
/*
|
||||||
|
multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT
|
||||||
|
*
|
||||||
|
FROM foo;
|
||||||
|
SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
|
||||||
|
SELECT
|
||||||
|
x
|
||||||
|
FROM a.b.c /* x */, e.f.g /* x */;
|
||||||
|
SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
|
||||||
|
SELECT
|
||||||
|
x
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
*
|
||||||
|
FROM bla /* x */
|
||||||
|
WHERE
|
||||||
|
id = 1
|
||||||
|
) /* x */;
|
||||||
|
|
|
@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.join(exp.Table(this="tbl2"), join_type="left outer"),
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
|
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.join(select("y").from_("tbl2"), join_type="left outer"),
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
|
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
|
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.join(parse_one("left join x", into=exp.Join), on="a=b"),
|
||||||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
"SELECT x FROM tbl LEFT JOIN x ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.join("select b from tbl2", on="a=b", join_type="left"),
|
||||||
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
|
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z")
|
lambda: select("x", "y", "z")
|
||||||
.from_("merged_df")
|
.from_("merged_df")
|
||||||
.join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]),
|
.join(
|
||||||
|
"vte_diagnosis_df",
|
||||||
|
using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
|
||||||
|
),
|
||||||
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
|
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
|
lambda: select("x", "y", "z", "a")
|
||||||
|
.from_("tbl")
|
||||||
|
.cluster_by("x, y", "z")
|
||||||
|
.cluster_by("a"),
|
||||||
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
|
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
|
||||||
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
|
||||||
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
|
lambda: select("x")
|
||||||
|
.from_("tbl")
|
||||||
|
.with_("tbl", as_=select("x", "y").from_("tbl2"))
|
||||||
|
.select("y"),
|
||||||
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
|
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.group_by("x"),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.order_by("x"),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.limit(10),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.offset(10),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.join("tbl3"),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.distinct(),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.where("x > 10"),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
|
lambda: select("x")
|
||||||
|
.with_("tbl", as_=select("x").from_("tbl2"))
|
||||||
|
.from_("tbl")
|
||||||
|
.having("x > 20"),
|
||||||
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
|
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
|
||||||
),
|
),
|
||||||
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
|
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
|
||||||
|
@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
|
||||||
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
|
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
|
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
|
||||||
|
"x"
|
||||||
|
),
|
||||||
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
|
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
|
|
@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
cls.cache = {}
|
cls.cache = {}
|
||||||
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
|
cls.sqls = [
|
||||||
|
(sql, expected)
|
||||||
|
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
|
||||||
def test_execute_tpch(self):
|
def test_execute_tpch(self):
|
||||||
def to_csv(expression):
|
def to_csv(expression):
|
||||||
if isinstance(expression, exp.Table):
|
if isinstance(expression, exp.Table):
|
||||||
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
|
return parse_one(
|
||||||
|
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
|
||||||
|
)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
for sql, _ in self.sqls[0:3]:
|
for sql, _ in self.sqls[0:3]:
|
||||||
|
|
|
@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
|
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
|
||||||
self.assertEqual(exp.Table(pivots=[]), exp.Table())
|
self.assertEqual(exp.Table(pivots=[]), exp.Table())
|
||||||
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
|
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
|
||||||
self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False))
|
self.assertEqual(
|
||||||
|
exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
|
||||||
|
)
|
||||||
|
|
||||||
def test_find(self):
|
def test_find(self):
|
||||||
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
|
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
|
||||||
|
@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsNone(column.find_ancestor(exp.Join))
|
self.assertIsNone(column.find_ancestor(exp.Join))
|
||||||
|
|
||||||
def test_alias_or_name(self):
|
def test_alias_or_name(self):
|
||||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
expression = parse_one(
|
||||||
|
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[e.alias_or_name for e in expression.expressions],
|
[e.alias_or_name for e in expression.expressions],
|
||||||
["a", "B", "e", "*", "zz", "z"],
|
["a", "B", "e", "*", "zz", "z"],
|
||||||
|
@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
"SELECT * FROM foo WHERE ? > 100",
|
"SELECT * FROM foo WHERE ? > 100",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
|
exp.replace_placeholders(
|
||||||
|
parse_one("select * from :name WHERE ? > 100"), another_name="bla"
|
||||||
|
).sql(),
|
||||||
"SELECT * FROM :name WHERE ? > 100",
|
"SELECT * FROM :name WHERE ? > 100",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_named_selects(self):
|
def test_named_selects(self):
|
||||||
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
|
expression = parse_one(
|
||||||
|
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
||||||
|
)
|
||||||
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
|
||||||
|
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
|
@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertEqual(len(list(expression.walk())), 9)
|
self.assertEqual(len(list(expression.walk())), 9)
|
||||||
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
|
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
|
||||||
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
|
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
|
||||||
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
|
self.assertTrue(
|
||||||
|
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
|
||||||
|
)
|
||||||
|
|
||||||
def test_functions(self):
|
def test_functions(self):
|
||||||
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
|
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
|
||||||
|
@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
|
||||||
),
|
),
|
||||||
exp.Properties(
|
exp.Properties(
|
||||||
expressions=[
|
expressions=[
|
||||||
exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")),
|
exp.FileFormatProperty(
|
||||||
|
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
|
||||||
|
),
|
||||||
exp.PartitionedByProperty(
|
exp.PartitionedByProperty(
|
||||||
this=exp.Literal.string("PARTITIONED_BY"),
|
this=exp.Literal.string("PARTITIONED_BY"),
|
||||||
value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]),
|
value=exp.Tuple(
|
||||||
|
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
exp.AnonymousProperty(
|
||||||
|
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
|
||||||
),
|
),
|
||||||
exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
|
|
||||||
exp.TableFormatProperty(
|
exp.TableFormatProperty(
|
||||||
this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format")
|
this=exp.Literal.string("TABLE_FORMAT"),
|
||||||
|
value=exp.to_identifier("test_format"),
|
||||||
),
|
),
|
||||||
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
|
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
|
||||||
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
|
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
|
||||||
|
@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
|
||||||
((1, "2", None), "(1, '2', NULL)"),
|
((1, "2", None), "(1, '2', NULL)"),
|
||||||
([1, "2", None], "ARRAY(1, '2', NULL)"),
|
([1, "2", None], "ARRAY(1, '2', NULL)"),
|
||||||
({"x": None}, "MAP('x', NULL)"),
|
({"x": None}, "MAP('x', NULL)"),
|
||||||
(datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"),
|
(
|
||||||
|
datetime.datetime(2022, 10, 1, 1, 1, 1),
|
||||||
|
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
|
||||||
|
),
|
||||||
(
|
(
|
||||||
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
|
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
|
||||||
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
|
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
|
||||||
|
@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
|
||||||
with self.subTest(value):
|
with self.subTest(value):
|
||||||
self.assertEqual(exp.convert(value).sql(), expected)
|
self.assertEqual(exp.convert(value).sql(), expected)
|
||||||
|
|
||||||
def test_annotation_alias(self):
|
def test_comment_alias(self):
|
||||||
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
|
sql = """
|
||||||
|
SELECT
|
||||||
|
a,
|
||||||
|
b AS B,
|
||||||
|
c, /*comment*/
|
||||||
|
d AS D, -- another comment
|
||||||
|
CAST(x AS INT) -- final comment
|
||||||
|
FROM foo
|
||||||
|
"""
|
||||||
expression = parse_one(sql)
|
expression = parse_one(sql)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[e.alias_or_name for e in expression.expressions],
|
[e.alias_or_name for e in expression.expressions],
|
||||||
["a", "B", "c", "D"],
|
["a", "B", "c", "D", "x"],
|
||||||
)
|
)
|
||||||
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
|
|
||||||
self.assertEqual(expression.expressions[2].name, "comment")
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expression.sql(pretty=True, annotations=False),
|
expression.sql(),
|
||||||
|
"SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
expression.sql(comments=False),
|
||||||
|
"SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
expression.sql(pretty=True, comments=False),
|
||||||
"""SELECT
|
"""SELECT
|
||||||
a,
|
a,
|
||||||
b AS B,
|
b AS B,
|
||||||
c,
|
c,
|
||||||
d AS D""",
|
d AS D,
|
||||||
|
CAST(x AS INT)
|
||||||
|
FROM foo""",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expression.sql(pretty=True),
|
expression.sql(pretty=True),
|
||||||
"""SELECT
|
"""SELECT
|
||||||
a,
|
a,
|
||||||
b AS B,
|
b AS B,
|
||||||
c # comment,
|
c, -- comment
|
||||||
d AS D # another_comment FROM foo""",
|
d AS D, -- another comment
|
||||||
|
CAST(x AS INT) -- final comment
|
||||||
|
FROM foo""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_to_table(self):
|
def test_to_table(self):
|
||||||
|
@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsInstance(expression, exp.Union)
|
self.assertIsInstance(expression, exp.Union)
|
||||||
self.assertEqual(expression.named_selects, ["cola", "colb"])
|
self.assertEqual(expression.named_selects, ["cola", "colb"])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))]
|
expression.selects,
|
||||||
|
[
|
||||||
|
exp.Column(this=exp.to_identifier("cola")),
|
||||||
|
exp.Column(this=exp.to_identifier("colb")),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
|
||||||
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
|
for i, (meta, sql, expected) in enumerate(
|
||||||
|
load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
|
||||||
|
):
|
||||||
title = meta.get("title") or f"{i}, {sql}"
|
title = meta.get("title") or f"{i}, {sql}"
|
||||||
dialect = meta.get("dialect")
|
dialect = meta.get("dialect")
|
||||||
leave_tables_isolated = meta.get("leave_tables_isolated")
|
leave_tables_isolated = meta.get("leave_tables_isolated")
|
||||||
|
@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
|
||||||
|
|
||||||
if string_to_bool(should_execute):
|
if string_to_bool(should_execute):
|
||||||
with self.subTest(f"(execute) {title}"):
|
with self.subTest(f"(execute) {title}"):
|
||||||
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
|
df1 = self.conn.execute(
|
||||||
|
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
|
||||||
|
).df()
|
||||||
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
|
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
|
||||||
assert_frame_equal(df1, df2)
|
assert_frame_equal(df1, df2)
|
||||||
|
|
||||||
|
@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||||
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
|
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
|
scopes[3].expression.sql(),
|
||||||
|
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
|
||||||
)
|
)
|
||||||
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
|
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
|
||||||
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||||
|
@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
|
|
||||||
# Check that we can walk in scope from an arbitrary node
|
# Check that we can walk in scope from an arbitrary node
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
|
{
|
||||||
|
node.sql()
|
||||||
|
for node, *_ in walk_in_scope(expression.find(exp.Where))
|
||||||
|
if isinstance(node, exp.Column)
|
||||||
|
},
|
||||||
{"s.b"},
|
{"s.b"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
|
||||||
|
|
||||||
def test_cache_annotation(self):
|
def test_cache_annotation(self):
|
||||||
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
|
expression = annotate_types(
|
||||||
|
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
|
||||||
|
)
|
||||||
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
|
||||||
|
|
||||||
def test_binary_annotation(self):
|
def test_binary_annotation(self):
|
||||||
|
@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expression = annotate_types(parse_one(sql), schema=schema)
|
expression = annotate_types(parse_one(sql), schema=schema)
|
||||||
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
|
self.assertEqual(
|
||||||
|
expression.expressions[0].type, exp.DataType.Type.TEXT
|
||||||
|
) # tbl.cola + tbl.colb + 'foo' AS col
|
||||||
|
|
||||||
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
|
||||||
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
|
||||||
|
@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
|
||||||
|
|
||||||
cte_select = expression.args["with"].expressions[0].this
|
cte_select = expression.args["with"].expressions[0].this
|
||||||
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
|
self.assertEqual(
|
||||||
|
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
|
||||||
|
) # x.cola + 'bla' AS cola
|
||||||
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
|
||||||
|
|
||||||
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
|
||||||
|
@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
|
||||||
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
|
for d, t in zip(
|
||||||
|
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
|
||||||
|
):
|
||||||
self.assertEqual(d.this.expressions[0].this.type, t)
|
self.assertEqual(d.this.expressions[0].this.type, t)
|
||||||
|
|
||||||
def test_function_annotation(self):
|
def test_function_annotation(self):
|
||||||
|
@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
|
||||||
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
|
||||||
|
|
||||||
|
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
|
||||||
|
|
||||||
|
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
|
||||||
|
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
|
case_expr = case_expr_alias.this
|
||||||
|
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
|
||||||
|
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
|
||||||
|
|
||||||
|
case_ifs_expr = case_expr.args["ifs"][0]
|
||||||
|
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
|
||||||
|
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
|
||||||
|
|
||||||
def test_unknown_annotation(self):
|
def test_unknown_annotation(self):
|
||||||
schema = {"x": {"cola": "VARCHAR"}}
|
schema = {"x": {"cola": "VARCHAR"}}
|
||||||
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
|
||||||
|
@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
||||||
concat_expr = concat_expr_alias.this
|
concat_expr = concat_expr_alias.this
|
||||||
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
|
||||||
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
|
||||||
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
|
self.assertEqual(
|
||||||
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
|
concat_expr.right.type, exp.DataType.Type.UNKNOWN
|
||||||
|
) # SOME_ANONYMOUS_FUNC(x.cola)
|
||||||
|
self.assertEqual(
|
||||||
|
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
|
||||||
|
) # x.cola (arg)
|
||||||
|
|
||||||
def test_null_annotation(self):
|
def test_null_annotation(self):
|
||||||
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
|
||||||
|
|
|
@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
|
||||||
|
|
||||||
def test_float(self):
|
def test_float(self):
|
||||||
self.assertEqual(parse_one(".2"), parse_one("0.2"))
|
self.assertEqual(parse_one(".2"), parse_one("0.2"))
|
||||||
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
|
||||||
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
|
|
||||||
|
|
||||||
def test_table(self):
|
def test_table(self):
|
||||||
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
|
||||||
|
@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
|
||||||
def test_select(self):
|
def test_select(self):
|
||||||
self.assertIsNotNone(parse_one("select 1 natural"))
|
self.assertIsNotNone(parse_one("select 1 natural"))
|
||||||
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
|
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
|
||||||
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
|
self.assertIsNotNone(
|
||||||
|
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
|
||||||
|
)
|
||||||
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
|
||||||
|
@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
|
||||||
def test_var(self):
|
def test_var(self):
|
||||||
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
|
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
|
||||||
|
|
||||||
def test_annotations(self):
|
def test_comments(self):
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
"""
|
"""
|
||||||
SELECT
|
--comment1
|
||||||
a #annotation1,
|
SELECT /* this won't be used */
|
||||||
b as B #annotation2:testing ,
|
a, --comment2
|
||||||
"test#annotation",c#annotation3, d #annotation4,
|
b as B, --comment3:testing
|
||||||
e #,
|
"test--annotation",
|
||||||
f # space
|
c, --comment4 --foo
|
||||||
|
e, --
|
||||||
|
f -- space
|
||||||
FROM foo
|
FROM foo
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
assert expression.expressions[0].name == "annotation1"
|
self.assertEqual(expression.comment, "comment1")
|
||||||
assert expression.expressions[1].name == "annotation2:testing"
|
self.assertEqual(expression.expressions[0].comment, "comment2")
|
||||||
assert expression.expressions[2].name == "test#annotation"
|
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
|
||||||
assert expression.expressions[3].name == "annotation3"
|
self.assertEqual(expression.expressions[2].comment, None)
|
||||||
assert expression.expressions[4].name == "annotation4"
|
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
|
||||||
assert expression.expressions[5].name == ""
|
self.assertEqual(expression.expressions[4].comment, "")
|
||||||
assert expression.expressions[6].name == "space"
|
self.assertEqual(expression.expressions[5].comment, " space")
|
||||||
|
|
||||||
|
def test_type_literals(self):
|
||||||
|
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
|
||||||
|
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMPTZ)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMPLTZ)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMP)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMPTZ(1))",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
|
||||||
|
"CAST('2022-01-01' AS TIMESTAMP(1))",
|
||||||
|
)
|
||||||
|
self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
|
||||||
|
self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
|
||||||
|
self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
|
||||||
|
self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
|
||||||
|
self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
|
||||||
|
self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
|
||||||
|
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
|
||||||
|
self.assertIsInstance(parse_one("map.x"), exp.Column)
|
||||||
|
|
||||||
def test_pretty_config_override(self):
|
def test_pretty_config_override(self):
|
||||||
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
||||||
|
|
|
@ -1,281 +1,141 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import table
|
from sqlglot import exp, to_table
|
||||||
from sqlglot.dataframe.sql import types as df_types
|
from sqlglot.errors import SchemaError
|
||||||
from sqlglot.schema import MappingSchema, ensure_schema
|
from sqlglot.schema import MappingSchema, ensure_schema
|
||||||
|
|
||||||
|
|
||||||
class TestSchema(unittest.TestCase):
|
class TestSchema(unittest.TestCase):
|
||||||
|
def assert_column_names(self, schema, *table_results):
|
||||||
|
for table, result in table_results:
|
||||||
|
with self.subTest(f"{table} -> {result}"):
|
||||||
|
self.assertEqual(schema.column_names(to_table(table)), result)
|
||||||
|
|
||||||
|
def assert_column_names_raises(self, schema, *tables):
|
||||||
|
for table in tables:
|
||||||
|
with self.subTest(table):
|
||||||
|
with self.assertRaises(SchemaError):
|
||||||
|
schema.column_names(to_table(table))
|
||||||
|
|
||||||
def test_schema(self):
|
def test_schema(self):
|
||||||
schema = ensure_schema(
|
schema = ensure_schema(
|
||||||
{
|
{
|
||||||
"x": {
|
"x": {
|
||||||
"a": "uint64",
|
"a": "uint64",
|
||||||
}
|
},
|
||||||
}
|
"y": {
|
||||||
|
"b": "uint64",
|
||||||
|
"c": "uint64",
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
|
||||||
schema.column_names(
|
self.assert_column_names(
|
||||||
table(
|
schema,
|
||||||
|
("x", ["a"]),
|
||||||
|
("y", ["b", "c"]),
|
||||||
|
("z.x", ["a"]),
|
||||||
|
("z.x.y", ["b", "c"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names_raises(
|
||||||
|
schema,
|
||||||
|
"z",
|
||||||
|
"z.z",
|
||||||
|
"z.z.z",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_schema_db(self):
|
||||||
|
schema = ensure_schema(
|
||||||
|
{
|
||||||
|
"d1": {
|
||||||
|
"x": {
|
||||||
|
"a": "uint64",
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"b": "uint64",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"d2": {
|
||||||
|
"x": {
|
||||||
|
"c": "uint64",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names(
|
||||||
|
schema,
|
||||||
|
("d1.x", ["a"]),
|
||||||
|
("d2.x", ["c"]),
|
||||||
|
("y", ["b"]),
|
||||||
|
("d1.y", ["b"]),
|
||||||
|
("z.d1.y", ["b"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names_raises(
|
||||||
|
schema,
|
||||||
"x",
|
"x",
|
||||||
)
|
"z.x",
|
||||||
),
|
"z.y",
|
||||||
["a"],
|
|
||||||
)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db", catalog="c"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x2"))
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
|
||||||
|
|
||||||
schema.add_table(table("y"), {"b": "string"})
|
|
||||||
schema_with_y = {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"b": "string",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
self.assertEqual(schema.schema, schema_with_y)
|
|
||||||
|
|
||||||
new_schema = schema.copy()
|
|
||||||
new_schema.add_table(table("z"), {"c": "string"})
|
|
||||||
self.assertEqual(schema.schema, schema_with_y)
|
|
||||||
self.assertEqual(
|
|
||||||
new_schema.schema,
|
|
||||||
{
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"b": "string",
|
|
||||||
},
|
|
||||||
"z": {
|
|
||||||
"c": "string",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
schema.add_table(table("m"), {"d": "string"})
|
|
||||||
schema.add_table(table("n"), {"e": "string"})
|
|
||||||
schema_with_m_n = {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"b": "string",
|
|
||||||
},
|
|
||||||
"m": {
|
|
||||||
"d": "string",
|
|
||||||
},
|
|
||||||
"n": {
|
|
||||||
"e": "string",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
self.assertEqual(schema.schema, schema_with_m_n)
|
|
||||||
new_schema = schema.copy()
|
|
||||||
new_schema.add_table(table("o"), {"f": "string"})
|
|
||||||
new_schema.add_table(table("p"), {"g": "string"})
|
|
||||||
self.assertEqual(schema.schema, schema_with_m_n)
|
|
||||||
self.assertEqual(
|
|
||||||
new_schema.schema,
|
|
||||||
{
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"b": "string",
|
|
||||||
},
|
|
||||||
"m": {
|
|
||||||
"d": "string",
|
|
||||||
},
|
|
||||||
"n": {
|
|
||||||
"e": "string",
|
|
||||||
},
|
|
||||||
"o": {
|
|
||||||
"f": "string",
|
|
||||||
},
|
|
||||||
"p": {
|
|
||||||
"g": "string",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_schema_catalog(self):
|
||||||
schema = ensure_schema(
|
schema = ensure_schema(
|
||||||
{
|
{
|
||||||
"db": {
|
"c1": {
|
||||||
"x": {
|
"d1": {
|
||||||
"a": "uint64",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db", catalog="c"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db2"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x2", db="db"))
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("y"), {"b": "string"})
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
|
||||||
|
|
||||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"db": {
|
|
||||||
"x": {
|
"x": {
|
||||||
"a": "uint64",
|
"a": "uint64",
|
||||||
},
|
},
|
||||||
"y": {
|
"y": {
|
||||||
"b": "string",
|
"b": "uint64",
|
||||||
},
|
},
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = ensure_schema(
|
|
||||||
{
|
|
||||||
"c": {
|
|
||||||
"db": {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db", catalog="c2"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x", db="db2"))
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.column_names(table("x2", db="db"))
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("x"), {"b": "string"})
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
schema.add_table(table("x", db="db"), {"b": "string"})
|
|
||||||
|
|
||||||
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"c": {
|
|
||||||
"db": {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"a": "string",
|
|
||||||
"b": "int",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"c": {
|
|
||||||
"db": {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"a": "string",
|
|
||||||
"b": "int",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"db2": {
|
|
||||||
"z": {
|
"z": {
|
||||||
"c": "string",
|
"c": "uint64",
|
||||||
"d": "int",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"c": {
|
|
||||||
"db": {
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
},
|
|
||||||
"y": {
|
|
||||||
"a": "string",
|
|
||||||
"b": "int",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"db2": {
|
|
||||||
"z": {
|
|
||||||
"c": "string",
|
|
||||||
"d": "int",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"c2": {
|
"c2": {
|
||||||
"db2": {
|
"d1": {
|
||||||
"m": {
|
|
||||||
"e": "string",
|
|
||||||
"f": "int",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = ensure_schema(
|
|
||||||
{
|
|
||||||
"x": {
|
|
||||||
"a": "uint64",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(schema.column_names(table("x")), ["a"])
|
|
||||||
|
|
||||||
schema = MappingSchema()
|
|
||||||
schema.add_table(table("x"), {"a": "string"})
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"x": {
|
|
||||||
"a": "string",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
|
|
||||||
self.assertEqual(
|
|
||||||
schema.schema,
|
|
||||||
{
|
|
||||||
"x": {
|
|
||||||
"a": "string",
|
|
||||||
},
|
|
||||||
"y": {
|
"y": {
|
||||||
"b": "string",
|
"d": "uint64",
|
||||||
|
},
|
||||||
|
"z": {
|
||||||
|
"e": "uint64",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"d2": {
|
||||||
|
"z": {
|
||||||
|
"f": "uint64",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names(
|
||||||
|
schema,
|
||||||
|
("x", ["a"]),
|
||||||
|
("d1.x", ["a"]),
|
||||||
|
("c1.d1.x", ["a"]),
|
||||||
|
("c1.d1.y", ["b"]),
|
||||||
|
("c1.d1.z", ["c"]),
|
||||||
|
("c2.d1.y", ["d"]),
|
||||||
|
("c2.d1.z", ["e"]),
|
||||||
|
("d2.z", ["f"]),
|
||||||
|
("c2.d2.z", ["f"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names_raises(
|
||||||
|
schema,
|
||||||
|
"q",
|
||||||
|
"d2.x",
|
||||||
|
"y",
|
||||||
|
"z",
|
||||||
|
"d1.y",
|
||||||
|
"d1.z",
|
||||||
|
"a.b.c",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_schema_add_table_with_and_without_mapping(self):
|
def test_schema_add_table_with_and_without_mapping(self):
|
||||||
|
@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
|
||||||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
||||||
schema.add_table("test")
|
schema.add_table("test")
|
||||||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
||||||
|
|
||||||
|
def test_schema_get_column_type(self):
|
||||||
|
schema = MappingSchema({"a": {"b": "varchar"}})
|
||||||
|
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
|
||||||
|
exp.DataType.Type.VARCHAR,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
|
||||||
|
)
|
||||||
|
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
|
||||||
|
exp.DataType.Type.VARCHAR,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
|
||||||
|
)
|
||||||
|
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
|
||||||
|
exp.DataType.Type.VARCHAR,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
|
||||||
|
exp.DataType.Type.VARCHAR,
|
||||||
|
)
|
||||||
|
|
18
tests/test_tokens.py
Normal file
18
tests/test_tokens.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from sqlglot.tokens import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokens(unittest.TestCase):
|
||||||
|
def test_comment_attachment(self):
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
sql_comment = [
|
||||||
|
("/*comment*/ foo", "comment"),
|
||||||
|
("/*comment*/ foo --test", "comment"),
|
||||||
|
("--comment\nfoo --test", "comment"),
|
||||||
|
("foo --comment", "comment"),
|
||||||
|
("foo", None),
|
||||||
|
]
|
||||||
|
|
||||||
|
for sql, comment in sql_comment:
|
||||||
|
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)
|
|
@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
|
||||||
leading_comma=True,
|
leading_comma=True,
|
||||||
pretty=True,
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
self.validate(
|
||||||
|
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
|
||||||
|
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
|
||||||
|
leading_comma=True,
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
# without pretty, this should be a no-op
|
# without pretty, this should be a no-op
|
||||||
self.validate(
|
self.validate(
|
||||||
"SELECT FOO, BAR, BAZ",
|
"SELECT FOO, BAR, BAZ",
|
||||||
|
@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
|
||||||
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
|
||||||
|
|
||||||
def test_comments(self):
|
def test_comments(self):
|
||||||
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo")
|
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
|
||||||
self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo")
|
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
|
||||||
|
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
|
||||||
|
self.validate(
|
||||||
|
"SELECT 1 /* inline */ FROM foo -- comment",
|
||||||
|
"SELECT 1 /* inline */ FROM foo /* comment */",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
|
||||||
|
)
|
||||||
self.validate(
|
self.validate(
|
||||||
"""
|
"""
|
||||||
SELECT 1 -- comment
|
SELECT 1 -- comment
|
||||||
FROM foo -- comment
|
FROM foo -- comment
|
||||||
""",
|
""",
|
||||||
"SELECT 1 FROM foo",
|
"SELECT 1 /* comment */ FROM foo /* comment */",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.validate(
|
self.validate(
|
||||||
"""
|
"""
|
||||||
SELECT 1 /* big comment
|
SELECT 1 /* big comment
|
||||||
like this */
|
like this */
|
||||||
FROM foo -- comment
|
FROM foo -- comment
|
||||||
""",
|
""",
|
||||||
"SELECT 1 FROM foo",
|
"""SELECT 1 /* big comment
|
||||||
|
like this */ FROM foo /* comment */""",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"select x from foo -- x",
|
||||||
|
"SELECT x FROM foo /* x */",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"""
|
||||||
|
/* multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT
|
||||||
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
|
CAST(x AS INT), # comment 3
|
||||||
|
y -- comment 4
|
||||||
|
FROM
|
||||||
|
bar /* comment 5 */,
|
||||||
|
tbl # comment 6
|
||||||
|
""",
|
||||||
|
"""/* multi
|
||||||
|
line
|
||||||
|
comment
|
||||||
|
*/
|
||||||
|
SELECT
|
||||||
|
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
|
||||||
|
CAST(x AS INT), -- comment 3
|
||||||
|
y -- comment 4
|
||||||
|
FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||||
|
read="mysql",
|
||||||
|
pretty=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_types(self):
|
def test_types(self):
|
||||||
|
@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
|
||||||
def test_ignore_nulls(self):
|
def test_ignore_nulls(self):
|
||||||
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
|
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
|
||||||
|
|
||||||
|
def test_with(self):
|
||||||
|
self.validate(
|
||||||
|
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
|
||||||
|
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
||||||
|
)
|
||||||
|
self.validate(
|
||||||
|
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
|
||||||
|
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
|
||||||
|
)
|
||||||
|
|
||||||
def test_time(self):
|
def test_time(self):
|
||||||
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
|
||||||
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
|
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue