Merging upstream version 10.5.10.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8588db6332
commit
4d496b7a6a
43 changed files with 1384 additions and 356 deletions
|
@ -32,6 +32,7 @@ We use GitHub issues to track public bugs. Report a bug by opening a new issue.
|
||||||
- What you expected would happen
|
- What you expected would happen
|
||||||
- What actually happens
|
- What actually happens
|
||||||
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
|
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
|
||||||
|
- References (e.g. documentation pages related to the issue)
|
||||||
|
|
||||||
## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions)
|
## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions)
|
||||||
[We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state
|
[We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -18,7 +18,7 @@ style:
|
||||||
check: style test
|
check: style test
|
||||||
|
|
||||||
docs:
|
docs:
|
||||||
pdoc/cli.py -o pdoc/docs
|
python pdoc/cli.py -o pdoc/docs
|
||||||
|
|
||||||
docs-serve:
|
docs-serve:
|
||||||
pdoc/cli.py
|
python pdoc/cli.py
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# SQLGlot
|
# SQLGlot
|
||||||
|
|
||||||
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
||||||
|
|
||||||
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
|
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
|
||||||
|
|
||||||
|
@ -189,7 +189,6 @@ except sqlglot.errors.ParseError as e:
|
||||||
print(e.errors)
|
print(e.errors)
|
||||||
```
|
```
|
||||||
|
|
||||||
Output:
|
|
||||||
```python
|
```python
|
||||||
[{
|
[{
|
||||||
'description': 'Expecting )',
|
'description': 'Expecting )',
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
# Expressions
|
|
||||||
|
|
||||||
Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not.
|
|
||||||
|
|
||||||
Furthermore, the following attributes are common across all expressions:
|
|
||||||
|
|
||||||
#### key
|
|
||||||
|
|
||||||
A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings.
|
|
||||||
|
|
||||||
#### args
|
|
||||||
|
|
||||||
A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value.
|
|
||||||
|
|
||||||
#### arg_types
|
|
||||||
|
|
||||||
A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class Limit(Expression):
|
|
||||||
arg_types = {"this": False, "expression": True}
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase.
|
|
||||||
|
|
||||||
#### parent
|
|
||||||
|
|
||||||
A reference to the parent expression (may be `None`).
|
|
||||||
|
|
||||||
#### arg_key
|
|
||||||
|
|
||||||
The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it.
|
|
||||||
|
|
||||||
#### comments
|
|
||||||
|
|
||||||
A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code.
|
|
||||||
|
|
||||||
#### type
|
|
||||||
|
|
||||||
The data type of an expression, as inferred by SQLGlot's optimizer.
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../README.md
|
.. include:: ../README.md
|
||||||
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
@ -29,14 +30,16 @@ from sqlglot.expressions import table_ as table
|
||||||
from sqlglot.expressions import to_column, to_table, union
|
from sqlglot.expressions import to_column, to_table, union
|
||||||
from sqlglot.generator import Generator
|
from sqlglot.generator import Generator
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema
|
from sqlglot.schema import MappingSchema, Schema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.5.6"
|
__version__ = "10.5.10"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
"""Whether to format generated SQL by default."""
|
||||||
|
|
||||||
schema = MappingSchema()
|
schema = MappingSchema()
|
||||||
|
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||||
|
|
||||||
|
|
||||||
def parse(
|
def parse(
|
||||||
|
@ -48,7 +51,7 @@ def parse(
|
||||||
Args:
|
Args:
|
||||||
sql: the SQL code string to parse.
|
sql: the SQL code string to parse.
|
||||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||||
**opts: other options.
|
**opts: other `sqlglot.parser.Parser` options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resulting syntax tree collection.
|
The resulting syntax tree collection.
|
||||||
|
@ -60,7 +63,7 @@ def parse(
|
||||||
def parse_one(
|
def parse_one(
|
||||||
sql: str,
|
sql: str,
|
||||||
read: t.Optional[str | Dialect] = None,
|
read: t.Optional[str | Dialect] = None,
|
||||||
into: t.Optional[t.Type[Expression] | str] = None,
|
into: t.Optional[exp.IntoType] = None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> Expression:
|
) -> Expression:
|
||||||
"""
|
"""
|
||||||
|
@ -70,7 +73,7 @@ def parse_one(
|
||||||
sql: the SQL code string to parse.
|
sql: the SQL code string to parse.
|
||||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||||
into: the SQLGlot Expression to parse into.
|
into: the SQLGlot Expression to parse into.
|
||||||
**opts: other options.
|
**opts: other `sqlglot.parser.Parser` options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The syntax tree for the first parsed statement.
|
The syntax tree for the first parsed statement.
|
||||||
|
@ -110,7 +113,7 @@ def transpile(
|
||||||
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
|
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
|
||||||
the source and the target dialect.
|
the source and the target dialect.
|
||||||
error_level: the desired error level of the parser.
|
error_level: the desired error level of the parser.
|
||||||
**opts: other options.
|
**opts: other `sqlglot.generator.Generator` options.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The list of transpiled SQL statements.
|
The list of transpiled SQL statements.
|
||||||
|
|
|
@ -1,29 +1,29 @@
|
||||||
# PySpark DataFrame SQL Generator
|
# PySpark DataFrame SQL Generator
|
||||||
|
|
||||||
This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
|
This is a drop-in replacement for the PySpark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
|
||||||
|
|
||||||
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
|
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
|
||||||
|
|
||||||
# How to use
|
# How to use
|
||||||
|
|
||||||
## Instructions
|
## Instructions
|
||||||
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
|
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
|
||||||
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
|
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
|
||||||
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
|
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
|
||||||
* The column structure can be defined the following ways:
|
* The column structure can be defined the following ways:
|
||||||
* Dictionary where the keys are column names and values are string of the Spark SQL type name
|
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
|
||||||
* Ex: {'cola': 'string', 'colb': 'int'}
|
* Ex: `{'cola': 'string', 'colb': 'int'}`
|
||||||
* PySpark DataFrame `StructType` similar to when using `createDataFrame`
|
* PySpark DataFrame `StructType` similar to when using `createDataFrame`.
|
||||||
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
|
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
|
||||||
* A string of names and types similar to what is supported in `createDataFrame`
|
* A string of names and types similar to what is supported in `createDataFrame`.
|
||||||
* Ex: `cola: STRING, colb: INT`
|
* Ex: `cola: STRING, colb: INT`
|
||||||
* [Not Recommended] A list of string column names without type
|
* [Not Recommended] A list of string column names without type.
|
||||||
* Ex: ['cola', 'colb']
|
* Ex: `['cola', 'colb']`
|
||||||
* The lack of types may limit functionality in future releases
|
* The lack of types may limit functionality in future releases.
|
||||||
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
|
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
|
||||||
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
|
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
|
||||||
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
|
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
|
||||||
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
|
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
|
||||||
* Ex: `.sql(pretty=True, dialect='bigquery')`
|
* Ex: `.sql(pretty=True, dialect='bigquery')`
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
@ -51,7 +51,7 @@ df = (
|
||||||
|
|
||||||
print(df.sql(pretty=True)) # Spark will be the dialect used by default
|
print(df.sql(pretty=True)) # Spark will be the dialect used by default
|
||||||
```
|
```
|
||||||
Output:
|
|
||||||
```sparksql
|
```sparksql
|
||||||
SELECT
|
SELECT
|
||||||
`employee`.`age` AS `age`,
|
`employee`.`age` AS `age`,
|
||||||
|
@ -206,7 +206,7 @@ sql_statements = (
|
||||||
.createDataFrame(data, schema)
|
.createDataFrame(data, schema)
|
||||||
.groupBy(F.col("age"))
|
.groupBy(F.col("age"))
|
||||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||||
.sql(dialect="bigquery")
|
.sql(dialect="spark")
|
||||||
)
|
)
|
||||||
|
|
||||||
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
|
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
|
||||||
|
|
|
@ -111,16 +111,13 @@ class DataFrame:
|
||||||
return DataFrameNaFunctions(self)
|
return DataFrameNaFunctions(self)
|
||||||
|
|
||||||
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
||||||
expression = expression.copy()
|
|
||||||
ctes = expression.ctes
|
|
||||||
replacement_mapping = {}
|
replacement_mapping = {}
|
||||||
for cte in ctes:
|
for cte in expression.ctes:
|
||||||
old_name_id = cte.args["alias"].this
|
old_name_id = cte.args["alias"].this
|
||||||
new_hashed_id = exp.to_identifier(
|
new_hashed_id = exp.to_identifier(
|
||||||
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
||||||
)
|
)
|
||||||
replacement_mapping[old_name_id] = new_hashed_id
|
replacement_mapping[old_name_id] = new_hashed_id
|
||||||
cte.set("alias", exp.TableAlias(this=new_hashed_id))
|
|
||||||
expression = expression.transform(replace_id_value, replacement_mapping)
|
expression = expression.transform(replace_id_value, replacement_mapping)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@ -183,7 +180,7 @@ class DataFrame:
|
||||||
expression = df.expression
|
expression = df.expression
|
||||||
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
||||||
for hint in df.pending_partition_hints:
|
for hint in df.pending_partition_hints:
|
||||||
hint_expression.args.get("expressions").append(hint)
|
hint_expression.append("expressions", hint)
|
||||||
df.pending_hints.remove(hint)
|
df.pending_hints.remove(hint)
|
||||||
|
|
||||||
join_aliases = {
|
join_aliases = {
|
||||||
|
@ -209,7 +206,7 @@ class DataFrame:
|
||||||
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
||||||
df.pending_hints.remove(hint)
|
df.pending_hints.remove(hint)
|
||||||
break
|
break
|
||||||
hint_expression.args.get("expressions").append(hint)
|
hint_expression.append("expressions", hint)
|
||||||
if hint_expression.expressions:
|
if hint_expression.expressions:
|
||||||
expression.set("hint", hint_expression)
|
expression.set("hint", hint_expression)
|
||||||
return df
|
return df
|
||||||
|
|
|
@ -129,7 +129,7 @@ class SparkSession:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _random_name(self) -> str:
|
def _random_name(self) -> str:
|
||||||
return f"a{str(uuid.uuid4())[:8]}"
|
return "r" + uuid.uuid4().hex
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _random_branch_id(self) -> str:
|
def _random_branch_id(self) -> str:
|
||||||
|
@ -145,7 +145,7 @@ class SparkSession:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _random_id(self) -> str:
|
def _random_id(self) -> str:
|
||||||
id = f"a{str(uuid.uuid4())[:8]}"
|
id = self._random_name
|
||||||
self.known_ids.add(id)
|
self.known_ids.add(id)
|
||||||
return id
|
return id
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,64 @@
|
||||||
|
"""
|
||||||
|
## Dialects
|
||||||
|
|
||||||
|
One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a
|
||||||
|
"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`,
|
||||||
|
`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming
|
||||||
|
SQL code.
|
||||||
|
|
||||||
|
However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such
|
||||||
|
example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to
|
||||||
|
override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new
|
||||||
|
dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class),
|
||||||
|
in order to implement the target behavior.
|
||||||
|
|
||||||
|
|
||||||
|
### Implementing a custom Dialect
|
||||||
|
|
||||||
|
Consider the following example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlglot import exp
|
||||||
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
from sqlglot.generator import Generator
|
||||||
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class Custom(Dialect):
|
||||||
|
class Tokenizer(Tokenizer):
|
||||||
|
QUOTES = ["'", '"']
|
||||||
|
IDENTIFIERS = ["`"]
|
||||||
|
|
||||||
|
KEYWORDS = {
|
||||||
|
**Tokenizer.KEYWORDS,
|
||||||
|
"INT64": TokenType.BIGINT,
|
||||||
|
"FLOAT64": TokenType.DOUBLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Generator(Generator):
|
||||||
|
TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
|
||||||
|
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
exp.DataType.Type.TINYINT: "INT64",
|
||||||
|
exp.DataType.Type.SMALLINT: "INT64",
|
||||||
|
exp.DataType.Type.INT: "INT64",
|
||||||
|
exp.DataType.Type.BIGINT: "INT64",
|
||||||
|
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||||
|
exp.DataType.Type.FLOAT: "FLOAT64",
|
||||||
|
exp.DataType.Type.DOUBLE: "FLOAT64",
|
||||||
|
exp.DataType.Type.BOOLEAN: "BOOL",
|
||||||
|
exp.DataType.Type.TEXT: "STRING",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
|
||||||
|
delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
|
||||||
|
the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
|
||||||
|
logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
|
||||||
|
|
||||||
|
----
|
||||||
|
"""
|
||||||
|
|
||||||
from sqlglot.dialects.bigquery import BigQuery
|
from sqlglot.dialects.bigquery import BigQuery
|
||||||
from sqlglot.dialects.clickhouse import ClickHouse
|
from sqlglot.dialects.clickhouse import ClickHouse
|
||||||
from sqlglot.dialects.databricks import Databricks
|
from sqlglot.dialects.databricks import Databricks
|
||||||
|
|
|
@ -124,7 +124,6 @@ class BigQuery(Dialect):
|
||||||
"FLOAT64": TokenType.DOUBLE,
|
"FLOAT64": TokenType.DOUBLE,
|
||||||
"INT64": TokenType.BIGINT,
|
"INT64": TokenType.BIGINT,
|
||||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
|
||||||
"UNKNOWN": TokenType.NULL,
|
"UNKNOWN": TokenType.NULL,
|
||||||
}
|
}
|
||||||
KEYWORDS.pop("DIV")
|
KEYWORDS.pop("DIV")
|
||||||
|
|
|
@ -73,13 +73,8 @@ class ClickHouse(Dialect):
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
|
||||||
def _parse_position(self) -> exp.Expression:
|
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
|
||||||
this = super()._parse_position()
|
return super()._parse_position(haystack_first=True)
|
||||||
# clickhouse position args are swapped
|
|
||||||
substr = this.this
|
|
||||||
this.args["this"] = this.args.get("substr")
|
|
||||||
this.args["substr"] = substr
|
|
||||||
return this
|
|
||||||
|
|
||||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
|
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
|
||||||
def _parse_cte(self) -> exp.Expression:
|
def _parse_cte(self) -> exp.Expression:
|
||||||
|
|
|
@ -124,6 +124,8 @@ class MySQL(Dialect):
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||||
"LONGTEXT": TokenType.LONGTEXT,
|
"LONGTEXT": TokenType.LONGTEXT,
|
||||||
|
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||||
|
"LONGBLOB": TokenType.LONGBLOB,
|
||||||
"START": TokenType.BEGIN,
|
"START": TokenType.BEGIN,
|
||||||
"SEPARATOR": TokenType.SEPARATOR,
|
"SEPARATOR": TokenType.SEPARATOR,
|
||||||
"_ARMSCII8": TokenType.INTRODUCER,
|
"_ARMSCII8": TokenType.INTRODUCER,
|
||||||
|
@ -459,6 +461,8 @@ class MySQL(Dialect):
|
||||||
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
|
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
|
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
|
||||||
|
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
|
||||||
|
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
|
||||||
|
|
||||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||||
|
|
||||||
|
|
|
@ -194,7 +194,8 @@ class Snowflake(Dialect):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
"EXCLUDE": TokenType.EXCEPT,
|
||||||
|
"RENAME": TokenType.REPLACE,
|
||||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||||
|
@ -232,6 +233,11 @@ class Snowflake(Dialect):
|
||||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STAR_MAPPING = {
|
||||||
|
"except": "EXCLUDE",
|
||||||
|
"replace": "RENAME",
|
||||||
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
ROOT_PROPERTIES = {
|
||||||
exp.PartitionedByProperty,
|
exp.PartitionedByProperty,
|
||||||
exp.ReturnsProperty,
|
exp.ReturnsProperty,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
||||||
|
@ -251,6 +252,7 @@ class TSQL(Dialect):
|
||||||
"NTEXT": TokenType.TEXT,
|
"NTEXT": TokenType.TEXT,
|
||||||
"NVARCHAR(MAX)": TokenType.TEXT,
|
"NVARCHAR(MAX)": TokenType.TEXT,
|
||||||
"PRINT": TokenType.COMMAND,
|
"PRINT": TokenType.COMMAND,
|
||||||
|
"PROC": TokenType.PROCEDURE,
|
||||||
"REAL": TokenType.FLOAT,
|
"REAL": TokenType.FLOAT,
|
||||||
"ROWVERSION": TokenType.ROWVERSION,
|
"ROWVERSION": TokenType.ROWVERSION,
|
||||||
"SMALLDATETIME": TokenType.DATETIME,
|
"SMALLDATETIME": TokenType.DATETIME,
|
||||||
|
@ -263,6 +265,11 @@ class TSQL(Dialect):
|
||||||
"XML": TokenType.XML,
|
"XML": TokenType.XML,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TSQL allows @, # to appear as a variable/identifier prefix
|
||||||
|
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
|
||||||
|
SINGLE_TOKENS.pop("@")
|
||||||
|
SINGLE_TOKENS.pop("#")
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS, # type: ignore
|
**parser.Parser.FUNCTIONS, # type: ignore
|
||||||
|
@ -293,26 +300,82 @@ class TSQL(Dialect):
|
||||||
DataType.Type.NCHAR,
|
DataType.Type.NCHAR,
|
||||||
}
|
}
|
||||||
|
|
||||||
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
|
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
|
||||||
TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
|
TokenType.TABLE,
|
||||||
|
*parser.Parser.TYPE_TOKENS, # type: ignore
|
||||||
|
}
|
||||||
|
|
||||||
def _parse_convert(self, strict):
|
STATEMENT_PARSERS = {
|
||||||
|
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||||
|
TokenType.END: lambda self: self._parse_command(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_system_time(self) -> t.Optional[exp.Expression]:
|
||||||
|
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._match_text_seq("AS", "OF"):
|
||||||
|
system_time = self.expression(
|
||||||
|
exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
|
||||||
|
)
|
||||||
|
elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
|
||||||
|
kind = self._prev.text
|
||||||
|
this = self._parse_bitwise()
|
||||||
|
self._match_texts(("TO", "AND"))
|
||||||
|
expression = self._parse_bitwise()
|
||||||
|
system_time = self.expression(
|
||||||
|
exp.SystemTime, this=this, expression=expression, kind=kind
|
||||||
|
)
|
||||||
|
elif self._match_text_seq("CONTAINED", "IN"):
|
||||||
|
args = self._parse_wrapped_csv(self._parse_bitwise)
|
||||||
|
system_time = self.expression(
|
||||||
|
exp.SystemTime,
|
||||||
|
this=seq_get(args, 0),
|
||||||
|
expression=seq_get(args, 1),
|
||||||
|
kind="CONTAINED IN",
|
||||||
|
)
|
||||||
|
elif self._match(TokenType.ALL):
|
||||||
|
system_time = self.expression(exp.SystemTime, kind="ALL")
|
||||||
|
else:
|
||||||
|
system_time = None
|
||||||
|
self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
|
||||||
|
|
||||||
|
return system_time
|
||||||
|
|
||||||
|
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
|
||||||
|
table = super()._parse_table_parts(schema=schema)
|
||||||
|
table.set("system_time", self._parse_system_time())
|
||||||
|
return table
|
||||||
|
|
||||||
|
def _parse_returns(self) -> exp.Expression:
|
||||||
|
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
|
||||||
|
returns = super()._parse_returns()
|
||||||
|
returns.set("table", table)
|
||||||
|
return returns
|
||||||
|
|
||||||
|
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||||
to = self._parse_types()
|
to = self._parse_types()
|
||||||
self._match(TokenType.COMMA)
|
self._match(TokenType.COMMA)
|
||||||
this = self._parse_conjunction()
|
this = self._parse_conjunction()
|
||||||
|
|
||||||
|
if not to or not this:
|
||||||
|
return None
|
||||||
|
|
||||||
# Retrieve length of datatype and override to default if not specified
|
# Retrieve length of datatype and override to default if not specified
|
||||||
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||||
|
|
||||||
# Check whether a conversion with format is applicable
|
# Check whether a conversion with format is applicable
|
||||||
if self._match(TokenType.COMMA):
|
if self._match(TokenType.COMMA):
|
||||||
format_val = self._parse_number().name
|
format_val = self._parse_number()
|
||||||
if format_val not in TSQL.convert_format_mapping:
|
format_val_name = format_val.name if format_val else ""
|
||||||
|
|
||||||
|
if format_val_name not in TSQL.convert_format_mapping:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"CONVERT function at T-SQL does not support format style {format_val}"
|
f"CONVERT function at T-SQL does not support format style {format_val_name}"
|
||||||
)
|
)
|
||||||
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
|
|
||||||
|
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
|
||||||
|
|
||||||
# Check whether the convert entails a string to date format
|
# Check whether the convert entails a string to date format
|
||||||
if to.this == DataType.Type.DATE:
|
if to.this == DataType.Type.DATE:
|
||||||
|
@ -333,6 +396,21 @@ class TSQL(Dialect):
|
||||||
# Entails a simple cast without any format requirement
|
# Entails a simple cast without any format requirement
|
||||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||||
|
|
||||||
|
def _parse_user_defined_function(
|
||||||
|
self, kind: t.Optional[TokenType] = None
|
||||||
|
) -> t.Optional[exp.Expression]:
|
||||||
|
this = super()._parse_user_defined_function(kind=kind)
|
||||||
|
|
||||||
|
if (
|
||||||
|
kind == TokenType.FUNCTION
|
||||||
|
or isinstance(this, exp.UserDefinedFunction)
|
||||||
|
or self._match(TokenType.ALIAS, advance=False)
|
||||||
|
):
|
||||||
|
return this
|
||||||
|
|
||||||
|
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||||
|
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||||
|
@ -354,3 +432,27 @@ class TSQL(Dialect):
|
||||||
exp.TimeToStr: _format_sql,
|
exp.TimeToStr: _format_sql,
|
||||||
exp.GroupConcat: _string_agg_sql,
|
exp.GroupConcat: _string_agg_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||||
|
|
||||||
|
def systemtime_sql(self, expression: exp.SystemTime) -> str:
|
||||||
|
kind = expression.args["kind"]
|
||||||
|
if kind == "ALL":
|
||||||
|
return "FOR SYSTEM_TIME ALL"
|
||||||
|
|
||||||
|
start = self.sql(expression, "this")
|
||||||
|
if kind == "AS OF":
|
||||||
|
return f"FOR SYSTEM_TIME AS OF {start}"
|
||||||
|
|
||||||
|
end = self.sql(expression, "expression")
|
||||||
|
if kind == "FROM":
|
||||||
|
return f"FOR SYSTEM_TIME FROM {start} TO {end}"
|
||||||
|
if kind == "BETWEEN":
|
||||||
|
return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
|
||||||
|
|
||||||
|
return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
|
||||||
|
|
||||||
|
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
|
||||||
|
table = expression.args.get("table")
|
||||||
|
table = f"{table} " if table else ""
|
||||||
|
return f"RETURNS {table}{self.sql(expression, 'this')}"
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../posts/sql_diff.md
|
.. include:: ../posts/sql_diff.md
|
||||||
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
@ -75,12 +76,13 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
||||||
]
|
]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (sqlglot.Expression): the source expression.
|
source: the source expression.
|
||||||
target (sqlglot.Expression): the target expression against which the diff should be calculated.
|
target: the target expression against which the diff should be calculated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
|
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
|
||||||
This list represents a sequence of steps needed to transform the source expression tree into the target one.
|
target expression trees. This list represents a sequence of steps needed to transform the source
|
||||||
|
expression tree into the target one.
|
||||||
"""
|
"""
|
||||||
return ChangeDistiller().diff(source.copy(), target.copy())
|
return ChangeDistiller().diff(source.copy(), target.copy())
|
||||||
|
|
||||||
|
@ -258,7 +260,7 @@ class ChangeDistiller:
|
||||||
return bigram_histo
|
return bigram_histo
|
||||||
|
|
||||||
|
|
||||||
def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
|
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||||
has_child_exprs = False
|
has_child_exprs = False
|
||||||
|
|
||||||
for a in expression.args.values():
|
for a in expression.args.values():
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Context:
|
||||||
reader = table[i]
|
reader = table[i]
|
||||||
yield reader, self
|
yield reader, self
|
||||||
|
|
||||||
def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
|
def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
|
||||||
self.env["scope"] = self.row_readers
|
self.env["scope"] = self.row_readers
|
||||||
|
|
||||||
for reader in self.tables[table]:
|
for reader in self.tables[table]:
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../pdoc/docs/expressions.md
|
## Expressions
|
||||||
|
|
||||||
|
Every AST node in SQLGlot is represented by a subclass of `Expression`.
|
||||||
|
|
||||||
|
This module contains the implementation of all supported `Expression` types. Additionally,
|
||||||
|
it exposes a number of helper functions, which are mainly used to programmatically build
|
||||||
|
SQL expressions, such as `sqlglot.expressions.select`.
|
||||||
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
@ -27,35 +34,66 @@ from sqlglot.tokens import Token
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.dialects.dialect import Dialect
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
|
||||||
|
IntoType = t.Union[
|
||||||
|
str,
|
||||||
|
t.Type[Expression],
|
||||||
|
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class _Expression(type):
|
class _Expression(type):
|
||||||
def __new__(cls, clsname, bases, attrs):
|
def __new__(cls, clsname, bases, attrs):
|
||||||
klass = super().__new__(cls, clsname, bases, attrs)
|
klass = super().__new__(cls, clsname, bases, attrs)
|
||||||
|
|
||||||
|
# When an Expression class is created, its key is automatically set to be
|
||||||
|
# the lowercase version of the class' name.
|
||||||
klass.key = clsname.lower()
|
klass.key = clsname.lower()
|
||||||
|
|
||||||
|
# This is so that docstrings are not inherited in pdoc
|
||||||
|
klass.__doc__ = klass.__doc__ or ""
|
||||||
|
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
|
|
||||||
class Expression(metaclass=_Expression):
|
class Expression(metaclass=_Expression):
|
||||||
"""
|
"""
|
||||||
The base class for all expressions in a syntax tree.
|
The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
|
||||||
|
context, such as its child expressions, their names (arg keys), and whether a given child expression
|
||||||
|
is optional or not.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
arg_types (dict): determines arguments supported by this expression.
|
key: a unique key for each class in the Expression hierarchy. This is useful for hashing
|
||||||
The key in a dictionary defines a unique key of an argument using
|
and representing expressions as strings.
|
||||||
which the argument's value can be retrieved. The value is a boolean
|
arg_types: determines what arguments (child nodes) are supported by an expression. It
|
||||||
flag which indicates whether the argument's value is required (True)
|
maps arg keys to booleans that indicate whether the corresponding args are optional.
|
||||||
or optional (False).
|
|
||||||
|
Example:
|
||||||
|
>>> class Foo(Expression):
|
||||||
|
... arg_types = {"this": True, "expression": False}
|
||||||
|
|
||||||
|
The above definition informs us that Foo is an Expression that requires an argument called
|
||||||
|
"this" and may also optionally receive an argument called "expression".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
|
||||||
|
parent: a reference to the parent expression (or None, in case of root expressions).
|
||||||
|
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
|
||||||
|
uses to refer to it.
|
||||||
|
comments: a list of comments that are associated with a given expression. This is used in
|
||||||
|
order to preserve comments when transpiling SQL code.
|
||||||
|
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
|
||||||
|
optimizer, in order to enable some transformations that require type information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = "Expression"
|
key = "expression"
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
|
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
|
||||||
|
|
||||||
def __init__(self, **args):
|
def __init__(self, **args: t.Any):
|
||||||
self.args = args
|
self.args: t.Dict[str, t.Any] = args
|
||||||
self.parent = None
|
self.parent: t.Optional[Expression] = None
|
||||||
self.arg_key = None
|
self.arg_key: t.Optional[str] = None
|
||||||
self.comments = None
|
self.comments: t.Optional[t.List[str]] = None
|
||||||
self._type: t.Optional[DataType] = None
|
self._type: t.Optional[DataType] = None
|
||||||
|
|
||||||
for arg_key, value in self.args.items():
|
for arg_key, value in self.args.items():
|
||||||
|
@ -76,17 +114,30 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def this(self):
|
def this(self):
|
||||||
|
"""
|
||||||
|
Retrieves the argument with key "this".
|
||||||
|
"""
|
||||||
return self.args.get("this")
|
return self.args.get("this")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expression(self):
|
def expression(self):
|
||||||
|
"""
|
||||||
|
Retrieves the argument with key "expression".
|
||||||
|
"""
|
||||||
return self.args.get("expression")
|
return self.args.get("expression")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expressions(self):
|
def expressions(self):
|
||||||
|
"""
|
||||||
|
Retrieves the argument with key "expressions".
|
||||||
|
"""
|
||||||
return self.args.get("expressions") or []
|
return self.args.get("expressions") or []
|
||||||
|
|
||||||
def text(self, key):
|
def text(self, key):
|
||||||
|
"""
|
||||||
|
Returns a textual representation of the argument corresponding to "key". This can only be used
|
||||||
|
for args that are strings or leaf Expression instances, such as identifiers and literals.
|
||||||
|
"""
|
||||||
field = self.args.get(key)
|
field = self.args.get(key)
|
||||||
if isinstance(field, str):
|
if isinstance(field, str):
|
||||||
return field
|
return field
|
||||||
|
@ -96,14 +147,23 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_string(self):
|
def is_string(self):
|
||||||
|
"""
|
||||||
|
Checks whether a Literal expression is a string.
|
||||||
|
"""
|
||||||
return isinstance(self, Literal) and self.args["is_string"]
|
return isinstance(self, Literal) and self.args["is_string"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_number(self):
|
def is_number(self):
|
||||||
|
"""
|
||||||
|
Checks whether a Literal expression is a number.
|
||||||
|
"""
|
||||||
return isinstance(self, Literal) and not self.args["is_string"]
|
return isinstance(self, Literal) and not self.args["is_string"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_int(self):
|
def is_int(self):
|
||||||
|
"""
|
||||||
|
Checks whether a Literal expression is an integer.
|
||||||
|
"""
|
||||||
if self.is_number:
|
if self.is_number:
|
||||||
try:
|
try:
|
||||||
int(self.name)
|
int(self.name)
|
||||||
|
@ -114,6 +174,9 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias(self):
|
def alias(self):
|
||||||
|
"""
|
||||||
|
Returns the alias of the expression, or an empty string if it's not aliased.
|
||||||
|
"""
|
||||||
if isinstance(self.args.get("alias"), TableAlias):
|
if isinstance(self.args.get("alias"), TableAlias):
|
||||||
return self.args["alias"].name
|
return self.args["alias"].name
|
||||||
return self.text("alias")
|
return self.text("alias")
|
||||||
|
@ -128,6 +191,24 @@ class Expression(metaclass=_Expression):
|
||||||
return "NULL"
|
return "NULL"
|
||||||
return self.alias or self.name
|
return self.alias or self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
"""
|
||||||
|
Name of the output column if this expression is a selection.
|
||||||
|
|
||||||
|
If the Expression has no output name, an empty string is returned.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from sqlglot import parse_one
|
||||||
|
>>> parse_one("SELECT a").expressions[0].output_name
|
||||||
|
'a'
|
||||||
|
>>> parse_one("SELECT b AS c").expressions[0].output_name
|
||||||
|
'c'
|
||||||
|
>>> parse_one("SELECT 1 + 2").expressions[0].output_name
|
||||||
|
''
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> t.Optional[DataType]:
|
def type(self) -> t.Optional[DataType]:
|
||||||
return self._type
|
return self._type
|
||||||
|
@ -145,6 +226,9 @@ class Expression(metaclass=_Expression):
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
"""
|
||||||
|
Returns a deep copy of the expression.
|
||||||
|
"""
|
||||||
new = deepcopy(self)
|
new = deepcopy(self)
|
||||||
for item, parent, _ in new.bfs():
|
for item, parent, _ in new.bfs():
|
||||||
if isinstance(item, Expression) and parent:
|
if isinstance(item, Expression) and parent:
|
||||||
|
@ -169,7 +253,7 @@ class Expression(metaclass=_Expression):
|
||||||
Sets `arg_key` to `value`.
|
Sets `arg_key` to `value`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arg_key (str): name of the expression arg
|
arg_key (str): name of the expression arg.
|
||||||
value: value to set the arg to.
|
value: value to set the arg to.
|
||||||
"""
|
"""
|
||||||
self.args[arg_key] = value
|
self.args[arg_key] = value
|
||||||
|
@ -203,8 +287,7 @@ class Expression(metaclass=_Expression):
|
||||||
expression_types (type): the expression type(s) to match.
|
expression_types (type): the expression type(s) to match.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the node which matches the criteria or None if no node matching
|
The node which matches the criteria or None if no such node was found.
|
||||||
the criteria was found.
|
|
||||||
"""
|
"""
|
||||||
return next(self.find_all(*expression_types, bfs=bfs), None)
|
return next(self.find_all(*expression_types, bfs=bfs), None)
|
||||||
|
|
||||||
|
@ -217,7 +300,7 @@ class Expression(metaclass=_Expression):
|
||||||
expression_types (type): the expression type(s) to match.
|
expression_types (type): the expression type(s) to match.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the generator object.
|
The generator object.
|
||||||
"""
|
"""
|
||||||
for expression, _, _ in self.walk(bfs=bfs):
|
for expression, _, _ in self.walk(bfs=bfs):
|
||||||
if isinstance(expression, expression_types):
|
if isinstance(expression, expression_types):
|
||||||
|
@ -231,7 +314,7 @@ class Expression(metaclass=_Expression):
|
||||||
expression_types (type): the expression type(s) to match.
|
expression_types (type): the expression type(s) to match.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the parent node
|
The parent node.
|
||||||
"""
|
"""
|
||||||
ancestor = self.parent
|
ancestor = self.parent
|
||||||
while ancestor and not isinstance(ancestor, expression_types):
|
while ancestor and not isinstance(ancestor, expression_types):
|
||||||
|
@ -269,7 +352,7 @@ class Expression(metaclass=_Expression):
|
||||||
the DFS (Depth-first) order.
|
the DFS (Depth-first) order.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the generator object.
|
The generator object.
|
||||||
"""
|
"""
|
||||||
parent = parent or self.parent
|
parent = parent or self.parent
|
||||||
yield self, parent, key
|
yield self, parent, key
|
||||||
|
@ -287,7 +370,7 @@ class Expression(metaclass=_Expression):
|
||||||
the BFS (Breadth-first) order.
|
the BFS (Breadth-first) order.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the generator object.
|
The generator object.
|
||||||
"""
|
"""
|
||||||
queue = deque([(self, self.parent, None)])
|
queue = deque([(self, self.parent, None)])
|
||||||
|
|
||||||
|
@ -341,32 +424,33 @@ class Expression(metaclass=_Expression):
|
||||||
return self.sql()
|
return self.sql()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.to_s()
|
return self._to_s()
|
||||||
|
|
||||||
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
||||||
"""
|
"""
|
||||||
Returns SQL string representation of this tree.
|
Returns SQL string representation of this tree.
|
||||||
|
|
||||||
Args
|
Args:
|
||||||
dialect (str): the dialect of the output SQL string
|
dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql").
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
opts: other `sqlglot.generator.Generator` options.
|
||||||
opts (dict): other :class:`~sqlglot.generator.Generator` options.
|
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
the SQL string.
|
The SQL string.
|
||||||
"""
|
"""
|
||||||
from sqlglot.dialects import Dialect
|
from sqlglot.dialects import Dialect
|
||||||
|
|
||||||
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
||||||
|
|
||||||
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
||||||
indent = "" if not level else "\n"
|
indent = "" if not level else "\n"
|
||||||
indent += "".join([" "] * level)
|
indent += "".join([" "] * level)
|
||||||
left = f"({self.key.upper()} "
|
left = f"({self.key.upper()} "
|
||||||
|
|
||||||
args: t.Dict[str, t.Any] = {
|
args: t.Dict[str, t.Any] = {
|
||||||
k: ", ".join(
|
k: ", ".join(
|
||||||
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
v._to_s(hide_missing=hide_missing, level=level + 1)
|
||||||
|
if hasattr(v, "_to_s")
|
||||||
|
else str(v)
|
||||||
for v in ensure_collection(vs)
|
for v in ensure_collection(vs)
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
@ -394,7 +478,7 @@ class Expression(metaclass=_Expression):
|
||||||
modified in place.
|
modified in place.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the transformed tree.
|
The transformed tree.
|
||||||
"""
|
"""
|
||||||
node = self.copy() if copy else self
|
node = self.copy() if copy else self
|
||||||
new_node = fun(node, *args, **kwargs)
|
new_node = fun(node, *args, **kwargs)
|
||||||
|
@ -423,8 +507,8 @@ class Expression(metaclass=_Expression):
|
||||||
Args:
|
Args:
|
||||||
expression (Expression|None): new node
|
expression (Expression|None): new node
|
||||||
|
|
||||||
Returns :
|
Returns:
|
||||||
the new expression or expressions
|
The new expression or expressions.
|
||||||
"""
|
"""
|
||||||
if not self.parent:
|
if not self.parent:
|
||||||
return expression
|
return expression
|
||||||
|
@ -458,6 +542,40 @@ class Expression(metaclass=_Expression):
|
||||||
assert isinstance(self, type_)
|
assert isinstance(self, type_)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
|
||||||
|
"""
|
||||||
|
Checks if this expression is valid (e.g. all mandatory args are set).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: a sequence of values that were used to instantiate a Func expression. This is used
|
||||||
|
to check that the provided arguments don't exceed the function argument limit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of error messages for all possible errors that were found.
|
||||||
|
"""
|
||||||
|
errors: t.List[str] = []
|
||||||
|
|
||||||
|
for k in self.args:
|
||||||
|
if k not in self.arg_types:
|
||||||
|
errors.append(f"Unexpected keyword: '{k}' for {self.__class__}")
|
||||||
|
for k, mandatory in self.arg_types.items():
|
||||||
|
v = self.args.get(k)
|
||||||
|
if mandatory and (v is None or (isinstance(v, list) and not v)):
|
||||||
|
errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
args
|
||||||
|
and isinstance(self, Func)
|
||||||
|
and len(args) > len(self.arg_types)
|
||||||
|
and not self.is_var_len_args
|
||||||
|
):
|
||||||
|
errors.append(
|
||||||
|
f"The number of provided arguments ({len(args)}) is greater than "
|
||||||
|
f"the maximum number of supported arguments ({len(self.arg_types)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
def dump(self):
|
def dump(self):
|
||||||
"""
|
"""
|
||||||
Dump this Expression to a JSON-serializable dict.
|
Dump this Expression to a JSON-serializable dict.
|
||||||
|
@ -552,7 +670,7 @@ class DerivedTable(Expression):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_selects(self):
|
def named_selects(self):
|
||||||
return [select.alias_or_name for select in self.selects]
|
return [select.output_name for select in self.selects]
|
||||||
|
|
||||||
|
|
||||||
class Unionable(Expression):
|
class Unionable(Expression):
|
||||||
|
@ -654,6 +772,7 @@ class Create(Expression):
|
||||||
"no_primary_index": False,
|
"no_primary_index": False,
|
||||||
"indexes": False,
|
"indexes": False,
|
||||||
"no_schema_binding": False,
|
"no_schema_binding": False,
|
||||||
|
"begin": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -696,7 +815,7 @@ class Show(Expression):
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedFunction(Expression):
|
class UserDefinedFunction(Expression):
|
||||||
arg_types = {"this": True, "expressions": False}
|
arg_types = {"this": True, "expressions": False, "wrapped": False}
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedFunctionKwarg(Expression):
|
class UserDefinedFunctionKwarg(Expression):
|
||||||
|
@ -750,6 +869,10 @@ class Column(Condition):
|
||||||
def table(self):
|
def table(self):
|
||||||
return self.text("table")
|
return self.text("table")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class ColumnDef(Expression):
|
class ColumnDef(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
|
@ -865,6 +988,10 @@ class ForeignKey(Expression):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PrimaryKey(Expression):
|
||||||
|
arg_types = {"expressions": True, "options": False}
|
||||||
|
|
||||||
|
|
||||||
class Unique(Expression):
|
class Unique(Expression):
|
||||||
arg_types = {"expressions": True}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
@ -904,6 +1031,10 @@ class Identifier(Expression):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.key, self.this.lower()))
|
return hash((self.key, self.this.lower()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Index(Expression):
|
class Index(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
|
@ -996,6 +1127,10 @@ class Literal(Condition):
|
||||||
def string(cls, string) -> Literal:
|
def string(cls, string) -> Literal:
|
||||||
return cls(this=str(string), is_string=True)
|
return cls(this=str(string), is_string=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Join(Expression):
|
class Join(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
|
@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property):
|
||||||
|
|
||||||
|
|
||||||
class ReturnsProperty(Property):
|
class ReturnsProperty(Property):
|
||||||
arg_types = {"this": True, "is_table": False}
|
arg_types = {"this": True, "is_table": False, "table": False}
|
||||||
|
|
||||||
|
|
||||||
class LanguageProperty(Property):
|
class LanguageProperty(Property):
|
||||||
|
@ -1262,8 +1397,13 @@ class Qualify(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
|
||||||
|
class Return(Expression):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Reference(Expression):
|
class Reference(Expression):
|
||||||
arg_types = {"this": True, "expressions": True}
|
arg_types = {"this": True, "expressions": False, "options": False}
|
||||||
|
|
||||||
|
|
||||||
class Tuple(Expression):
|
class Tuple(Expression):
|
||||||
|
@ -1397,6 +1537,16 @@ class Table(Expression):
|
||||||
"joins": False,
|
"joins": False,
|
||||||
"pivots": False,
|
"pivots": False,
|
||||||
"hints": False,
|
"hints": False,
|
||||||
|
"system_time": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# See the TSQL "Querying data in a system-versioned temporal table" page
|
||||||
|
class SystemTime(Expression):
|
||||||
|
arg_types = {
|
||||||
|
"this": False,
|
||||||
|
"expression": False,
|
||||||
|
"kind": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2027,7 +2177,7 @@ class Select(Subqueryable):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_selects(self) -> t.List[str]:
|
def named_selects(self) -> t.List[str]:
|
||||||
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
|
return [e.output_name for e in self.expressions if e.alias_or_name]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selects(self) -> t.List[Expression]:
|
def selects(self) -> t.List[Expression]:
|
||||||
|
@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable):
|
||||||
expression = expression.this
|
expression = expression.this
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.alias
|
||||||
|
|
||||||
|
|
||||||
class TableSample(Expression):
|
class TableSample(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
|
@ -2066,6 +2220,16 @@ class TableSample(Expression):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Expression):
|
||||||
|
"""Tags are used for generating arbitrary sql like SELECT <span>x</span>."""
|
||||||
|
|
||||||
|
arg_types = {
|
||||||
|
"this": False,
|
||||||
|
"prefix": False,
|
||||||
|
"postfix": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Pivot(Expression):
|
class Pivot(Expression):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"this": False,
|
"this": False,
|
||||||
|
@ -2106,6 +2270,10 @@ class Star(Expression):
|
||||||
def name(self):
|
def name(self):
|
||||||
return "*"
|
return "*"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Parameter(Expression):
|
class Parameter(Expression):
|
||||||
pass
|
pass
|
||||||
|
@ -2143,6 +2311,8 @@ class DataType(Expression):
|
||||||
TEXT = auto()
|
TEXT = auto()
|
||||||
MEDIUMTEXT = auto()
|
MEDIUMTEXT = auto()
|
||||||
LONGTEXT = auto()
|
LONGTEXT = auto()
|
||||||
|
MEDIUMBLOB = auto()
|
||||||
|
LONGBLOB = auto()
|
||||||
BINARY = auto()
|
BINARY = auto()
|
||||||
VARBINARY = auto()
|
VARBINARY = auto()
|
||||||
INT = auto()
|
INT = auto()
|
||||||
|
@ -2282,11 +2452,11 @@ class Rollback(Expression):
|
||||||
|
|
||||||
|
|
||||||
class AlterTable(Expression):
|
class AlterTable(Expression):
|
||||||
arg_types = {
|
arg_types = {"this": True, "actions": True, "exists": False}
|
||||||
"this": True,
|
|
||||||
"actions": True,
|
|
||||||
"exists": False,
|
class AddConstraint(Expression):
|
||||||
}
|
arg_types = {"this": False, "expression": False, "enforced": False}
|
||||||
|
|
||||||
|
|
||||||
# Binary expressions like (ADD a b)
|
# Binary expressions like (ADD a b)
|
||||||
|
@ -2456,6 +2626,10 @@ class Neg(Unary):
|
||||||
class Alias(Expression):
|
class Alias(Expression):
|
||||||
arg_types = {"this": True, "alias": False}
|
arg_types = {"this": True, "alias": False}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.alias
|
||||||
|
|
||||||
|
|
||||||
class Aliases(Expression):
|
class Aliases(Expression):
|
||||||
arg_types = {"this": True, "expressions": True}
|
arg_types = {"this": True, "expressions": True}
|
||||||
|
@ -2523,16 +2697,13 @@ class Func(Condition):
|
||||||
"""
|
"""
|
||||||
The base class for all function expressions.
|
The base class for all function expressions.
|
||||||
|
|
||||||
Attributes
|
Attributes:
|
||||||
is_var_len_args (bool): if set to True the last argument defined in
|
is_var_len_args (bool): if set to True the last argument defined in arg_types will be
|
||||||
arg_types will be treated as a variable length argument and the
|
treated as a variable length argument and the argument's value will be stored as a list.
|
||||||
argument's value will be stored as a list.
|
_sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
|
||||||
_sql_names (list): determines the SQL name (1st item in the list) and
|
for this function expression. These values are used to map this node to a name during parsing
|
||||||
aliases (subsequent items) for this function expression. These
|
as well as to provide the function's name during SQL string generation. By default the SQL
|
||||||
values are used to map this node to a name during parsing as well
|
name is set to the expression's class name transformed to snake case.
|
||||||
as to provide the function's name during SQL string generation. By
|
|
||||||
default the SQL name is set to the expression's class name transformed
|
|
||||||
to snake case.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_var_len_args = False
|
is_var_len_args = False
|
||||||
|
@ -2558,7 +2729,7 @@ class Func(Condition):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"SQL name is only supported by concrete function implementations"
|
"SQL name is only supported by concrete function implementations"
|
||||||
)
|
)
|
||||||
if not hasattr(cls, "_sql_names"):
|
if "_sql_names" not in cls.__dict__:
|
||||||
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
||||||
return cls._sql_names
|
return cls._sql_names
|
||||||
|
|
||||||
|
@ -2658,6 +2829,10 @@ class Cast(Func):
|
||||||
def to(self):
|
def to(self):
|
||||||
return self.args["to"]
|
return self.args["to"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Collate(Binary):
|
class Collate(Binary):
|
||||||
pass
|
pass
|
||||||
|
@ -2956,6 +3131,14 @@ class Pow(Func):
|
||||||
_sql_names = ["POWER", "POW"]
|
_sql_names = ["POWER", "POW"]
|
||||||
|
|
||||||
|
|
||||||
|
class PercentileCont(AggFunc):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PercentileDisc(AggFunc):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Quantile(AggFunc):
|
class Quantile(AggFunc):
|
||||||
arg_types = {"this": True, "quantile": True}
|
arg_types = {"this": True, "quantile": True}
|
||||||
|
|
||||||
|
@ -3213,12 +3396,13 @@ def _norm_arg(arg):
|
||||||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
def maybe_parse(
|
def maybe_parse(
|
||||||
sql_or_expression,
|
sql_or_expression: str | Expression,
|
||||||
*,
|
*,
|
||||||
into=None,
|
into: t.Optional[IntoType] = None,
|
||||||
dialect=None,
|
dialect: t.Optional[str] = None,
|
||||||
prefix=None,
|
prefix: t.Optional[str] = None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> Expression:
|
) -> Expression:
|
||||||
"""Gracefully handle a possible string or expression.
|
"""Gracefully handle a possible string or expression.
|
||||||
|
@ -3230,11 +3414,11 @@ def maybe_parse(
|
||||||
(IDENTIFIER this: x, quoted: False)
|
(IDENTIFIER this: x, quoted: False)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql_or_expression (str | Expression): the SQL code string or an expression
|
sql_or_expression: the SQL code string or an expression
|
||||||
into (Expression): the SQLGlot Expression to parse into
|
into: the SQLGlot Expression to parse into
|
||||||
dialect (str): the dialect used to parse the input expressions (in the case that an
|
dialect: the dialect used to parse the input expressions (in the case that an
|
||||||
input expression is a SQL string).
|
input expression is a SQL string).
|
||||||
prefix (str): a string to prefix the sql with before it gets parsed
|
prefix: a string to prefix the sql with before it gets parsed
|
||||||
(automatically includes a space)
|
(automatically includes a space)
|
||||||
**opts: other options to use to parse the input expressions (again, in the case
|
**opts: other options to use to parse the input expressions (again, in the case
|
||||||
that an input expression is a SQL string).
|
that an input expression is a SQL string).
|
||||||
|
@ -3993,7 +4177,7 @@ def table_name(table) -> str:
|
||||||
"""Get the full name of a table as a string.
|
"""Get the full name of a table as a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (exp.Table | str): Table expression node or string.
|
table (exp.Table | str): table expression node or string.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from sqlglot import exp, parse_one
|
>>> from sqlglot import exp, parse_one
|
||||||
|
@ -4001,7 +4185,7 @@ def table_name(table) -> str:
|
||||||
'a.b.c'
|
'a.b.c'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: the table name
|
The table name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table = maybe_parse(table, into=Table)
|
table = maybe_parse(table, into=Table)
|
||||||
|
@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping):
|
||||||
"""Replace all tables in expression according to the mapping.
|
"""Replace all tables in expression according to the mapping.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expression (sqlglot.Expression): Expression node to be transformed and replaced
|
expression (sqlglot.Expression): expression node to be transformed and replaced.
|
||||||
mapping (Dict[str, str]): Mapping of table names
|
mapping (Dict[str, str]): mapping of table names.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from sqlglot import exp, parse_one
|
>>> from sqlglot import exp, parse_one
|
||||||
|
@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping):
|
||||||
'SELECT * FROM c'
|
'SELECT * FROM c'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The mapped expression
|
The mapped expression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _replace_tables(node):
|
def _replace_tables(node):
|
||||||
|
@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs):
|
||||||
"""Replace placeholders in an expression.
|
"""Replace placeholders in an expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expression (sqlglot.Expression): Expression node to be transformed and replaced
|
expression (sqlglot.Expression): expression node to be transformed and replaced.
|
||||||
args: Positional names that will substitute unnamed placeholders in the given order
|
args: positional names that will substitute unnamed placeholders in the given order.
|
||||||
kwargs: Keyword arguments that will substitute named placeholders
|
kwargs: keyword arguments that will substitute named placeholders.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from sqlglot import exp, parse_one
|
>>> from sqlglot import exp, parse_one
|
||||||
|
@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs):
|
||||||
'SELECT * FROM foo WHERE a = b'
|
'SELECT * FROM foo WHERE a = b'
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The mapped expression
|
The mapped expression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _replace_placeholders(node, args, **kwargs):
|
def _replace_placeholders(node, args, **kwargs):
|
||||||
|
@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs):
|
||||||
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
return expression.transform(_replace_placeholders, iter(args), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
|
||||||
|
"""Transforms an expression by expanding all referenced sources into subqueries.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from sqlglot import parse_one
|
||||||
|
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
|
||||||
|
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: The expression to expand.
|
||||||
|
sources: A dictionary of name to Subqueryables.
|
||||||
|
copy: Whether or not to copy the expression during transformation. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transformed expression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _expand(node: Expression):
|
||||||
|
if isinstance(node, Table):
|
||||||
|
name = table_name(node)
|
||||||
|
source = sources.get(name)
|
||||||
|
if source:
|
||||||
|
subquery = source.subquery(node.alias or name)
|
||||||
|
subquery.comments = [f"source: {name}"]
|
||||||
|
return subquery
|
||||||
|
return node
|
||||||
|
|
||||||
|
return expression.transform(_expand, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
|
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
|
||||||
|
"""
|
||||||
|
Returns a Func expression.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> func("abs", 5).sql()
|
||||||
|
'ABS(5)'
|
||||||
|
|
||||||
|
>>> func("cast", this=5, to=DataType.build("DOUBLE")).sql()
|
||||||
|
'CAST(5 AS DOUBLE)'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the function to build.
|
||||||
|
args: the args used to instantiate the function of interest.
|
||||||
|
dialect: the source dialect.
|
||||||
|
kwargs: the kwargs used to instantiate the function of interest.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The arguments `args` and `kwargs` are mutually exclusive.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the function of interest, or an anonymous function, if `name` doesn't
|
||||||
|
correspond to an existing `sqlglot.expressions.Func` class.
|
||||||
|
"""
|
||||||
|
if args and kwargs:
|
||||||
|
raise ValueError("Can't use both args and kwargs to instantiate a function.")
|
||||||
|
|
||||||
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
|
||||||
|
args = tuple(convert(arg) for arg in args)
|
||||||
|
kwargs = {key: convert(value) for key, value in kwargs.items()}
|
||||||
|
|
||||||
|
parser = Dialect.get_or_raise(dialect)().parser()
|
||||||
|
from_args_list = parser.FUNCTIONS.get(name.upper())
|
||||||
|
|
||||||
|
if from_args_list:
|
||||||
|
function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore
|
||||||
|
else:
|
||||||
|
kwargs = kwargs or {"expressions": args}
|
||||||
|
function = Anonymous(this=name, **kwargs)
|
||||||
|
|
||||||
|
for error_message in function.error_messages(args):
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
def true():
|
def true():
|
||||||
|
"""
|
||||||
|
Returns a true Boolean expression.
|
||||||
|
"""
|
||||||
return Boolean(this=True)
|
return Boolean(this=True)
|
||||||
|
|
||||||
|
|
||||||
def false():
|
def false():
|
||||||
|
"""
|
||||||
|
Returns a false Boolean expression.
|
||||||
|
"""
|
||||||
return Boolean(this=False)
|
return Boolean(this=False)
|
||||||
|
|
||||||
|
|
||||||
def null():
|
def null():
|
||||||
|
"""
|
||||||
|
Returns a Null expression.
|
||||||
|
"""
|
||||||
return Null()
|
return Null()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class Generator:
|
||||||
"""
|
"""
|
||||||
Generator interprets the given syntax tree and produces a SQL string as an output.
|
Generator interprets the given syntax tree and produces a SQL string as an output.
|
||||||
|
|
||||||
Args
|
Args:
|
||||||
time_mapping (dict): the dictionary of custom time mappings in which the key
|
time_mapping (dict): the dictionary of custom time mappings in which the key
|
||||||
represents a python time format and the output the target time format
|
represents a python time format and the output the target time format
|
||||||
time_trie (trie): a trie of the time_mapping keys
|
time_trie (trie): a trie of the time_mapping keys
|
||||||
|
@ -84,6 +84,13 @@ class Generator:
|
||||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||||
exp.DataType.Type.MEDIUMTEXT: "TEXT",
|
exp.DataType.Type.MEDIUMTEXT: "TEXT",
|
||||||
exp.DataType.Type.LONGTEXT: "TEXT",
|
exp.DataType.Type.LONGTEXT: "TEXT",
|
||||||
|
exp.DataType.Type.MEDIUMBLOB: "BLOB",
|
||||||
|
exp.DataType.Type.LONGBLOB: "BLOB",
|
||||||
|
}
|
||||||
|
|
||||||
|
STAR_MAPPING = {
|
||||||
|
"except": "EXCEPT",
|
||||||
|
"replace": "REPLACE",
|
||||||
}
|
}
|
||||||
|
|
||||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||||
|
@ -106,6 +113,8 @@ class Generator:
|
||||||
exp.TableFormatProperty,
|
exp.TableFormatProperty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint)
|
||||||
|
|
||||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||||
|
|
||||||
|
@ -241,15 +250,17 @@ class Generator:
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
sep = "\n" if self.pretty else " "
|
sep = "\n" if self.pretty else " "
|
||||||
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
|
comments_sql = sep.join(
|
||||||
|
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
|
||||||
|
)
|
||||||
|
|
||||||
if not comments:
|
if not comments_sql:
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||||
return f"{comments}{self.sep()}{sql}"
|
return f"{comments_sql}{self.sep()}{sql}"
|
||||||
|
|
||||||
return f"{sql} {comments}"
|
return f"{sql} {comments_sql}"
|
||||||
|
|
||||||
def wrap(self, expression: exp.Expression | str) -> str:
|
def wrap(self, expression: exp.Expression | str) -> str:
|
||||||
this_sql = self.indent(
|
this_sql = self.indent(
|
||||||
|
@ -433,8 +444,9 @@ class Generator:
|
||||||
def create_sql(self, expression: exp.Create) -> str:
|
def create_sql(self, expression: exp.Create) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
kind = self.sql(expression, "kind").upper()
|
kind = self.sql(expression, "kind").upper()
|
||||||
|
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
|
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
|
||||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
transient = (
|
transient = (
|
||||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||||
|
@ -741,12 +753,14 @@ class Generator:
|
||||||
laterals = self.expressions(expression, key="laterals", sep="")
|
laterals = self.expressions(expression, key="laterals", sep="")
|
||||||
joins = self.expressions(expression, key="joins", sep="")
|
joins = self.expressions(expression, key="joins", sep="")
|
||||||
pivots = self.expressions(expression, key="pivots", sep="")
|
pivots = self.expressions(expression, key="pivots", sep="")
|
||||||
|
system_time = expression.args.get("system_time")
|
||||||
|
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
|
||||||
|
|
||||||
if alias and pivots:
|
if alias and pivots:
|
||||||
pivots = f"{pivots}{alias}"
|
pivots = f"{pivots}{alias}"
|
||||||
alias = ""
|
alias = ""
|
||||||
|
|
||||||
return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
|
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
|
||||||
|
|
||||||
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
def tablesample_sql(self, expression: exp.TableSample) -> str:
|
||||||
if self.alias_post_tablesample and expression.this.alias:
|
if self.alias_post_tablesample and expression.this.alias:
|
||||||
|
@ -1009,9 +1023,9 @@ class Generator:
|
||||||
|
|
||||||
def star_sql(self, expression: exp.Star) -> str:
|
def star_sql(self, expression: exp.Star) -> str:
|
||||||
except_ = self.expressions(expression, key="except", flat=True)
|
except_ = self.expressions(expression, key="except", flat=True)
|
||||||
except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
|
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
|
||||||
replace = self.expressions(expression, key="replace", flat=True)
|
replace = self.expressions(expression, key="replace", flat=True)
|
||||||
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
|
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
|
||||||
return f"*{except_}{replace}"
|
return f"*{except_}{replace}"
|
||||||
|
|
||||||
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
|
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
|
||||||
|
@ -1193,6 +1207,12 @@ class Generator:
|
||||||
update = f" ON UPDATE {update}" if update else ""
|
update = f" ON UPDATE {update}" if update else ""
|
||||||
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
|
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
|
||||||
|
|
||||||
|
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
|
||||||
|
expressions = self.expressions(expression, flat=True)
|
||||||
|
options = self.expressions(expression, "options", flat=True, sep=" ")
|
||||||
|
options = f" {options}" if options else ""
|
||||||
|
return f"PRIMARY KEY ({expressions}){options}"
|
||||||
|
|
||||||
def unique_sql(self, expression: exp.Unique) -> str:
|
def unique_sql(self, expression: exp.Unique) -> str:
|
||||||
columns = self.expressions(expression, key="expressions")
|
columns = self.expressions(expression, key="expressions")
|
||||||
return f"UNIQUE ({columns})"
|
return f"UNIQUE ({columns})"
|
||||||
|
@ -1229,10 +1249,16 @@ class Generator:
|
||||||
unit = f" {unit}" if unit else ""
|
unit = f" {unit}" if unit else ""
|
||||||
return f"INTERVAL{this}{unit}"
|
return f"INTERVAL{this}{unit}"
|
||||||
|
|
||||||
|
def return_sql(self, expression: exp.Return) -> str:
|
||||||
|
return f"RETURN {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def reference_sql(self, expression: exp.Reference) -> str:
|
def reference_sql(self, expression: exp.Reference) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
expressions = self.expressions(expression, flat=True)
|
expressions = self.expressions(expression, flat=True)
|
||||||
return f"REFERENCES {this}({expressions})"
|
expressions = f"({expressions})" if expressions else ""
|
||||||
|
options = self.expressions(expression, "options", flat=True, sep=" ")
|
||||||
|
options = f" {options}" if options else ""
|
||||||
|
return f"REFERENCES {this}{expressions}{options}"
|
||||||
|
|
||||||
def anonymous_sql(self, expression: exp.Anonymous) -> str:
|
def anonymous_sql(self, expression: exp.Anonymous) -> str:
|
||||||
args = self.format_args(*expression.expressions)
|
args = self.format_args(*expression.expressions)
|
||||||
|
@ -1362,7 +1388,7 @@ class Generator:
|
||||||
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
|
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
|
||||||
elif isinstance(actions[0], exp.Drop):
|
elif isinstance(actions[0], exp.Drop):
|
||||||
actions = self.expressions(expression, "actions")
|
actions = self.expressions(expression, "actions")
|
||||||
elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
|
elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION):
|
||||||
actions = self.sql(actions[0])
|
actions = self.sql(actions[0])
|
||||||
else:
|
else:
|
||||||
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
|
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
|
||||||
|
@ -1370,6 +1396,17 @@ class Generator:
|
||||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||||
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
|
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
|
||||||
|
|
||||||
|
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
|
||||||
|
this = self.sql(expression, "this")
|
||||||
|
expression_ = self.sql(expression, "expression")
|
||||||
|
add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
|
||||||
|
|
||||||
|
enforced = expression.args.get("enforced")
|
||||||
|
if enforced is not None:
|
||||||
|
return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
|
||||||
|
|
||||||
|
return f"{add_constraint} {expression_}"
|
||||||
|
|
||||||
def distinct_sql(self, expression: exp.Distinct) -> str:
|
def distinct_sql(self, expression: exp.Distinct) -> str:
|
||||||
this = self.expressions(expression, flat=True)
|
this = self.expressions(expression, flat=True)
|
||||||
this = f" {this}" if this else ""
|
this = f" {this}" if this else ""
|
||||||
|
@ -1550,13 +1587,19 @@ class Generator:
|
||||||
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def tag_sql(self, expression: exp.Tag) -> str:
|
||||||
|
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
|
||||||
|
|
||||||
def token_sql(self, token_type: TokenType) -> str:
|
def token_sql(self, token_type: TokenType) -> str:
|
||||||
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
||||||
|
|
||||||
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
|
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
expressions = self.no_identify(lambda: self.expressions(expression))
|
expressions = self.no_identify(lambda: self.expressions(expression))
|
||||||
return f"{this}({expressions})"
|
expressions = (
|
||||||
|
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
|
||||||
|
)
|
||||||
|
return f"{this}{expressions}"
|
||||||
|
|
||||||
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
|
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
|
|
@ -332,7 +332,7 @@ def is_iterable(value: t.Any) -> bool:
|
||||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
||||||
|
|
||||||
|
|
||||||
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
|
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
|
||||||
"""
|
"""
|
||||||
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
|
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
|
||||||
type `str` and `bytes` are not regarded as iterables.
|
type `str` and `bytes` are not regarded as iterables.
|
||||||
|
|
228
sqlglot/lineage.py
Normal file
228
sqlglot/lineage.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import typing as t
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from sqlglot import Schema, exp, maybe_parse
|
||||||
|
from sqlglot.optimizer import Scope, build_scope, optimize
|
||||||
|
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||||
|
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Node:
|
||||||
|
name: str
|
||||||
|
expression: exp.Expression
|
||||||
|
source: exp.Expression
|
||||||
|
downstream: t.List[Node] = field(default_factory=list)
|
||||||
|
|
||||||
|
def walk(self) -> t.Iterator[Node]:
|
||||||
|
yield self
|
||||||
|
|
||||||
|
for d in self.downstream:
|
||||||
|
if isinstance(d, Node):
|
||||||
|
yield from d.walk()
|
||||||
|
else:
|
||||||
|
yield d
|
||||||
|
|
||||||
|
def to_html(self, **opts) -> LineageHTML:
|
||||||
|
return LineageHTML(self, **opts)
|
||||||
|
|
||||||
|
|
||||||
|
def lineage(
|
||||||
|
column: str | exp.Column,
|
||||||
|
sql: str | exp.Expression,
|
||||||
|
schema: t.Optional[t.Dict | Schema] = None,
|
||||||
|
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||||
|
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
|
||||||
|
dialect: t.Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
|
"""Build the lineage graph for a column of a SQL query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The column to build the lineage for.
|
||||||
|
sql: The SQL string or expression.
|
||||||
|
schema: The schema of tables.
|
||||||
|
sources: A mapping of queries which will be used to continue building lineage.
|
||||||
|
rules: Optimizer rules to apply, by default only qualifying tables and columns.
|
||||||
|
dialect: The dialect of input SQL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A lineage node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expression = maybe_parse(sql, dialect=dialect)
|
||||||
|
|
||||||
|
if sources:
|
||||||
|
expression = exp.expand(
|
||||||
|
expression,
|
||||||
|
{
|
||||||
|
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
|
||||||
|
for k, v in sources.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
optimized = optimize(expression, schema=schema, rules=rules)
|
||||||
|
scope = build_scope(optimized)
|
||||||
|
tables: t.Dict[str, Node] = {}
|
||||||
|
|
||||||
|
def to_node(
|
||||||
|
column_name: str,
|
||||||
|
scope: Scope,
|
||||||
|
scope_name: t.Optional[str] = None,
|
||||||
|
upstream: t.Optional[Node] = None,
|
||||||
|
) -> Node:
|
||||||
|
if isinstance(scope.expression, exp.Union):
|
||||||
|
for scope in scope.union_scopes:
|
||||||
|
node = to_node(
|
||||||
|
column_name,
|
||||||
|
scope=scope,
|
||||||
|
scope_name=scope_name,
|
||||||
|
upstream=upstream,
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
select = next(select for select in scope.selects if select.alias_or_name == column_name)
|
||||||
|
source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
|
||||||
|
select = source.selects[0]
|
||||||
|
|
||||||
|
node = Node(
|
||||||
|
name=f"{scope_name}.{column_name}" if scope_name else column_name,
|
||||||
|
source=source,
|
||||||
|
expression=select,
|
||||||
|
)
|
||||||
|
|
||||||
|
if upstream:
|
||||||
|
upstream.downstream.append(node)
|
||||||
|
|
||||||
|
for c in set(select.find_all(exp.Column)):
|
||||||
|
table = c.table
|
||||||
|
source = scope.sources[table]
|
||||||
|
|
||||||
|
if isinstance(source, Scope):
|
||||||
|
to_node(
|
||||||
|
c.name,
|
||||||
|
scope=source,
|
||||||
|
scope_name=table,
|
||||||
|
upstream=node,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if table not in tables:
|
||||||
|
tables[table] = Node(name=table, source=source, expression=source)
|
||||||
|
node.downstream.append(tables[table])
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
return to_node(column if isinstance(column, str) else column.name, scope)
|
||||||
|
|
||||||
|
|
||||||
|
class LineageHTML:
|
||||||
|
"""Node to HTML generator using vis.js.
|
||||||
|
|
||||||
|
https://visjs.github.io/vis-network/docs/network/
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node: Node,
|
||||||
|
dialect: t.Optional[str] = None,
|
||||||
|
imports: bool = True,
|
||||||
|
**opts: t.Any,
|
||||||
|
):
|
||||||
|
self.node = node
|
||||||
|
self.imports = imports
|
||||||
|
|
||||||
|
self.options = {
|
||||||
|
"height": "500px",
|
||||||
|
"width": "100%",
|
||||||
|
"layout": {
|
||||||
|
"hierarchical": {
|
||||||
|
"enabled": True,
|
||||||
|
"nodeSpacing": 200,
|
||||||
|
"sortMethod": "directed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"interaction": {
|
||||||
|
"dragNodes": False,
|
||||||
|
"selectable": False,
|
||||||
|
},
|
||||||
|
"physics": {
|
||||||
|
"enabled": False,
|
||||||
|
},
|
||||||
|
"edges": {
|
||||||
|
"arrows": "to",
|
||||||
|
},
|
||||||
|
"nodes": {
|
||||||
|
"font": "20px monaco",
|
||||||
|
"shape": "box",
|
||||||
|
"widthConstraint": {
|
||||||
|
"maximum": 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
**opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.nodes = {}
|
||||||
|
self.edges = []
|
||||||
|
|
||||||
|
for node in node.walk():
|
||||||
|
if isinstance(node.expression, exp.Table):
|
||||||
|
label = f"FROM {node.expression.this}"
|
||||||
|
title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
|
||||||
|
group = 1
|
||||||
|
else:
|
||||||
|
label = node.expression.sql(pretty=True, dialect=dialect)
|
||||||
|
source = node.source.transform(
|
||||||
|
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
|
||||||
|
if n is node.expression
|
||||||
|
else n,
|
||||||
|
copy=False,
|
||||||
|
).sql(pretty=True, dialect=dialect)
|
||||||
|
title = f"<pre>{source}</pre>"
|
||||||
|
group = 0
|
||||||
|
|
||||||
|
node_id = id(node)
|
||||||
|
|
||||||
|
self.nodes[node_id] = {
|
||||||
|
"id": node_id,
|
||||||
|
"label": label,
|
||||||
|
"title": title,
|
||||||
|
"group": group,
|
||||||
|
}
|
||||||
|
|
||||||
|
for d in node.downstream:
|
||||||
|
self.edges.append({"from": node_id, "to": id(d)})
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
nodes = json.dumps(list(self.nodes.values()))
|
||||||
|
edges = json.dumps(self.edges)
|
||||||
|
options = json.dumps(self.options)
|
||||||
|
imports = (
|
||||||
|
"""<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
|
||||||
|
<script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
|
||||||
|
<link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
|
||||||
|
if self.imports
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"""<div>
|
||||||
|
<div id="sqlglot-lineage"></div>
|
||||||
|
{imports}
|
||||||
|
<script type="text/javascript">
|
||||||
|
var nodes = new vis.DataSet({nodes})
|
||||||
|
nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
|
||||||
|
|
||||||
|
new vis.Network(
|
||||||
|
document.getElementById("sqlglot-lineage"),
|
||||||
|
{{
|
||||||
|
nodes: nodes,
|
||||||
|
edges: new vis.DataSet({edges})
|
||||||
|
}},
|
||||||
|
{options},
|
||||||
|
)
|
||||||
|
</script>
|
||||||
|
</div>"""
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
return self.__str__()
|
|
@ -1 +1,2 @@
|
||||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||||
|
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
from sqlglot import alias, exp
|
from sqlglot import alias, exp
|
||||||
from sqlglot.errors import OptimizeError
|
from sqlglot.errors import OptimizeError
|
||||||
from sqlglot.optimizer.scope import traverse_scope
|
from sqlglot.optimizer.scope import traverse_scope
|
||||||
|
from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
|
|
||||||
def isolate_table_selects(expression):
|
def isolate_table_selects(expression, schema=None):
|
||||||
|
schema = ensure_schema(schema)
|
||||||
|
|
||||||
for scope in traverse_scope(expression):
|
for scope in traverse_scope(expression):
|
||||||
if len(scope.selected_sources) == 1:
|
if len(scope.selected_sources) == 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for (_, source) in scope.selected_sources.values():
|
for (_, source) in scope.selected_sources.values():
|
||||||
if not isinstance(source, exp.Table):
|
if not isinstance(source, exp.Table) or not schema.column_names(source):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not source.alias:
|
if not source.alias:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import alias, exp
|
from sqlglot import alias, exp
|
||||||
from sqlglot.errors import OptimizeError, SchemaError
|
from sqlglot.errors import OptimizeError
|
||||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||||
from sqlglot.schema import ensure_schema
|
from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
|
@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver):
|
||||||
column_table = column.table
|
column_table = column.table
|
||||||
column_name = column.name
|
column_name = column.name
|
||||||
|
|
||||||
if (
|
if column_table and column_table in scope.sources:
|
||||||
column_table
|
source_columns = resolver.get_source_columns(column_table)
|
||||||
and column_table in scope.sources
|
if source_columns and column_name not in source_columns:
|
||||||
and column_name not in resolver.get_source_columns(column_table)
|
|
||||||
):
|
|
||||||
raise OptimizeError(f"Unknown column: {column_name}")
|
raise OptimizeError(f"Unknown column: {column_name}")
|
||||||
|
|
||||||
if not column_table:
|
if not column_table:
|
||||||
column_table = resolver.get_table(column_name)
|
column_table = resolver.get_table(column_name)
|
||||||
|
|
||||||
if not scope.is_subquery and not scope.is_udtf:
|
if not scope.is_subquery and not scope.is_udtf:
|
||||||
if column_name not in resolver.all_columns:
|
|
||||||
raise OptimizeError(f"Unknown column: {column_name}")
|
|
||||||
|
|
||||||
if column_table is None:
|
if column_table is None:
|
||||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
raise OptimizeError(f"Ambiguous column: {column_name}")
|
||||||
|
|
||||||
|
@ -265,6 +261,10 @@ def _expand_stars(scope, resolver):
|
||||||
if table not in scope.sources:
|
if table not in scope.sources:
|
||||||
raise OptimizeError(f"Unknown table: {table}")
|
raise OptimizeError(f"Unknown table: {table}")
|
||||||
columns = resolver.get_source_columns(table, only_visible=True)
|
columns = resolver.get_source_columns(table, only_visible=True)
|
||||||
|
if not columns:
|
||||||
|
raise OptimizeError(
|
||||||
|
f"Table has no schema/columns. Cannot expand star for table: {table}."
|
||||||
|
)
|
||||||
table_id = id(table)
|
table_id = id(table)
|
||||||
for name in columns:
|
for name in columns:
|
||||||
if name not in except_columns.get(table_id, set()):
|
if name not in except_columns.get(table_id, set()):
|
||||||
|
@ -306,16 +306,11 @@ def _qualify_outputs(scope):
|
||||||
for i, (selection, aliased_column) in enumerate(
|
for i, (selection, aliased_column) in enumerate(
|
||||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||||
):
|
):
|
||||||
if isinstance(selection, exp.Column):
|
if isinstance(selection, exp.Subquery):
|
||||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
if not selection.output_name:
|
||||||
alias_ = alias(exp.column(""), alias=selection.name)
|
|
||||||
alias_.set("this", selection)
|
|
||||||
selection = alias_
|
|
||||||
elif isinstance(selection, exp.Subquery):
|
|
||||||
if not selection.alias:
|
|
||||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||||
elif not isinstance(selection, exp.Alias):
|
elif not isinstance(selection, exp.Alias):
|
||||||
alias_ = alias(exp.column(""), f"_col_{i}")
|
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
|
||||||
alias_.set("this", selection)
|
alias_.set("this", selection)
|
||||||
selection = alias_
|
selection = alias_
|
||||||
|
|
||||||
|
@ -346,20 +341,30 @@ class _Resolver:
|
||||||
self._unambiguous_columns = None
|
self._unambiguous_columns = None
|
||||||
self._all_columns = None
|
self._all_columns = None
|
||||||
|
|
||||||
def get_table(self, column_name):
|
def get_table(self, column_name: str) -> t.Optional[str]:
|
||||||
"""
|
"""
|
||||||
Get the table for a column name.
|
Get the table for a column name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
column_name (str)
|
column_name: The column name to find the table for.
|
||||||
Returns:
|
Returns:
|
||||||
(str) table name
|
The table name if it can be found/inferred.
|
||||||
"""
|
"""
|
||||||
if self._unambiguous_columns is None:
|
if self._unambiguous_columns is None:
|
||||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||||
self._get_all_source_columns()
|
self._get_all_source_columns()
|
||||||
)
|
)
|
||||||
return self._unambiguous_columns.get(column_name)
|
|
||||||
|
table = self._unambiguous_columns.get(column_name)
|
||||||
|
|
||||||
|
if not table:
|
||||||
|
sources_without_schema = tuple(
|
||||||
|
source for source, columns in self._get_all_source_columns().items() if not columns
|
||||||
|
)
|
||||||
|
if len(sources_without_schema) == 1:
|
||||||
|
return sources_without_schema[0]
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_columns(self):
|
def all_columns(self):
|
||||||
|
@ -379,10 +384,7 @@ class _Resolver:
|
||||||
|
|
||||||
# If referencing a table, return the columns from the schema
|
# If referencing a table, return the columns from the schema
|
||||||
if isinstance(source, exp.Table):
|
if isinstance(source, exp.Table):
|
||||||
try:
|
|
||||||
return self.schema.column_names(source, only_visible)
|
return self.schema.column_names(source, only_visible)
|
||||||
except Exception as e:
|
|
||||||
raise SchemaError(str(e)) from e
|
|
||||||
|
|
||||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||||
return source.expression.alias_column_names
|
return source.expression.alias_column_names
|
||||||
|
|
|
@ -230,7 +230,7 @@ class Scope:
|
||||||
column for scope in self.subquery_scopes for column in scope.external_columns
|
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||||
]
|
]
|
||||||
|
|
||||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
named_selects = set(self.expression.named_selects)
|
||||||
|
|
||||||
self._columns = []
|
self._columns = []
|
||||||
for column in columns + external_columns:
|
for column in columns + external_columns:
|
||||||
|
@ -238,7 +238,7 @@ class Scope:
|
||||||
if (
|
if (
|
||||||
not ancestor
|
not ancestor
|
||||||
or column.table
|
or column.table
|
||||||
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
|
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
||||||
):
|
):
|
||||||
self._columns.append(column)
|
self._columns.append(column)
|
||||||
|
|
||||||
|
|
|
@ -40,22 +40,23 @@ class _Parser(type):
|
||||||
|
|
||||||
class Parser(metaclass=_Parser):
|
class Parser(metaclass=_Parser):
|
||||||
"""
|
"""
|
||||||
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
|
Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
|
||||||
and produces a parsed syntax tree.
|
a parsed syntax tree.
|
||||||
|
|
||||||
Args
|
Args:
|
||||||
error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE.
|
error_level: the desired error level.
|
||||||
error_message_context (int): determines the amount of context to capture from
|
Default: ErrorLevel.RAISE
|
||||||
a query string when displaying the error message (in number of characters).
|
error_message_context: determines the amount of context to capture from a
|
||||||
|
query string when displaying the error message (in number of characters).
|
||||||
Default: 50.
|
Default: 50.
|
||||||
index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list
|
index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
|
||||||
Default: 0
|
Default: 0
|
||||||
alias_post_tablesample (bool): If the table alias comes after tablesample
|
alias_post_tablesample: If the table alias comes after tablesample.
|
||||||
Default: False
|
Default: False
|
||||||
max_errors (int): Maximum number of error messages to include in a raised ParseError.
|
max_errors: Maximum number of error messages to include in a raised ParseError.
|
||||||
This is only relevant if error_level is ErrorLevel.RAISE.
|
This is only relevant if error_level is ErrorLevel.RAISE.
|
||||||
Default: 3
|
Default: 3
|
||||||
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
|
null_ordering: Indicates the default null ordering method to use if not explicitly set.
|
||||||
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
|
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
|
||||||
Default: "nulls_are_small"
|
Default: "nulls_are_small"
|
||||||
"""
|
"""
|
||||||
|
@ -109,6 +110,8 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.TEXT,
|
TokenType.TEXT,
|
||||||
TokenType.MEDIUMTEXT,
|
TokenType.MEDIUMTEXT,
|
||||||
TokenType.LONGTEXT,
|
TokenType.LONGTEXT,
|
||||||
|
TokenType.MEDIUMBLOB,
|
||||||
|
TokenType.LONGBLOB,
|
||||||
TokenType.BINARY,
|
TokenType.BINARY,
|
||||||
TokenType.VARBINARY,
|
TokenType.VARBINARY,
|
||||||
TokenType.JSON,
|
TokenType.JSON,
|
||||||
|
@ -176,6 +179,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.DIV,
|
TokenType.DIV,
|
||||||
TokenType.DISTKEY,
|
TokenType.DISTKEY,
|
||||||
TokenType.DISTSTYLE,
|
TokenType.DISTSTYLE,
|
||||||
|
TokenType.END,
|
||||||
TokenType.EXECUTE,
|
TokenType.EXECUTE,
|
||||||
TokenType.ENGINE,
|
TokenType.ENGINE,
|
||||||
TokenType.ESCAPE,
|
TokenType.ESCAPE,
|
||||||
|
@ -468,9 +472,6 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||||
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
||||||
TokenType.PARAMETER: lambda self, _: self.expression(
|
|
||||||
exp.Parameter, this=self._parse_var() or self._parse_primary()
|
|
||||||
),
|
|
||||||
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
||||||
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||||
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||||
|
@ -479,6 +480,16 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PLACEHOLDER_PARSERS = {
|
||||||
|
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
|
||||||
|
TokenType.PARAMETER: lambda self: self.expression(
|
||||||
|
exp.Parameter, this=self._parse_var() or self._parse_primary()
|
||||||
|
),
|
||||||
|
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||||
|
if self._match_set((TokenType.NUMBER, TokenType.VAR))
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
RANGE_PARSERS = {
|
RANGE_PARSERS = {
|
||||||
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
|
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
|
||||||
TokenType.IN: lambda self, this: self._parse_in(this),
|
TokenType.IN: lambda self, this: self._parse_in(this),
|
||||||
|
@ -601,8 +612,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||||
|
|
||||||
# allows tables to have special tokens as prefixes
|
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||||
TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
|
|
||||||
|
|
||||||
STRICT_CAST = True
|
STRICT_CAST = True
|
||||||
|
|
||||||
|
@ -677,7 +687,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def parse_into(
|
def parse_into(
|
||||||
self,
|
self,
|
||||||
expression_types: str | exp.Expression | t.Collection[exp.Expression | str],
|
expression_types: exp.IntoType,
|
||||||
raw_tokens: t.List[Token],
|
raw_tokens: t.List[Token],
|
||||||
sql: t.Optional[str] = None,
|
sql: t.Optional[str] = None,
|
||||||
) -> t.List[t.Optional[exp.Expression]]:
|
) -> t.List[t.Optional[exp.Expression]]:
|
||||||
|
@ -820,24 +830,8 @@ class Parser(metaclass=_Parser):
|
||||||
if self.error_level == ErrorLevel.IGNORE:
|
if self.error_level == ErrorLevel.IGNORE:
|
||||||
return
|
return
|
||||||
|
|
||||||
for k in expression.args:
|
for error_message in expression.error_messages(args):
|
||||||
if k not in expression.arg_types:
|
self.raise_error(error_message)
|
||||||
self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
|
|
||||||
for k, mandatory in expression.arg_types.items():
|
|
||||||
v = expression.args.get(k)
|
|
||||||
if mandatory and (v is None or (isinstance(v, list) and not v)):
|
|
||||||
self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
|
|
||||||
|
|
||||||
if (
|
|
||||||
args
|
|
||||||
and isinstance(expression, exp.Func)
|
|
||||||
and len(args) > len(expression.arg_types)
|
|
||||||
and not expression.is_var_len_args
|
|
||||||
):
|
|
||||||
self.raise_error(
|
|
||||||
f"The number of provided arguments ({len(args)}) is greater than "
|
|
||||||
f"the maximum number of supported arguments ({len(expression.arg_types)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_token(self, token: Token, sql: str) -> int:
|
def _find_token(self, token: Token, sql: str) -> int:
|
||||||
line = 1
|
line = 1
|
||||||
|
@ -868,6 +862,9 @@ class Parser(metaclass=_Parser):
|
||||||
def _retreat(self, index: int) -> None:
|
def _retreat(self, index: int) -> None:
|
||||||
self._advance(index - self._index)
|
self._advance(index - self._index)
|
||||||
|
|
||||||
|
def _parse_command(self) -> exp.Expression:
|
||||||
|
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
|
||||||
|
|
||||||
def _parse_statement(self) -> t.Optional[exp.Expression]:
|
def _parse_statement(self) -> t.Optional[exp.Expression]:
|
||||||
if self._curr is None:
|
if self._curr is None:
|
||||||
return None
|
return None
|
||||||
|
@ -876,11 +873,7 @@ class Parser(metaclass=_Parser):
|
||||||
return self.STATEMENT_PARSERS[self._prev.token_type](self)
|
return self.STATEMENT_PARSERS[self._prev.token_type](self)
|
||||||
|
|
||||||
if self._match_set(Tokenizer.COMMANDS):
|
if self._match_set(Tokenizer.COMMANDS):
|
||||||
return self.expression(
|
return self._parse_command()
|
||||||
exp.Command,
|
|
||||||
this=self._prev.text,
|
|
||||||
expression=self._parse_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
expression = self._parse_expression()
|
expression = self._parse_expression()
|
||||||
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
||||||
|
@ -942,12 +935,18 @@ class Parser(metaclass=_Parser):
|
||||||
no_primary_index = None
|
no_primary_index = None
|
||||||
indexes = None
|
indexes = None
|
||||||
no_schema_binding = None
|
no_schema_binding = None
|
||||||
|
begin = None
|
||||||
|
|
||||||
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
||||||
this = self._parse_user_defined_function()
|
this = self._parse_user_defined_function(kind=create_token.token_type)
|
||||||
properties = self._parse_properties()
|
properties = self._parse_properties()
|
||||||
if self._match(TokenType.ALIAS):
|
if self._match(TokenType.ALIAS):
|
||||||
expression = self._parse_select_or_expression()
|
begin = self._match(TokenType.BEGIN)
|
||||||
|
return_ = self._match_text_seq("RETURN")
|
||||||
|
expression = self._parse_statement()
|
||||||
|
|
||||||
|
if return_:
|
||||||
|
expression = self.expression(exp.Return, this=expression)
|
||||||
elif create_token.token_type == TokenType.INDEX:
|
elif create_token.token_type == TokenType.INDEX:
|
||||||
this = self._parse_index()
|
this = self._parse_index()
|
||||||
elif create_token.token_type in (
|
elif create_token.token_type in (
|
||||||
|
@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser):
|
||||||
no_primary_index=no_primary_index,
|
no_primary_index=no_primary_index,
|
||||||
indexes=indexes,
|
indexes=indexes,
|
||||||
no_schema_binding=no_schema_binding,
|
no_schema_binding=no_schema_binding,
|
||||||
|
begin=begin,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_property(self) -> t.Optional[exp.Expression]:
|
def _parse_property(self) -> t.Optional[exp.Expression]:
|
||||||
|
@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match(TokenType.GT):
|
if not self._match(TokenType.GT):
|
||||||
self.raise_error("Expecting >")
|
self.raise_error("Expecting >")
|
||||||
else:
|
else:
|
||||||
value = self._parse_schema(exp.Literal.string("TABLE"))
|
value = self._parse_schema(exp.Var(this="TABLE"))
|
||||||
else:
|
else:
|
||||||
value = self._parse_types()
|
value = self._parse_types()
|
||||||
|
|
||||||
|
@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser):
|
||||||
return None
|
return None
|
||||||
index = self._parse_id_var()
|
index = self._parse_id_var()
|
||||||
columns = None
|
columns = None
|
||||||
if self._curr and self._curr.token_type == TokenType.L_PAREN:
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
columns = self._parse_wrapped_csv(self._parse_column)
|
columns = self._parse_wrapped_csv(self._parse_column)
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Index,
|
exp.Index,
|
||||||
|
@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser):
|
||||||
amp=amp,
|
amp=amp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
|
||||||
|
catalog = None
|
||||||
|
db = None
|
||||||
|
table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
|
||||||
|
|
||||||
|
while self._match(TokenType.DOT):
|
||||||
|
if catalog:
|
||||||
|
# This allows nesting the table in arbitrarily many dot expressions if needed
|
||||||
|
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
|
||||||
|
else:
|
||||||
|
catalog = db
|
||||||
|
db = table
|
||||||
|
table = self._parse_id_var()
|
||||||
|
|
||||||
|
if not table:
|
||||||
|
self.raise_error(f"Expected table name but got {self._curr}")
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_table(
|
def _parse_table(
|
||||||
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
|
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
|
||||||
) -> t.Optional[exp.Expression]:
|
) -> t.Optional[exp.Expression]:
|
||||||
|
@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser):
|
||||||
if subquery:
|
if subquery:
|
||||||
return subquery
|
return subquery
|
||||||
|
|
||||||
catalog = None
|
this = self._parse_table_parts(schema=schema)
|
||||||
db = None
|
|
||||||
table = (not schema and self._parse_function()) or self._parse_id_var(
|
|
||||||
any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
|
|
||||||
)
|
|
||||||
|
|
||||||
while self._match(TokenType.DOT):
|
|
||||||
if catalog:
|
|
||||||
# This allows nesting the table in arbitrarily many dot expressions if needed
|
|
||||||
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
|
|
||||||
else:
|
|
||||||
catalog = db
|
|
||||||
db = table
|
|
||||||
table = self._parse_id_var()
|
|
||||||
|
|
||||||
if not table:
|
|
||||||
self.raise_error(f"Expected table name but got {self._curr}")
|
|
||||||
|
|
||||||
this = self.expression(
|
|
||||||
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
|
||||||
)
|
|
||||||
|
|
||||||
if schema:
|
if schema:
|
||||||
return self._parse_schema(this=this)
|
return self._parse_schema(this=this)
|
||||||
|
@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser):
|
||||||
expression,
|
expression,
|
||||||
this=this,
|
this=this,
|
||||||
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
||||||
expression=self._parse_select(nested=True),
|
expression=self._parse_set_operations(self._parse_select(nested=True)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||||
|
@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_r_paren(this)
|
self._match_r_paren(this)
|
||||||
return self._parse_window(this)
|
return self._parse_window(this)
|
||||||
|
|
||||||
def _parse_user_defined_function(self) -> t.Optional[exp.Expression]:
|
def _parse_user_defined_function(
|
||||||
|
self, kind: t.Optional[TokenType] = None
|
||||||
|
) -> t.Optional[exp.Expression]:
|
||||||
this = self._parse_id_var()
|
this = self._parse_id_var()
|
||||||
|
|
||||||
while self._match(TokenType.DOT):
|
while self._match(TokenType.DOT):
|
||||||
|
@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
expressions = self._parse_csv(self._parse_udf_kwarg)
|
expressions = self._parse_csv(self._parse_udf_kwarg)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
return self.expression(
|
||||||
|
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
|
def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
|
||||||
literal = self._parse_primary()
|
literal = self._parse_primary()
|
||||||
|
@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser):
|
||||||
or self._parse_column_def(self._parse_field(any_token=True))
|
or self._parse_column_def(self._parse_field(any_token=True))
|
||||||
)
|
)
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
|
||||||
if isinstance(this, exp.Literal):
|
|
||||||
this = this.name
|
|
||||||
|
|
||||||
return self.expression(exp.Schema, this=this, expressions=args)
|
return self.expression(exp.Schema, this=this, expressions=args)
|
||||||
|
|
||||||
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||||
|
@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_unique(self) -> exp.Expression:
|
def _parse_unique(self) -> exp.Expression:
|
||||||
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
|
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
|
||||||
|
|
||||||
|
def _parse_key_constraint_options(self) -> t.List[str]:
|
||||||
|
options = []
|
||||||
|
while True:
|
||||||
|
if not self._curr:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._match_text_seq("NOT", "ENFORCED"):
|
||||||
|
options.append("NOT ENFORCED")
|
||||||
|
elif self._match_text_seq("DEFERRABLE"):
|
||||||
|
options.append("DEFERRABLE")
|
||||||
|
elif self._match_text_seq("INITIALLY", "DEFERRED"):
|
||||||
|
options.append("INITIALLY DEFERRED")
|
||||||
|
elif self._match_text_seq("NORELY"):
|
||||||
|
options.append("NORELY")
|
||||||
|
elif self._match_text_seq("MATCH", "FULL"):
|
||||||
|
options.append("MATCH FULL")
|
||||||
|
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
|
||||||
|
options.append("ON UPDATE NO ACTION")
|
||||||
|
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
|
||||||
|
options.append("ON DELETE NO ACTION")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
def _parse_references(self) -> t.Optional[exp.Expression]:
|
def _parse_references(self) -> t.Optional[exp.Expression]:
|
||||||
if not self._match(TokenType.REFERENCES):
|
if not self._match(TokenType.REFERENCES):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.expression(
|
expressions = None
|
||||||
exp.Reference,
|
this = self._parse_id_var()
|
||||||
this=self._parse_id_var(),
|
|
||||||
expressions=self._parse_wrapped_id_vars(),
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
)
|
expressions = self._parse_wrapped_id_vars()
|
||||||
|
|
||||||
|
options = self._parse_key_constraint_options()
|
||||||
|
return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
|
||||||
|
|
||||||
def _parse_foreign_key(self) -> exp.Expression:
|
def _parse_foreign_key(self) -> exp.Expression:
|
||||||
expressions = self._parse_wrapped_id_vars()
|
expressions = self._parse_wrapped_id_vars()
|
||||||
|
@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser):
|
||||||
options[kind] = action
|
options[kind] = action
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.ForeignKey,
|
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
|
||||||
expressions=expressions,
|
|
||||||
reference=reference,
|
|
||||||
**options, # type: ignore
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_primary_key(self) -> exp.Expression:
|
||||||
|
expressions = self._parse_wrapped_id_vars()
|
||||||
|
options = self._parse_key_constraint_options()
|
||||||
|
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
|
||||||
|
|
||||||
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||||
if not self._match(TokenType.L_BRACKET):
|
if not self._match(TokenType.L_BRACKET):
|
||||||
return this
|
return this
|
||||||
|
@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser):
|
||||||
order = self._parse_order(this=expression)
|
order = self._parse_order(this=expression)
|
||||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||||
|
|
||||||
def _parse_convert(self, strict: bool) -> exp.Expression:
|
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||||
to: t.Optional[exp.Expression]
|
to: t.Optional[exp.Expression]
|
||||||
this = self._parse_column()
|
this = self._parse_column()
|
||||||
|
|
||||||
|
@ -2641,20 +2672,26 @@ class Parser(metaclass=_Parser):
|
||||||
to = self._parse_types()
|
to = self._parse_types()
|
||||||
else:
|
else:
|
||||||
to = None
|
to = None
|
||||||
|
|
||||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||||
|
|
||||||
def _parse_position(self) -> exp.Expression:
|
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
|
||||||
args = self._parse_csv(self._parse_bitwise)
|
args = self._parse_csv(self._parse_bitwise)
|
||||||
|
|
||||||
if self._match(TokenType.IN):
|
if self._match(TokenType.IN):
|
||||||
args.append(self._parse_bitwise())
|
return self.expression(
|
||||||
|
exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0)
|
||||||
this = exp.StrPosition(
|
|
||||||
this=seq_get(args, 1),
|
|
||||||
substr=seq_get(args, 0),
|
|
||||||
position=seq_get(args, 2),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if haystack_first:
|
||||||
|
haystack = seq_get(args, 0)
|
||||||
|
needle = seq_get(args, 1)
|
||||||
|
else:
|
||||||
|
needle = seq_get(args, 0)
|
||||||
|
haystack = seq_get(args, 1)
|
||||||
|
|
||||||
|
this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
|
||||||
|
|
||||||
self.validate_expression(this, args)
|
self.validate_expression(this, args)
|
||||||
|
|
||||||
return this
|
return this
|
||||||
|
@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||||
if self._match(TokenType.PLACEHOLDER):
|
if self._match_set(self.PLACEHOLDER_PARSERS):
|
||||||
return self.expression(exp.Placeholder)
|
placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
|
||||||
elif self._match(TokenType.COLON):
|
if placeholder:
|
||||||
if self._match_set((TokenType.NUMBER, TokenType.VAR)):
|
return placeholder
|
||||||
return self.expression(exp.Placeholder, this=self._prev.text)
|
|
||||||
self._advance(-1)
|
self._advance(-1)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||||
if not self._match(TokenType.EXCEPT):
|
if not self._match(TokenType.EXCEPT):
|
||||||
return None
|
return None
|
||||||
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
return self._parse_wrapped_id_vars()
|
return self._parse_wrapped_id_vars()
|
||||||
|
return self._parse_csv(self._parse_id_var)
|
||||||
|
|
||||||
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||||
if not self._match(TokenType.REPLACE):
|
if not self._match(TokenType.REPLACE):
|
||||||
return None
|
return None
|
||||||
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
|
return self._parse_wrapped_csv(self._parse_expression)
|
||||||
|
return self._parse_csv(self._parse_expression)
|
||||||
|
|
||||||
def _parse_csv(
|
def _parse_csv(
|
||||||
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
|
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
|
||||||
|
@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||||
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
|
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
|
||||||
|
|
||||||
|
def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
|
||||||
|
this = None
|
||||||
|
kind = self._prev.token_type
|
||||||
|
|
||||||
|
if kind == TokenType.CONSTRAINT:
|
||||||
|
this = self._parse_id_var()
|
||||||
|
|
||||||
|
if self._match(TokenType.CHECK):
|
||||||
|
expression = self._parse_wrapped(self._parse_conjunction)
|
||||||
|
enforced = self._match_text_seq("ENFORCED")
|
||||||
|
|
||||||
|
return self.expression(
|
||||||
|
exp.AddConstraint, this=this, expression=expression, enforced=enforced
|
||||||
|
)
|
||||||
|
|
||||||
|
if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
|
||||||
|
expression = self._parse_foreign_key()
|
||||||
|
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
|
||||||
|
expression = self._parse_primary_key()
|
||||||
|
|
||||||
|
return self.expression(exp.AddConstraint, this=this, expression=expression)
|
||||||
|
|
||||||
def _parse_alter(self) -> t.Optional[exp.Expression]:
|
def _parse_alter(self) -> t.Optional[exp.Expression]:
|
||||||
if not self._match(TokenType.TABLE):
|
if not self._match(TokenType.TABLE):
|
||||||
return None
|
return None
|
||||||
|
@ -3029,7 +3090,13 @@ class Parser(metaclass=_Parser):
|
||||||
this = self._parse_table(schema=True)
|
this = self._parse_table(schema=True)
|
||||||
|
|
||||||
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
|
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
|
||||||
if self._match_text_seq("ADD", advance=False):
|
|
||||||
|
index = self._index
|
||||||
|
if self._match_text_seq("ADD"):
|
||||||
|
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
|
||||||
|
actions = self._parse_csv(self._parse_add_constraint)
|
||||||
|
else:
|
||||||
|
self._retreat(index)
|
||||||
actions = self._parse_csv(self._parse_add_column)
|
actions = self._parse_csv(self._parse_add_column)
|
||||||
elif self._match_text_seq("DROP", advance=False):
|
elif self._match_text_seq("DROP", advance=False):
|
||||||
actions = self._parse_csv(self._parse_drop_column)
|
actions = self._parse_csv(self._parse_drop_column)
|
||||||
|
@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def _parse_merge(self) -> exp.Expression:
|
def _parse_merge(self) -> exp.Expression:
|
||||||
self._match(TokenType.INTO)
|
self._match(TokenType.INTO)
|
||||||
target = self._parse_table(schema=True)
|
target = self._parse_table()
|
||||||
|
|
||||||
self._match(TokenType.USING)
|
self._match(TokenType.USING)
|
||||||
using = self._parse_table()
|
using = self._parse_table()
|
||||||
|
@ -3146,11 +3213,12 @@ class Parser(metaclass=_Parser):
|
||||||
self._retreat(index)
|
self._retreat(index)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _match(self, token_type):
|
def _match(self, token_type, advance=True):
|
||||||
if not self._curr:
|
if not self._curr:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._curr.token_type == token_type:
|
if self._curr.token_type == token_type:
|
||||||
|
if advance:
|
||||||
self._advance()
|
self._advance()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ class Plan:
|
||||||
return self._dag
|
return self._dag
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def leaves(self) -> t.Generator[Step, None, None]:
|
def leaves(self) -> t.Iterator[Step]:
|
||||||
return (node for node, deps in self.dag.items() if not deps)
|
return (node for node, deps in self.dag.items() if not deps)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -401,7 +401,7 @@ class SetOperation(Step):
|
||||||
op=expression.__class__,
|
op=expression.__class__,
|
||||||
left=left.name,
|
left=left.name,
|
||||||
right=right.name,
|
right=right.name,
|
||||||
distinct=expression.args.get("distinct"),
|
distinct=bool(expression.args.get("distinct")),
|
||||||
)
|
)
|
||||||
step.add_dependency(left)
|
step.add_dependency(left)
|
||||||
step.add_dependency(right)
|
step.add_dependency(right)
|
||||||
|
|
|
@ -109,9 +109,6 @@ class AbstractMappingSchema(t.Generic[T]):
|
||||||
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
||||||
|
|
||||||
if value == 0:
|
if value == 0:
|
||||||
if raise_on_missing:
|
|
||||||
raise SchemaError(f"Cannot find mapping for {table}.")
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
elif value == 1:
|
elif value == 1:
|
||||||
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
|
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
|
||||||
|
@ -262,7 +259,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
schema = self.find(table_)
|
schema = self.find(table_)
|
||||||
|
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise SchemaError(f"Could not find table schema {table}")
|
return []
|
||||||
|
|
||||||
if not only_visible or not self.visible:
|
if not only_visible or not self.visible:
|
||||||
return list(schema)
|
return list(schema)
|
||||||
|
|
|
@ -84,6 +84,8 @@ class TokenType(AutoName):
|
||||||
TEXT = auto()
|
TEXT = auto()
|
||||||
MEDIUMTEXT = auto()
|
MEDIUMTEXT = auto()
|
||||||
LONGTEXT = auto()
|
LONGTEXT = auto()
|
||||||
|
MEDIUMBLOB = auto()
|
||||||
|
LONGBLOB = auto()
|
||||||
BINARY = auto()
|
BINARY = auto()
|
||||||
VARBINARY = auto()
|
VARBINARY = auto()
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
|
@ -587,6 +589,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"PRECEDING": TokenType.PRECEDING,
|
"PRECEDING": TokenType.PRECEDING,
|
||||||
"PRIMARY KEY": TokenType.PRIMARY_KEY,
|
"PRIMARY KEY": TokenType.PRIMARY_KEY,
|
||||||
"PROCEDURE": TokenType.PROCEDURE,
|
"PROCEDURE": TokenType.PROCEDURE,
|
||||||
|
"QUALIFY": TokenType.QUALIFY,
|
||||||
"RANGE": TokenType.RANGE,
|
"RANGE": TokenType.RANGE,
|
||||||
"RECURSIVE": TokenType.RECURSIVE,
|
"RECURSIVE": TokenType.RECURSIVE,
|
||||||
"REGEXP": TokenType.RLIKE,
|
"REGEXP": TokenType.RLIKE,
|
||||||
|
@ -726,6 +729,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
TokenType.SHOW,
|
TokenType.SHOW,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
|
||||||
|
|
||||||
# handle numeric literals like in hive (3L = BIGINT)
|
# handle numeric literals like in hive (3L = BIGINT)
|
||||||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||||
ENCODE: t.Optional[str] = None
|
ENCODE: t.Optional[str] = None
|
||||||
|
@ -842,8 +847,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
)
|
)
|
||||||
self._comments = []
|
self._comments = []
|
||||||
|
|
||||||
|
# If we have either a semicolon or a begin token before the command's token, we'll parse
|
||||||
|
# whatever follows the command's token as a string
|
||||||
if token_type in self.COMMANDS and (
|
if token_type in self.COMMANDS and (
|
||||||
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS
|
||||||
):
|
):
|
||||||
start = self._current
|
start = self._current
|
||||||
tokens = len(self.tokens)
|
tokens = len(self.tokens)
|
||||||
|
|
|
@ -17,7 +17,8 @@ class TestClickhouse(Validator):
|
||||||
self.validate_identity("SELECT quantile(0.5)(a)")
|
self.validate_identity("SELECT quantile(0.5)(a)")
|
||||||
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
|
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
|
||||||
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
|
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
|
||||||
self.validate_identity("position(a, b)")
|
self.validate_identity("position(haystack, needle)")
|
||||||
|
self.validate_identity("position(haystack, needle, position)")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
|
||||||
|
@ -48,6 +49,10 @@ class TestClickhouse(Validator):
|
||||||
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
|
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT position(needle IN haystack)",
|
||||||
|
write={"clickhouse": "SELECT position(haystack, needle)"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_cte(self):
|
def test_cte(self):
|
||||||
self.validate_identity("WITH 'x' AS foo SELECT foo")
|
self.validate_identity("WITH 'x' AS foo SELECT foo")
|
||||||
|
|
|
@ -950,40 +950,40 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"POSITION(' ' in x)",
|
"POSITION(needle in haystack)",
|
||||||
write={
|
write={
|
||||||
"drill": "STRPOS(x, ' ')",
|
"drill": "STRPOS(haystack, needle)",
|
||||||
"duckdb": "STRPOS(x, ' ')",
|
"duckdb": "STRPOS(haystack, needle)",
|
||||||
"postgres": "STRPOS(x, ' ')",
|
"postgres": "STRPOS(haystack, needle)",
|
||||||
"presto": "STRPOS(x, ' ')",
|
"presto": "STRPOS(haystack, needle)",
|
||||||
"spark": "LOCATE(' ', x)",
|
"spark": "LOCATE(needle, haystack)",
|
||||||
"clickhouse": "position(x, ' ')",
|
"clickhouse": "position(haystack, needle)",
|
||||||
"snowflake": "POSITION(' ', x)",
|
"snowflake": "POSITION(needle, haystack)",
|
||||||
"mysql": "LOCATE(' ', x)",
|
"mysql": "LOCATE(needle, haystack)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_POSITION(x, 'a')",
|
"STR_POSITION(haystack, needle)",
|
||||||
write={
|
write={
|
||||||
"drill": "STRPOS(x, 'a')",
|
"drill": "STRPOS(haystack, needle)",
|
||||||
"duckdb": "STRPOS(x, 'a')",
|
"duckdb": "STRPOS(haystack, needle)",
|
||||||
"postgres": "STRPOS(x, 'a')",
|
"postgres": "STRPOS(haystack, needle)",
|
||||||
"presto": "STRPOS(x, 'a')",
|
"presto": "STRPOS(haystack, needle)",
|
||||||
"spark": "LOCATE('a', x)",
|
"spark": "LOCATE(needle, haystack)",
|
||||||
"clickhouse": "position(x, 'a')",
|
"clickhouse": "position(haystack, needle)",
|
||||||
"snowflake": "POSITION('a', x)",
|
"snowflake": "POSITION(needle, haystack)",
|
||||||
"mysql": "LOCATE('a', x)",
|
"mysql": "LOCATE(needle, haystack)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"POSITION('a', x, 3)",
|
"POSITION(needle, haystack, pos)",
|
||||||
write={
|
write={
|
||||||
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
|
"drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1",
|
||||||
"presto": "STRPOS(x, 'a', 3)",
|
"presto": "STRPOS(haystack, needle, pos)",
|
||||||
"spark": "LOCATE('a', x, 3)",
|
"spark": "LOCATE(needle, haystack, pos)",
|
||||||
"clickhouse": "position(x, 'a', 3)",
|
"clickhouse": "position(haystack, needle, pos)",
|
||||||
"snowflake": "POSITION('a', x, 3)",
|
"snowflake": "POSITION(needle, haystack, pos)",
|
||||||
"mysql": "LOCATE('a', x, 3)",
|
"mysql": "LOCATE(needle, haystack, pos)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -1365,3 +1365,19 @@ SELECT
|
||||||
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
|
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
MERGE a b USING c d ON b.id = d.id
|
||||||
|
WHEN MATCHED AND EXISTS (
|
||||||
|
SELECT b.name
|
||||||
|
EXCEPT
|
||||||
|
SELECT d.name
|
||||||
|
)
|
||||||
|
THEN UPDATE SET b.name = d.name
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"bigquery": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT DISTINCT SELECT d.name) THEN UPDATE SET b.name = d.name",
|
||||||
|
"snowflake": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
|
||||||
|
"spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -75,6 +75,15 @@ class TestMySQL(Validator):
|
||||||
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
|
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)",
|
||||||
|
read={
|
||||||
|
"mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)",
|
||||||
|
},
|
||||||
|
write={
|
||||||
|
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_canonical_functions(self):
|
def test_canonical_functions(self):
|
||||||
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
|
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
|
||||||
|
|
|
@ -12,6 +12,24 @@ class TestSnowflake(Validator):
|
||||||
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
|
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * EXCLUDE a, b FROM xxx",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT * EXCLUDE (a, b) FROM xxx",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * RENAME a AS b, c AS d FROM xxx",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT * EXCLUDE a, b RENAME (c AS d, E as F) FROM xxx",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
'x:a:"b c"',
|
'x:a:"b c"',
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from sqlglot import exp, parse, parse_one
|
||||||
from tests.dialects.test_dialect import Validator
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +6,10 @@ class TestTSQL(Validator):
|
||||||
dialect = "tsql"
|
dialect = "tsql"
|
||||||
|
|
||||||
def test_tsql(self):
|
def test_tsql(self):
|
||||||
|
self.validate_identity("SELECT CASE WHEN a > 1 THEN b END")
|
||||||
|
self.validate_identity("END")
|
||||||
|
self.validate_identity("@x")
|
||||||
|
self.validate_identity("#x")
|
||||||
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
|
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
|
||||||
self.validate_identity("PRINT @TestVariable")
|
self.validate_identity("PRINT @TestVariable")
|
||||||
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
|
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
|
||||||
|
@ -87,6 +92,95 @@ class TestTSQL(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_udf(self):
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB"
|
||||||
|
)
|
||||||
|
self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz")
|
||||||
|
self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz")
|
||||||
|
self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1")
|
||||||
|
self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER")
|
||||||
|
|
||||||
|
# The following two cases don't necessarily correspond to valid TSQL, but they are used to verify
|
||||||
|
# that the syntax RETURNS @return_variable TABLE <table_type_definition> ... is parsed correctly.
|
||||||
|
#
|
||||||
|
# See also "Transact-SQL Multi-Statement Table-Valued Function Syntax"
|
||||||
|
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1"
|
||||||
|
)
|
||||||
|
self.validate_identity(
|
||||||
|
"CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION udfProductInYear (
|
||||||
|
@model_year INT
|
||||||
|
)
|
||||||
|
RETURNS TABLE
|
||||||
|
AS
|
||||||
|
RETURN
|
||||||
|
SELECT
|
||||||
|
product_name,
|
||||||
|
model_year,
|
||||||
|
list_price
|
||||||
|
FROM
|
||||||
|
production.products
|
||||||
|
WHERE
|
||||||
|
model_year = @model_year
|
||||||
|
""",
|
||||||
|
write={
|
||||||
|
"tsql": """CREATE FUNCTION udfProductInYear(
|
||||||
|
@model_year INTEGER
|
||||||
|
)
|
||||||
|
RETURNS TABLE AS
|
||||||
|
RETURN SELECT
|
||||||
|
product_name,
|
||||||
|
model_year,
|
||||||
|
list_price
|
||||||
|
FROM production.products
|
||||||
|
WHERE
|
||||||
|
model_year = @model_year""",
|
||||||
|
},
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
CREATE procedure [TRANSF].[SP_Merge_Sales_Real]
|
||||||
|
@Loadid INTEGER
|
||||||
|
,@NumberOfRows INTEGER
|
||||||
|
AS
|
||||||
|
BEGIN
|
||||||
|
SET XACT_ABORT ON;
|
||||||
|
|
||||||
|
DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104);
|
||||||
|
DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104);
|
||||||
|
DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER);
|
||||||
|
DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER);
|
||||||
|
|
||||||
|
DECLARE @SalesAmountBefore float;
|
||||||
|
SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S;
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected_sqls = [
|
||||||
|
'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON',
|
||||||
|
"DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
|
||||||
|
"DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)",
|
||||||
|
"DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)",
|
||||||
|
"DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)",
|
||||||
|
"DECLARE @SalesAmountBefore float",
|
||||||
|
'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S',
|
||||||
|
"END",
|
||||||
|
]
|
||||||
|
|
||||||
|
for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
|
||||||
|
self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
|
||||||
|
|
||||||
def test_charindex(self):
|
def test_charindex(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CHARINDEX(x, y, 9)",
|
"CHARINDEX(x, y, 9)",
|
||||||
|
@ -472,3 +566,51 @@ class TestTSQL(Validator):
|
||||||
"EOMONTH(GETDATE(), -1)",
|
"EOMONTH(GETDATE(), -1)",
|
||||||
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
|
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_variables(self):
|
||||||
|
# In TSQL @, # can be used as a prefix for variables/identifiers
|
||||||
|
expr = parse_one("@x", read="tsql")
|
||||||
|
self.assertIsInstance(expr, exp.Column)
|
||||||
|
self.assertIsInstance(expr.this, exp.Identifier)
|
||||||
|
|
||||||
|
expr = parse_one("#x", read="tsql")
|
||||||
|
self.assertIsInstance(expr, exp.Column)
|
||||||
|
self.assertIsInstance(expr.this, exp.Identifier)
|
||||||
|
|
||||||
|
def test_system_time(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias",
|
||||||
|
write={
|
||||||
|
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
25
tests/fixtures/identity.sql
vendored
25
tests/fixtures/identity.sql
vendored
|
@ -94,8 +94,8 @@ CONCAT_WS('-', 'a', 'b')
|
||||||
CONCAT_WS('-', 'a', 'b', 'c')
|
CONCAT_WS('-', 'a', 'b', 'c')
|
||||||
POSEXPLODE("x") AS ("a", "b")
|
POSEXPLODE("x") AS ("a", "b")
|
||||||
POSEXPLODE("x") AS ("a", "b", "c")
|
POSEXPLODE("x") AS ("a", "b", "c")
|
||||||
STR_POSITION(x, 'a')
|
STR_POSITION(haystack, needle)
|
||||||
STR_POSITION(x, 'a', 3)
|
STR_POSITION(haystack, needle, pos)
|
||||||
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
|
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
|
||||||
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
|
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
|
||||||
x[ORDINAL(1)][SAFE_OFFSET(2)]
|
x[ORDINAL(1)][SAFE_OFFSET(2)]
|
||||||
|
@ -375,12 +375,16 @@ SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
|
||||||
SELECT * FROM (SELECT 1 UNION ALL SELECT 2)
|
SELECT * FROM (SELECT 1 UNION ALL SELECT 2)
|
||||||
SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b)
|
SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b)
|
||||||
SELECT * FROM ((SELECT 1) AS a(b))
|
SELECT * FROM ((SELECT 1) AS a(b))
|
||||||
|
SELECT * FROM ((SELECT 1) UNION (SELECT 2) UNION (SELECT 3))
|
||||||
SELECT * FROM x AS y(a, b)
|
SELECT * FROM x AS y(a, b)
|
||||||
SELECT * EXCEPT (a, b)
|
SELECT * EXCEPT (a, b)
|
||||||
|
SELECT * EXCEPT (a, b) FROM y
|
||||||
SELECT * REPLACE (a AS b, b AS C)
|
SELECT * REPLACE (a AS b, b AS C)
|
||||||
SELECT * REPLACE (a + 1 AS b, b AS C)
|
SELECT * REPLACE (a + 1 AS b, b AS C)
|
||||||
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
|
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
|
||||||
|
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) FROM y
|
||||||
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
|
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
|
||||||
|
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) FROM x
|
||||||
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
|
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
|
||||||
SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
|
SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
|
||||||
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
|
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
|
||||||
|
@ -438,6 +442,8 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLO
|
||||||
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
|
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
|
||||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
|
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
|
||||||
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
|
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
|
||||||
|
SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x)
|
||||||
|
SELECT PERCENTILE_DISC(0.5) WITHIN GROUP (ORDER BY x)
|
||||||
SELECT SUM(x) FILTER(WHERE x > 1)
|
SELECT SUM(x) FILTER(WHERE x > 1)
|
||||||
SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y)
|
SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y)
|
||||||
SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
|
SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
|
||||||
|
@ -611,6 +617,7 @@ WITH a AS (SELECT * FROM b) DELETE FROM a
|
||||||
WITH a AS (SELECT * FROM b) CACHE TABLE a
|
WITH a AS (SELECT * FROM b) CACHE TABLE a
|
||||||
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
|
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
|
||||||
SELECT :hello, ? FROM x LIMIT :my_limit
|
SELECT :hello, ? FROM x LIMIT :my_limit
|
||||||
|
SELECT * FROM x FETCH NEXT @take ROWS ONLY OFFSET @skip
|
||||||
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
|
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
|
||||||
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
||||||
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
||||||
|
@ -670,3 +677,17 @@ CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)
|
||||||
CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
|
CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
|
||||||
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
|
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
|
||||||
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))
|
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))
|
||||||
|
ALTER TABLE "schema"."tablename" ADD CONSTRAINT "CHK_Name" CHECK (NOT "IdDwh" IS NULL AND "IdDwh" <> (0))
|
||||||
|
ALTER TABLE persons ADD CONSTRAINT persons_pk PRIMARY KEY (first_name, last_name)
|
||||||
|
ALTER TABLE pets ADD CONSTRAINT pets_persons_fk FOREIGN KEY (owner_first_name, owner_last_name) REFERENCES persons
|
||||||
|
ALTER TABLE pets ADD CONSTRAINT pets_name_not_cute_chk CHECK (LENGTH(name) < 20)
|
||||||
|
ALTER TABLE people10m ADD CONSTRAINT dateWithinRange CHECK (birthDate > '1900-01-01')
|
||||||
|
ALTER TABLE people10m ADD CONSTRAINT validIds CHECK (id > 1 AND id < 99999999) ENFORCED
|
||||||
|
ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY
|
||||||
|
ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY
|
||||||
|
ALTER TABLE baa ADD CONSTRAINT boo FOREIGN KEY (x, y) REFERENCES persons ON UPDATE NO ACTION ON DELETE NO ACTION MATCH FULL
|
||||||
|
ALTER TABLE a ADD PRIMARY KEY (x, y) NOT ENFORCED
|
||||||
|
ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla
|
||||||
|
CREATE TABLE foo (baz_id INT REFERENCES baz(id) DEFERRABLE)
|
||||||
|
SELECT end FROM a
|
||||||
|
SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1
|
||||||
|
|
|
@ -18,3 +18,6 @@ WITH y AS (SELECT *) SELECT * FROM x AS x;
|
||||||
|
|
||||||
WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y;
|
WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y;
|
||||||
WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y;
|
WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y;
|
||||||
|
|
||||||
|
SELECT * FROM x AS x JOIN xx AS y;
|
||||||
|
SELECT * FROM (SELECT * FROM x AS x) AS x JOIN xx AS y;
|
||||||
|
|
|
@ -2,7 +2,7 @@ SELECT a FROM (SELECT * FROM x);
|
||||||
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
|
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
|
||||||
|
|
||||||
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
|
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
|
||||||
SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
|
SELECT 1 AS "1" FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
|
||||||
|
|
||||||
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
|
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
|
||||||
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;
|
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;
|
||||||
|
|
16
tests/fixtures/optimizer/qualify_columns.sql
vendored
16
tests/fixtures/optimizer/qualify_columns.sql
vendored
|
@ -4,6 +4,14 @@
|
||||||
SELECT a FROM x;
|
SELECT a FROM x;
|
||||||
SELECT x.a AS a FROM x AS x;
|
SELECT x.a AS a FROM x AS x;
|
||||||
|
|
||||||
|
# execute: false
|
||||||
|
SELECT a FROM zz GROUP BY a ORDER BY a;
|
||||||
|
SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a;
|
||||||
|
|
||||||
|
# execute: false
|
||||||
|
SELECT x, p FROM (SELECT x from xx) xx CROSS JOIN yy;
|
||||||
|
SELECT xx.x AS x, yy.p AS p FROM (SELECT xx.x AS x FROM xx AS xx) AS xx CROSS JOIN yy AS yy;
|
||||||
|
|
||||||
SELECT a FROM x AS z;
|
SELECT a FROM x AS z;
|
||||||
SELECT z.a AS a FROM x AS z;
|
SELECT z.a AS a FROM x AS z;
|
||||||
|
|
||||||
|
@ -20,8 +28,8 @@ SELECT a AS b FROM x;
|
||||||
SELECT x.a AS b FROM x AS x;
|
SELECT x.a AS b FROM x AS x;
|
||||||
|
|
||||||
# execute: false
|
# execute: false
|
||||||
SELECT 1, 2 FROM x;
|
SELECT 1, 2 + 3 FROM x;
|
||||||
SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x;
|
SELECT 1 AS "1", 2 + 3 AS _col_1 FROM x AS x;
|
||||||
|
|
||||||
# execute: false
|
# execute: false
|
||||||
SELECT a + b FROM x;
|
SELECT a + b FROM x;
|
||||||
|
@ -57,6 +65,10 @@ SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a;
|
||||||
SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
|
SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
|
||||||
SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
||||||
|
|
||||||
|
# execute: false
|
||||||
|
SELECT CAST(a AS INT) FROM x ORDER BY a;
|
||||||
|
SELECT CAST(x.a AS INT) AS a FROM x AS x ORDER BY a;
|
||||||
|
|
||||||
# execute: false
|
# execute: false
|
||||||
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
|
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
|
||||||
SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
SELECT a FROM zz;
|
|
||||||
SELECT * FROM zz;
|
SELECT * FROM zz;
|
||||||
SELECT z.a FROM x;
|
SELECT z.a FROM x;
|
||||||
SELECT z.* FROM x;
|
SELECT z.* FROM x;
|
||||||
|
@ -11,3 +10,4 @@ SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.
|
||||||
SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c;
|
SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c;
|
||||||
SELECT x.a FROM x JOIN y USING (a);
|
SELECT x.a FROM x JOIN y USING (a);
|
||||||
SELECT a, SUM(b) FROM x GROUP BY 3;
|
SELECT a, SUM(b) FROM x GROUP BY 3;
|
||||||
|
SELECT p FROM (SELECT x from xx) y CROSS JOIN yy CROSS JOIN zz
|
||||||
|
|
|
@ -481,11 +481,11 @@ class TestExecutor(unittest.TestCase):
|
||||||
|
|
||||||
def test_static_queries(self):
|
def test_static_queries(self):
|
||||||
for sql, cols, rows in [
|
for sql, cols, rows in [
|
||||||
("SELECT 1", ["_col_0"], [(1,)]),
|
("SELECT 1", ["1"], [(1,)]),
|
||||||
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
|
("SELECT 1 + 2 AS x", ["x"], [(3,)]),
|
||||||
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
|
||||||
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
|
||||||
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
|
("SELECT 'foo' LIMIT 1", ["foo"], [("foo",)]),
|
||||||
(
|
(
|
||||||
"SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
|
"SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
|
||||||
["_col_0", "_col_1"],
|
["_col_0", "_col_1"],
|
||||||
|
|
|
@ -189,6 +189,27 @@ class TestExpressions(unittest.TestCase):
|
||||||
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
|
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_function_building(self):
|
||||||
|
self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, 'foo')")
|
||||||
|
self.assertEqual(exp.func("COUNT", exp.Star()).sql(), "COUNT(*)")
|
||||||
|
self.assertEqual(exp.func("bloo").sql(), "BLOO()")
|
||||||
|
self.assertEqual(
|
||||||
|
exp.func("locate", "x", "xo", dialect="hive").sql("hive"), "LOCATE('x', 'xo')"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition)
|
||||||
|
self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous)
|
||||||
|
self.assertIsInstance(
|
||||||
|
exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")),
|
||||||
|
exp.Cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
exp.func("some_func", 1, arg2="foo")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
exp.func("abs")
|
||||||
|
|
||||||
def test_named_selects(self):
|
def test_named_selects(self):
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
|
||||||
|
|
20
tests/test_lineage.py
Normal file
20
tests/test_lineage.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from sqlglot.lineage import lineage
|
||||||
|
|
||||||
|
|
||||||
|
class TestLineage(unittest.TestCase):
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
|
def test_lineage(self) -> None:
|
||||||
|
node = lineage(
|
||||||
|
"a",
|
||||||
|
"SELECT a FROM y",
|
||||||
|
schema={"x": {"a": "int"}},
|
||||||
|
sources={"y": "SELECT * FROM x"},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
node.source.sql(),
|
||||||
|
"SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */",
|
||||||
|
)
|
||||||
|
self.assertGreater(len(node.to_html()._repr_html_()), 1000)
|
|
@ -117,6 +117,7 @@ class TestOptimizer(unittest.TestCase):
|
||||||
self.check_file(
|
self.check_file(
|
||||||
"isolate_table_selects",
|
"isolate_table_selects",
|
||||||
optimizer.isolate_table_selects.isolate_table_selects,
|
optimizer.isolate_table_selects.isolate_table_selects,
|
||||||
|
schema=self.schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_qualify_tables(self):
|
def test_qualify_tables(self):
|
||||||
|
|
|
@ -17,6 +17,11 @@ class TestSchema(unittest.TestCase):
|
||||||
with self.assertRaises(SchemaError):
|
with self.assertRaises(SchemaError):
|
||||||
schema.column_names(to_table(table))
|
schema.column_names(to_table(table))
|
||||||
|
|
||||||
|
def assert_column_names_empty(self, schema, *tables):
|
||||||
|
for table in tables:
|
||||||
|
with self.subTest(table):
|
||||||
|
self.assertEqual(schema.column_names(to_table(table)), [])
|
||||||
|
|
||||||
def test_schema(self):
|
def test_schema(self):
|
||||||
schema = ensure_schema(
|
schema = ensure_schema(
|
||||||
{
|
{
|
||||||
|
@ -38,7 +43,7 @@ class TestSchema(unittest.TestCase):
|
||||||
("z.x.y", ["b", "c"]),
|
("z.x.y", ["b", "c"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assert_column_names_raises(
|
self.assert_column_names_empty(
|
||||||
schema,
|
schema,
|
||||||
"z",
|
"z",
|
||||||
"z.z",
|
"z.z",
|
||||||
|
@ -76,6 +81,10 @@ class TestSchema(unittest.TestCase):
|
||||||
self.assert_column_names_raises(
|
self.assert_column_names_raises(
|
||||||
schema,
|
schema,
|
||||||
"x",
|
"x",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names_empty(
|
||||||
|
schema,
|
||||||
"z.x",
|
"z.x",
|
||||||
"z.y",
|
"z.y",
|
||||||
)
|
)
|
||||||
|
@ -129,12 +138,16 @@ class TestSchema(unittest.TestCase):
|
||||||
|
|
||||||
self.assert_column_names_raises(
|
self.assert_column_names_raises(
|
||||||
schema,
|
schema,
|
||||||
"q",
|
|
||||||
"d2.x",
|
|
||||||
"y",
|
"y",
|
||||||
"z",
|
"z",
|
||||||
"d1.y",
|
"d1.y",
|
||||||
"d1.z",
|
"d1.z",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_column_names_empty(
|
||||||
|
schema,
|
||||||
|
"q",
|
||||||
|
"d2.x",
|
||||||
"a.b.c",
|
"a.b.c",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue