1
0
Fork 0

Merging upstream version 10.5.10.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:07:05 +01:00
parent 8588db6332
commit 4d496b7a6a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
43 changed files with 1384 additions and 356 deletions

View file

@ -32,6 +32,7 @@ We use GitHub issues to track public bugs. Report a bug by opening a new issue.
- What you expected would happen - What you expected would happen
- What actually happens - What actually happens
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
- References (e.g. documentation pages related to the issue)
## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions) ## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions)
[We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state [We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state

View file

@ -18,7 +18,7 @@ style:
check: style test check: style test
docs: docs:
pdoc/cli.py -o pdoc/docs python pdoc/cli.py -o pdoc/docs
docs-serve: docs-serve:
pdoc/cli.py python pdoc/cli.py

View file

@ -1,6 +1,6 @@
# SQLGlot # SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [18 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
@ -189,7 +189,6 @@ except sqlglot.errors.ParseError as e:
print(e.errors) print(e.errors)
``` ```
Output:
```python ```python
[{ [{
'description': 'Expecting )', 'description': 'Expecting )',

View file

@ -1,41 +0,0 @@
# Expressions
Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not.
Furthermore, the following attributes are common across all expressions:
#### key
A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings.
#### args
A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value.
#### arg_types
A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example:
```python
class Limit(Expression):
arg_types = {"this": False, "expression": True}
```
Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase.
#### parent
A reference to the parent expression (may be `None`).
#### arg_key
The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it.
#### comments
A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code.
#### type
The data type of an expression, as inferred by SQLGlot's optimizer.

View file

@ -1,5 +1,6 @@
""" """
.. include:: ../README.md .. include:: ../README.md
----
""" """
from __future__ import annotations from __future__ import annotations
@ -29,14 +30,16 @@ from sqlglot.expressions import table_ as table
from sqlglot.expressions import to_column, to_table, union from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.5.6" __version__ = "10.5.10"
pretty = False pretty = False
"""Whether to format generated SQL by default."""
schema = MappingSchema() schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def parse( def parse(
@ -48,7 +51,7 @@ def parse(
Args: Args:
sql: the SQL code string to parse. sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
**opts: other options. **opts: other `sqlglot.parser.Parser` options.
Returns: Returns:
The resulting syntax tree collection. The resulting syntax tree collection.
@ -60,7 +63,7 @@ def parse(
def parse_one( def parse_one(
sql: str, sql: str,
read: t.Optional[str | Dialect] = None, read: t.Optional[str | Dialect] = None,
into: t.Optional[t.Type[Expression] | str] = None, into: t.Optional[exp.IntoType] = None,
**opts, **opts,
) -> Expression: ) -> Expression:
""" """
@ -70,7 +73,7 @@ def parse_one(
sql: the SQL code string to parse. sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
into: the SQLGlot Expression to parse into. into: the SQLGlot Expression to parse into.
**opts: other options. **opts: other `sqlglot.parser.Parser` options.
Returns: Returns:
The syntax tree for the first parsed statement. The syntax tree for the first parsed statement.
@ -110,7 +113,7 @@ def transpile(
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
the source and the target dialect. the source and the target dialect.
error_level: the desired error level of the parser. error_level: the desired error level of the parser.
**opts: other options. **opts: other `sqlglot.generator.Generator` options.
Returns: Returns:
The list of transpiled SQL statements. The list of transpiled SQL statements.

View file

@ -1,29 +1,29 @@
# PySpark DataFrame SQL Generator # PySpark DataFrame SQL Generator
This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). This is a drop-in replacement for the PySpark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported. Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
# How to use # How to use
## Instructions ## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library * [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe` * Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)` * Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
* The column structure can be defined the following ways: * The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name * Dictionary where the keys are column names and values are string of the Spark SQL type name.
* Ex: {'cola': 'string', 'colb': 'int'} * Ex: `{'cola': 'string', 'colb': 'int'}`
* PySpark DataFrame `StructType` similar to when using `createDataFrame` * PySpark DataFrame `StructType` similar to when using `createDataFrame`.
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])` * Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
* A string of names and types similar to what is supported in `createDataFrame` * A string of names and types similar to what is supported in `createDataFrame`.
* Ex: `cola: STRING, colb: INT` * Ex: `cola: STRING, colb: INT`
* [Not Recommended] A list of string column names without type * [Not Recommended] A list of string column names without type.
* Ex: ['cola', 'colb'] * Ex: `['cola', 'colb']`
* The lack of types may limit functionality in future releases * The lack of types may limit functionality in future releases.
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally * See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command * Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. * In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects * Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
* Ex: `.sql(pretty=True, dialect='bigquery')` * Ex: `.sql(pretty=True, dialect='bigquery')`
## Examples ## Examples
@ -51,7 +51,7 @@ df = (
print(df.sql(pretty=True)) # Spark will be the dialect used by default print(df.sql(pretty=True)) # Spark will be the dialect used by default
``` ```
Output:
```sparksql ```sparksql
SELECT SELECT
`employee`.`age` AS `age`, `employee`.`age` AS `age`,
@ -206,7 +206,7 @@ sql_statements = (
.createDataFrame(data, schema) .createDataFrame(data, schema)
.groupBy(F.col("age")) .groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) .agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
.sql(dialect="bigquery") .sql(dialect="spark")
) )
pyspark = PySparkSession.builder.master("local[*]").getOrCreate() pyspark = PySparkSession.builder.master("local[*]").getOrCreate()

View file

@ -111,16 +111,13 @@ class DataFrame:
return DataFrameNaFunctions(self) return DataFrameNaFunctions(self)
def _replace_cte_names_with_hashes(self, expression: exp.Select): def _replace_cte_names_with_hashes(self, expression: exp.Select):
expression = expression.copy()
ctes = expression.ctes
replacement_mapping = {} replacement_mapping = {}
for cte in ctes: for cte in expression.ctes:
old_name_id = cte.args["alias"].this old_name_id = cte.args["alias"].this
new_hashed_id = exp.to_identifier( new_hashed_id = exp.to_identifier(
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
) )
replacement_mapping[old_name_id] = new_hashed_id replacement_mapping[old_name_id] = new_hashed_id
cte.set("alias", exp.TableAlias(this=new_hashed_id))
expression = expression.transform(replace_id_value, replacement_mapping) expression = expression.transform(replace_id_value, replacement_mapping)
return expression return expression
@ -183,7 +180,7 @@ class DataFrame:
expression = df.expression expression = df.expression
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
for hint in df.pending_partition_hints: for hint in df.pending_partition_hints:
hint_expression.args.get("expressions").append(hint) hint_expression.append("expressions", hint)
df.pending_hints.remove(hint) df.pending_hints.remove(hint)
join_aliases = { join_aliases = {
@ -209,7 +206,7 @@ class DataFrame:
sequence_id_expression.set("this", matching_cte.args["alias"].this) sequence_id_expression.set("this", matching_cte.args["alias"].this)
df.pending_hints.remove(hint) df.pending_hints.remove(hint)
break break
hint_expression.args.get("expressions").append(hint) hint_expression.append("expressions", hint)
if hint_expression.expressions: if hint_expression.expressions:
expression.set("hint", hint_expression) expression.set("hint", hint_expression)
return df return df

View file

@ -129,7 +129,7 @@ class SparkSession:
@property @property
def _random_name(self) -> str: def _random_name(self) -> str:
return f"a{str(uuid.uuid4())[:8]}" return "r" + uuid.uuid4().hex
@property @property
def _random_branch_id(self) -> str: def _random_branch_id(self) -> str:
@ -145,7 +145,7 @@ class SparkSession:
@property @property
def _random_id(self) -> str: def _random_id(self) -> str:
id = f"a{str(uuid.uuid4())[:8]}" id = self._random_name
self.known_ids.add(id) self.known_ids.add(id)
return id return id

View file

@ -1,3 +1,64 @@
"""
## Dialects
One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a
"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`,
`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming
SQL code.
However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such
example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to
override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new
dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class),
in order to implement the target behavior.
### Implementing a custom Dialect
Consider the following example:
```python
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.tokens import Tokenizer, TokenType
class Custom(Dialect):
class Tokenizer(Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
}
class Generator(Generator):
TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
TYPE_MAPPING = {
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.TEXT: "STRING",
}
```
This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
----
"""
from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.databricks import Databricks

View file

@ -124,7 +124,6 @@ class BigQuery(Dialect):
"FLOAT64": TokenType.DOUBLE, "FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT, "INT64": TokenType.BIGINT,
"NOT DETERMINISTIC": TokenType.VOLATILE, "NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL, "UNKNOWN": TokenType.NULL,
} }
KEYWORDS.pop("DIV") KEYWORDS.pop("DIV")

View file

@ -73,13 +73,8 @@ class ClickHouse(Dialect):
return this return this
def _parse_position(self) -> exp.Expression: def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
this = super()._parse_position() return super()._parse_position(haystack_first=True)
# clickhouse position args are swapped
substr = this.this
this.args["this"] = this.args.get("substr")
this.args["substr"] = substr
return this
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/ # https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression: def _parse_cte(self) -> exp.Expression:

View file

@ -124,6 +124,8 @@ class MySQL(Dialect):
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT, "MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT, "LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"LONGBLOB": TokenType.LONGBLOB,
"START": TokenType.BEGIN, "START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR, "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER, "_ARMSCII8": TokenType.INTRODUCER,
@ -459,6 +461,8 @@ class MySQL(Dialect):
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()

View file

@ -194,7 +194,8 @@ class Snowflake(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY, "EXCLUDE": TokenType.EXCEPT,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@ -232,6 +233,11 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
} }
STAR_MAPPING = {
"except": "EXCLUDE",
"replace": "RENAME",
}
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
exp.PartitionedByProperty, exp.PartitionedByProperty,
exp.ReturnsProperty, exp.ReturnsProperty,

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
@ -251,6 +252,7 @@ class TSQL(Dialect):
"NTEXT": TokenType.TEXT, "NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT, "NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND, "PRINT": TokenType.COMMAND,
"PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT, "REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION, "ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME, "SMALLDATETIME": TokenType.DATETIME,
@ -263,6 +265,11 @@ class TSQL(Dialect):
"XML": TokenType.XML, "XML": TokenType.XML,
} }
# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("@")
SINGLE_TOKENS.pop("#")
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore **parser.Parser.FUNCTIONS, # type: ignore
@ -293,26 +300,82 @@ class TSQL(Dialect):
DataType.Type.NCHAR, DataType.Type.NCHAR,
} }
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} TokenType.TABLE,
*parser.Parser.TYPE_TOKENS, # type: ignore
}
def _parse_convert(self, strict): STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.END: lambda self: self._parse_command(),
}
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
if self._match_text_seq("AS", "OF"):
system_time = self.expression(
exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
)
elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
kind = self._prev.text
this = self._parse_bitwise()
self._match_texts(("TO", "AND"))
expression = self._parse_bitwise()
system_time = self.expression(
exp.SystemTime, this=this, expression=expression, kind=kind
)
elif self._match_text_seq("CONTAINED", "IN"):
args = self._parse_wrapped_csv(self._parse_bitwise)
system_time = self.expression(
exp.SystemTime,
this=seq_get(args, 0),
expression=seq_get(args, 1),
kind="CONTAINED IN",
)
elif self._match(TokenType.ALL):
system_time = self.expression(exp.SystemTime, kind="ALL")
else:
system_time = None
self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
return system_time
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
def _parse_returns(self) -> exp.Expression:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
returns.set("table", table)
return returns
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to = self._parse_types() to = self._parse_types()
self._match(TokenType.COMMA) self._match(TokenType.COMMA)
this = self._parse_conjunction() this = self._parse_conjunction()
if not to or not this:
return None
# Retrieve length of datatype and override to default if not specified # Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable # Check whether a conversion with format is applicable
if self._match(TokenType.COMMA): if self._match(TokenType.COMMA):
format_val = self._parse_number().name format_val = self._parse_number()
if format_val not in TSQL.convert_format_mapping: format_val_name = format_val.name if format_val else ""
if format_val_name not in TSQL.convert_format_mapping:
raise ValueError( raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val}" f"CONVERT function at T-SQL does not support format style {format_val_name}"
) )
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
# Check whether the convert entails a string to date format # Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE: if to.this == DataType.Type.DATE:
@ -333,6 +396,21 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement # Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
) -> t.Optional[exp.Expression]:
this = super()._parse_user_defined_function(kind=kind)
if (
kind == TokenType.FUNCTION
or isinstance(this, exp.UserDefinedFunction)
or self._match(TokenType.ALIAS, advance=False)
):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore **generator.Generator.TYPE_MAPPING, # type: ignore
@ -354,3 +432,27 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql, exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql, exp.GroupConcat: _string_agg_sql,
} }
TRANSFORMS.pop(exp.ReturnsProperty)
def systemtime_sql(self, expression: exp.SystemTime) -> str:
kind = expression.args["kind"]
if kind == "ALL":
return "FOR SYSTEM_TIME ALL"
start = self.sql(expression, "this")
if kind == "AS OF":
return f"FOR SYSTEM_TIME AS OF {start}"
end = self.sql(expression, "expression")
if kind == "FROM":
return f"FOR SYSTEM_TIME FROM {start} TO {end}"
if kind == "BETWEEN":
return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
table = f"{table} " if table else ""
return f"RETURNS {table}{self.sql(expression, 'this')}"

View file

@ -1,5 +1,6 @@
""" """
.. include:: ../posts/sql_diff.md .. include:: ../posts/sql_diff.md
----
""" """
from __future__ import annotations from __future__ import annotations
@ -75,12 +76,13 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
] ]
Args: Args:
source (sqlglot.Expression): the source expression. source: the source expression.
target (sqlglot.Expression): the target expression against which the diff should be calculated. target: the target expression against which the diff should be calculated.
Returns: Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
This list represents a sequence of steps needed to transform the source expression tree into the target one. target expression trees. This list represents a sequence of steps needed to transform the source
expression tree into the target one.
""" """
return ChangeDistiller().diff(source.copy(), target.copy()) return ChangeDistiller().diff(source.copy(), target.copy())
@ -258,7 +260,7 @@ class ChangeDistiller:
return bigram_histo return bigram_histo
def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]: def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False has_child_exprs = False
for a in expression.args.values(): for a in expression.args.values():

View file

@ -63,7 +63,7 @@ class Context:
reader = table[i] reader = table[i]
yield reader, self yield reader, self
def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]: def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
for reader in self.tables[table]: for reader in self.tables[table]:

View file

@ -1,5 +1,12 @@
""" """
.. include:: ../pdoc/docs/expressions.md ## Expressions
Every AST node in SQLGlot is represented by a subclass of `Expression`.
This module contains the implementation of all supported `Expression` types. Additionally,
it exposes a number of helper functions, which are mainly used to programmatically build
SQL expressions, such as `sqlglot.expressions.select`.
----
""" """
from __future__ import annotations from __future__ import annotations
@ -27,35 +34,66 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import Dialect
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class _Expression(type): class _Expression(type):
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs) klass = super().__new__(cls, clsname, bases, attrs)
# When an Expression class is created, its key is automatically set to be
# the lowercase version of the class' name.
klass.key = clsname.lower() klass.key = clsname.lower()
# This is so that docstrings are not inherited in pdoc
klass.__doc__ = klass.__doc__ or ""
return klass return klass
class Expression(metaclass=_Expression): class Expression(metaclass=_Expression):
""" """
The base class for all expressions in a syntax tree. The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
context, such as its child expressions, their names (arg keys), and whether a given child expression
is optional or not.
Attributes: Attributes:
arg_types (dict): determines arguments supported by this expression. key: a unique key for each class in the Expression hierarchy. This is useful for hashing
The key in a dictionary defines a unique key of an argument using and representing expressions as strings.
which the argument's value can be retrieved. The value is a boolean arg_types: determines what arguments (child nodes) are supported by an expression. It
flag which indicates whether the argument's value is required (True) maps arg keys to booleans that indicate whether the corresponding args are optional.
or optional (False).
Example:
>>> class Foo(Expression):
... arg_types = {"this": True, "expression": False}
The above definition informs us that Foo is an Expression that requires an argument called
"this" and may also optionally receive an argument called "expression".
Args:
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
""" """
key = "Expression" key = "expression"
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type") __slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args): def __init__(self, **args: t.Any):
self.args = args self.args: t.Dict[str, t.Any] = args
self.parent = None self.parent: t.Optional[Expression] = None
self.arg_key = None self.arg_key: t.Optional[str] = None
self.comments = None self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
@ -76,17 +114,30 @@ class Expression(metaclass=_Expression):
@property @property
def this(self): def this(self):
"""
Retrieves the argument with key "this".
"""
return self.args.get("this") return self.args.get("this")
@property @property
def expression(self): def expression(self):
"""
Retrieves the argument with key "expression".
"""
return self.args.get("expression") return self.args.get("expression")
@property @property
def expressions(self): def expressions(self):
"""
Retrieves the argument with key "expressions".
"""
return self.args.get("expressions") or [] return self.args.get("expressions") or []
def text(self, key): def text(self, key):
"""
Returns a textual representation of the argument corresponding to "key". This can only be used
for args that are strings or leaf Expression instances, such as identifiers and literals.
"""
field = self.args.get(key) field = self.args.get(key)
if isinstance(field, str): if isinstance(field, str):
return field return field
@ -96,14 +147,23 @@ class Expression(metaclass=_Expression):
@property @property
def is_string(self): def is_string(self):
"""
Checks whether a Literal expression is a string.
"""
return isinstance(self, Literal) and self.args["is_string"] return isinstance(self, Literal) and self.args["is_string"]
@property @property
def is_number(self): def is_number(self):
"""
Checks whether a Literal expression is a number.
"""
return isinstance(self, Literal) and not self.args["is_string"] return isinstance(self, Literal) and not self.args["is_string"]
@property @property
def is_int(self): def is_int(self):
"""
Checks whether a Literal expression is an integer.
"""
if self.is_number: if self.is_number:
try: try:
int(self.name) int(self.name)
@ -114,6 +174,9 @@ class Expression(metaclass=_Expression):
@property @property
def alias(self): def alias(self):
"""
Returns the alias of the expression, or an empty string if it's not aliased.
"""
if isinstance(self.args.get("alias"), TableAlias): if isinstance(self.args.get("alias"), TableAlias):
return self.args["alias"].name return self.args["alias"].name
return self.text("alias") return self.text("alias")
@ -128,6 +191,24 @@ class Expression(metaclass=_Expression):
return "NULL" return "NULL"
return self.alias or self.name return self.alias or self.name
@property
def output_name(self):
"""
Name of the output column if this expression is a selection.
If the Expression has no output name, an empty string is returned.
Example:
>>> from sqlglot import parse_one
>>> parse_one("SELECT a").expressions[0].output_name
'a'
>>> parse_one("SELECT b AS c").expressions[0].output_name
'c'
>>> parse_one("SELECT 1 + 2").expressions[0].output_name
''
"""
return ""
@property @property
def type(self) -> t.Optional[DataType]: def type(self) -> t.Optional[DataType]:
return self._type return self._type
@ -145,6 +226,9 @@ class Expression(metaclass=_Expression):
return copy return copy
def copy(self): def copy(self):
"""
Returns a deep copy of the expression.
"""
new = deepcopy(self) new = deepcopy(self)
for item, parent, _ in new.bfs(): for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent: if isinstance(item, Expression) and parent:
@ -169,7 +253,7 @@ class Expression(metaclass=_Expression):
Sets `arg_key` to `value`. Sets `arg_key` to `value`.
Args: Args:
arg_key (str): name of the expression arg arg_key (str): name of the expression arg.
value: value to set the arg to. value: value to set the arg to.
""" """
self.args[arg_key] = value self.args[arg_key] = value
@ -203,8 +287,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match. expression_types (type): the expression type(s) to match.
Returns: Returns:
the node which matches the criteria or None if no node matching The node which matches the criteria or None if no such node was found.
the criteria was found.
""" """
return next(self.find_all(*expression_types, bfs=bfs), None) return next(self.find_all(*expression_types, bfs=bfs), None)
@ -217,7 +300,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match. expression_types (type): the expression type(s) to match.
Returns: Returns:
the generator object. The generator object.
""" """
for expression, _, _ in self.walk(bfs=bfs): for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types): if isinstance(expression, expression_types):
@ -231,7 +314,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match. expression_types (type): the expression type(s) to match.
Returns: Returns:
the parent node The parent node.
""" """
ancestor = self.parent ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types): while ancestor and not isinstance(ancestor, expression_types):
@ -269,7 +352,7 @@ class Expression(metaclass=_Expression):
the DFS (Depth-first) order. the DFS (Depth-first) order.
Returns: Returns:
the generator object. The generator object.
""" """
parent = parent or self.parent parent = parent or self.parent
yield self, parent, key yield self, parent, key
@ -287,7 +370,7 @@ class Expression(metaclass=_Expression):
the BFS (Breadth-first) order. the BFS (Breadth-first) order.
Returns: Returns:
the generator object. The generator object.
""" """
queue = deque([(self, self.parent, None)]) queue = deque([(self, self.parent, None)])
@ -341,32 +424,33 @@ class Expression(metaclass=_Expression):
return self.sql() return self.sql()
def __repr__(self): def __repr__(self):
return self.to_s() return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str: def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
""" """
Returns SQL string representation of this tree. Returns SQL string representation of this tree.
Args Args:
dialect (str): the dialect of the output SQL string dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql").
(eg. "spark", "hive", "presto", "mysql"). opts: other `sqlglot.generator.Generator` options.
opts (dict): other :class:`~sqlglot.generator.Generator` options.
Returns Returns:
the SQL string. The SQL string.
""" """
from sqlglot.dialects import Dialect from sqlglot.dialects import Dialect
return Dialect.get_or_raise(dialect)().generate(self, **opts) return Dialect.get_or_raise(dialect)().generate(self, **opts)
def to_s(self, hide_missing: bool = True, level: int = 0) -> str: def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n" indent = "" if not level else "\n"
indent += "".join([" "] * level) indent += "".join([" "] * level)
left = f"({self.key.upper()} " left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = { args: t.Dict[str, t.Any] = {
k: ", ".join( k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_collection(vs) for v in ensure_collection(vs)
if v is not None if v is not None
) )
@ -394,7 +478,7 @@ class Expression(metaclass=_Expression):
modified in place. modified in place.
Returns: Returns:
the transformed tree. The transformed tree.
""" """
node = self.copy() if copy else self node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs) new_node = fun(node, *args, **kwargs)
@ -423,8 +507,8 @@ class Expression(metaclass=_Expression):
Args: Args:
expression (Expression|None): new node expression (Expression|None): new node
Returns : Returns:
the new expression or expressions The new expression or expressions.
""" """
if not self.parent: if not self.parent:
return expression return expression
@ -458,6 +542,40 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_) assert isinstance(self, type_)
return self return self
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
"""
Checks if this expression is valid (e.g. all mandatory args are set).
Args:
args: a sequence of values that were used to instantiate a Func expression. This is used
to check that the provided arguments don't exceed the function argument limit.
Returns:
A list of error messages for all possible errors that were found.
"""
errors: t.List[str] = []
for k in self.args:
if k not in self.arg_types:
errors.append(f"Unexpected keyword: '{k}' for {self.__class__}")
for k, mandatory in self.arg_types.items():
v = self.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)):
errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
if (
args
and isinstance(self, Func)
and len(args) > len(self.arg_types)
and not self.is_var_len_args
):
errors.append(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(self.arg_types)})"
)
return errors
def dump(self): def dump(self):
""" """
Dump this Expression to a JSON-serializable dict. Dump this Expression to a JSON-serializable dict.
@ -552,7 +670,7 @@ class DerivedTable(Expression):
@property @property
def named_selects(self): def named_selects(self):
return [select.alias_or_name for select in self.selects] return [select.output_name for select in self.selects]
class Unionable(Expression): class Unionable(Expression):
@ -654,6 +772,7 @@ class Create(Expression):
"no_primary_index": False, "no_primary_index": False,
"indexes": False, "indexes": False,
"no_schema_binding": False, "no_schema_binding": False,
"begin": False,
} }
@ -696,7 +815,7 @@ class Show(Expression):
class UserDefinedFunction(Expression): class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False} arg_types = {"this": True, "expressions": False, "wrapped": False}
class UserDefinedFunctionKwarg(Expression): class UserDefinedFunctionKwarg(Expression):
@ -750,6 +869,10 @@ class Column(Condition):
def table(self): def table(self):
return self.text("table") return self.text("table")
@property
def output_name(self):
return self.name
class ColumnDef(Expression): class ColumnDef(Expression):
arg_types = { arg_types = {
@ -865,6 +988,10 @@ class ForeignKey(Expression):
} }
class PrimaryKey(Expression):
arg_types = {"expressions": True, "options": False}
class Unique(Expression): class Unique(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -904,6 +1031,10 @@ class Identifier(Expression):
def __hash__(self): def __hash__(self):
return hash((self.key, self.this.lower())) return hash((self.key, self.this.lower()))
@property
def output_name(self):
return self.name
class Index(Expression): class Index(Expression):
arg_types = { arg_types = {
@ -996,6 +1127,10 @@ class Literal(Condition):
def string(cls, string) -> Literal: def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True) return cls(this=str(string), is_string=True)
@property
def output_name(self):
return self.name
class Join(Expression): class Join(Expression):
arg_types = { arg_types = {
@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property):
class ReturnsProperty(Property): class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False} arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property): class LanguageProperty(Property):
@ -1262,8 +1397,13 @@ class Qualify(Expression):
pass pass
# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
class Return(Expression):
pass
class Reference(Expression): class Reference(Expression):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": False, "options": False}
class Tuple(Expression): class Tuple(Expression):
@ -1397,6 +1537,16 @@ class Table(Expression):
"joins": False, "joins": False,
"pivots": False, "pivots": False,
"hints": False, "hints": False,
"system_time": False,
}
# See the TSQL "Querying data in a system-versioned temporal table" page
class SystemTime(Expression):
arg_types = {
"this": False,
"expression": False,
"kind": True,
} }
@ -2027,7 +2177,7 @@ class Select(Subqueryable):
@property @property
def named_selects(self) -> t.List[str]: def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name] return [e.output_name for e in self.expressions if e.alias_or_name]
@property @property
def selects(self) -> t.List[Expression]: def selects(self) -> t.List[Expression]:
@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this expression = expression.this
return expression return expression
@property
def output_name(self):
return self.alias
class TableSample(Expression): class TableSample(Expression):
arg_types = { arg_types = {
@ -2066,6 +2220,16 @@ class TableSample(Expression):
} }
class Tag(Expression):
"""Tags are used for generating arbitrary sql like SELECT <span>x</span>."""
arg_types = {
"this": False,
"prefix": False,
"postfix": False,
}
class Pivot(Expression): class Pivot(Expression):
arg_types = { arg_types = {
"this": False, "this": False,
@ -2106,6 +2270,10 @@ class Star(Expression):
def name(self): def name(self):
return "*" return "*"
@property
def output_name(self):
return self.name
class Parameter(Expression): class Parameter(Expression):
pass pass
@ -2143,6 +2311,8 @@ class DataType(Expression):
TEXT = auto() TEXT = auto()
MEDIUMTEXT = auto() MEDIUMTEXT = auto()
LONGTEXT = auto() LONGTEXT = auto()
MEDIUMBLOB = auto()
LONGBLOB = auto()
BINARY = auto() BINARY = auto()
VARBINARY = auto() VARBINARY = auto()
INT = auto() INT = auto()
@ -2282,11 +2452,11 @@ class Rollback(Expression):
class AlterTable(Expression): class AlterTable(Expression):
arg_types = { arg_types = {"this": True, "actions": True, "exists": False}
"this": True,
"actions": True,
"exists": False, class AddConstraint(Expression):
} arg_types = {"this": False, "expression": False, "enforced": False}
# Binary expressions like (ADD a b) # Binary expressions like (ADD a b)
@ -2456,6 +2626,10 @@ class Neg(Unary):
class Alias(Expression): class Alias(Expression):
arg_types = {"this": True, "alias": False} arg_types = {"this": True, "alias": False}
@property
def output_name(self):
return self.alias
class Aliases(Expression): class Aliases(Expression):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": True}
@ -2523,16 +2697,13 @@ class Func(Condition):
""" """
The base class for all function expressions. The base class for all function expressions.
Attributes Attributes:
is_var_len_args (bool): if set to True the last argument defined in is_var_len_args (bool): if set to True the last argument defined in arg_types will be
arg_types will be treated as a variable length argument and the treated as a variable length argument and the argument's value will be stored as a list.
argument's value will be stored as a list. _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
_sql_names (list): determines the SQL name (1st item in the list) and for this function expression. These values are used to map this node to a name during parsing
aliases (subsequent items) for this function expression. These as well as to provide the function's name during SQL string generation. By default the SQL
values are used to map this node to a name during parsing as well name is set to the expression's class name transformed to snake case.
as to provide the function's name during SQL string generation. By
default the SQL name is set to the expression's class name transformed
to snake case.
""" """
is_var_len_args = False is_var_len_args = False
@ -2558,7 +2729,7 @@ class Func(Condition):
raise NotImplementedError( raise NotImplementedError(
"SQL name is only supported by concrete function implementations" "SQL name is only supported by concrete function implementations"
) )
if not hasattr(cls, "_sql_names"): if "_sql_names" not in cls.__dict__:
cls._sql_names = [camel_to_snake_case(cls.__name__)] cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names return cls._sql_names
@ -2658,6 +2829,10 @@ class Cast(Func):
def to(self): def to(self):
return self.args["to"] return self.args["to"]
@property
def output_name(self):
return self.name
class Collate(Binary): class Collate(Binary):
pass pass
@ -2956,6 +3131,14 @@ class Pow(Func):
_sql_names = ["POWER", "POW"] _sql_names = ["POWER", "POW"]
class PercentileCont(AggFunc):
pass
class PercentileDisc(AggFunc):
pass
class Quantile(AggFunc): class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True} arg_types = {"this": True, "quantile": True}
@ -3213,12 +3396,13 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
def maybe_parse( def maybe_parse(
sql_or_expression, sql_or_expression: str | Expression,
*, *,
into=None, into: t.Optional[IntoType] = None,
dialect=None, dialect: t.Optional[str] = None,
prefix=None, prefix: t.Optional[str] = None,
**opts, **opts,
) -> Expression: ) -> Expression:
"""Gracefully handle a possible string or expression. """Gracefully handle a possible string or expression.
@ -3230,11 +3414,11 @@ def maybe_parse(
(IDENTIFIER this: x, quoted: False) (IDENTIFIER this: x, quoted: False)
Args: Args:
sql_or_expression (str | Expression): the SQL code string or an expression sql_or_expression: the SQL code string or an expression
into (Expression): the SQLGlot Expression to parse into into: the SQLGlot Expression to parse into
dialect (str): the dialect used to parse the input expressions (in the case that an dialect: the dialect used to parse the input expressions (in the case that an
input expression is a SQL string). input expression is a SQL string).
prefix (str): a string to prefix the sql with before it gets parsed prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space) (automatically includes a space)
**opts: other options to use to parse the input expressions (again, in the case **opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string). that an input expression is a SQL string).
@ -3993,7 +4177,7 @@ def table_name(table) -> str:
"""Get the full name of a table as a string. """Get the full name of a table as a string.
Args: Args:
table (exp.Table | str): Table expression node or string. table (exp.Table | str): table expression node or string.
Examples: Examples:
>>> from sqlglot import exp, parse_one >>> from sqlglot import exp, parse_one
@ -4001,7 +4185,7 @@ def table_name(table) -> str:
'a.b.c' 'a.b.c'
Returns: Returns:
str: the table name The table name.
""" """
table = maybe_parse(table, into=Table) table = maybe_parse(table, into=Table)
@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping. """Replace all tables in expression according to the mapping.
Args: Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced expression (sqlglot.Expression): expression node to be transformed and replaced.
mapping (Dict[str, str]): Mapping of table names mapping (Dict[str, str]): mapping of table names.
Examples: Examples:
>>> from sqlglot import exp, parse_one >>> from sqlglot import exp, parse_one
@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping):
'SELECT * FROM c' 'SELECT * FROM c'
Returns: Returns:
The mapped expression The mapped expression.
""" """
def _replace_tables(node): def _replace_tables(node):
@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs):
"""Replace placeholders in an expression. """Replace placeholders in an expression.
Args: Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced expression (sqlglot.Expression): expression node to be transformed and replaced.
args: Positional names that will substitute unnamed placeholders in the given order args: positional names that will substitute unnamed placeholders in the given order.
kwargs: Keyword arguments that will substitute named placeholders kwargs: keyword arguments that will substitute named placeholders.
Examples: Examples:
>>> from sqlglot import exp, parse_one >>> from sqlglot import exp, parse_one
@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs):
'SELECT * FROM foo WHERE a = b' 'SELECT * FROM foo WHERE a = b'
Returns: Returns:
The mapped expression The mapped expression.
""" """
def _replace_placeholders(node, args, **kwargs): def _replace_placeholders(node, args, **kwargs):
@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs) return expression.transform(_replace_placeholders, iter(args), **kwargs)
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
>>> from sqlglot import parse_one
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
copy: Whether or not to copy the expression during transformation. Defaults to True.
Returns:
The transformed expression.
"""
def _expand(node: Expression):
if isinstance(node, Table):
name = table_name(node)
source = sources.get(name)
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
return subquery
return node
return expression.transform(_expand, copy=copy)
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
"""
Returns a Func expression.
Examples:
>>> func("abs", 5).sql()
'ABS(5)'
>>> func("cast", this=5, to=DataType.build("DOUBLE")).sql()
'CAST(5 AS DOUBLE)'
Args:
name: the name of the function to build.
args: the args used to instantiate the function of interest.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
Note:
The arguments `args` and `kwargs` are mutually exclusive.
Returns:
An instance of the function of interest, or an anonymous function, if `name` doesn't
correspond to an existing `sqlglot.expressions.Func` class.
"""
if args and kwargs:
raise ValueError("Can't use both args and kwargs to instantiate a function.")
from sqlglot.dialects.dialect import Dialect
args = tuple(convert(arg) for arg in args)
kwargs = {key: convert(value) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())
if from_args_list:
function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore
else:
kwargs = kwargs or {"expressions": args}
function = Anonymous(this=name, **kwargs)
for error_message in function.error_messages(args):
raise ValueError(error_message)
return function
def true(): def true():
"""
Returns a true Boolean expression.
"""
return Boolean(this=True) return Boolean(this=True)
def false(): def false():
"""
Returns a false Boolean expression.
"""
return Boolean(this=False) return Boolean(this=False)
def null(): def null():
"""
Returns a Null expression.
"""
return Null() return Null()

View file

@ -16,7 +16,7 @@ class Generator:
""" """
Generator interprets the given syntax tree and produces a SQL string as an output. Generator interprets the given syntax tree and produces a SQL string as an output.
Args Args:
time_mapping (dict): the dictionary of custom time mappings in which the key time_mapping (dict): the dictionary of custom time mappings in which the key
represents a python time format and the output the target time format represents a python time format and the output the target time format
time_trie (trie): a trie of the time_mapping keys time_trie (trie): a trie of the time_mapping keys
@ -84,6 +84,13 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT", exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT", exp.DataType.Type.LONGTEXT: "TEXT",
exp.DataType.Type.MEDIUMBLOB: "BLOB",
exp.DataType.Type.LONGBLOB: "BLOB",
}
STAR_MAPPING = {
"except": "EXCEPT",
"replace": "REPLACE",
} }
TOKEN_MAPPING: t.Dict[TokenType, str] = {} TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@ -106,6 +113,8 @@ class Generator:
exp.TableFormatProperty, exp.TableFormatProperty,
} }
WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint)
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@ -241,15 +250,17 @@ class Generator:
return sql return sql
sep = "\n" if self.pretty else " " sep = "\n" if self.pretty else " "
comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment) comments_sql = sep.join(
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
)
if not comments: if not comments_sql:
return sql return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS): if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments}{self.sep()}{sql}" return f"{comments_sql}{self.sep()}{sql}"
return f"{sql} {comments}" return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str: def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent( this_sql = self.indent(
@ -433,8 +444,9 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str: def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper() kind = self.sql(expression, "kind").upper()
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else "" expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else "" temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = ( transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@ -741,12 +753,14 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="") laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="") joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="") pivots = self.expressions(expression, key="pivots", sep="")
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
if alias and pivots: if alias and pivots:
pivots = f"{pivots}{alias}" pivots = f"{pivots}{alias}"
alias = "" alias = ""
return f"{table}{alias}{hints}{laterals}{joins}{pivots}" return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str: def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias: if self.alias_post_tablesample and expression.this.alias:
@ -1009,9 +1023,9 @@ class Generator:
def star_sql(self, expression: exp.Star) -> str: def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True) except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True) replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}" return f"*{except_}{replace}"
def structkwarg_sql(self, expression: exp.StructKwarg) -> str: def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
@ -1193,6 +1207,12 @@ class Generator:
update = f" ON UPDATE {update}" if update else "" update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
options = self.expressions(expression, "options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
def unique_sql(self, expression: exp.Unique) -> str: def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions") columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})" return f"UNIQUE ({columns})"
@ -1229,10 +1249,16 @@ class Generator:
unit = f" {unit}" if unit else "" unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}" return f"INTERVAL{this}{unit}"
def return_sql(self, expression: exp.Return) -> str:
return f"RETURN {self.sql(expression, 'this')}"
def reference_sql(self, expression: exp.Reference) -> str: def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True) expressions = self.expressions(expression, flat=True)
return f"REFERENCES {this}({expressions})" expressions = f"({expressions})" if expressions else ""
options = self.expressions(expression, "options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str: def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions) args = self.format_args(*expression.expressions)
@ -1362,7 +1388,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop): elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions") actions = self.expressions(expression, "actions")
elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)): elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION):
actions = self.sql(actions[0]) actions = self.sql(actions[0])
else: else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
@ -1370,6 +1396,17 @@ class Generator:
exists = " IF EXISTS" if expression.args.get("exists") else "" exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
this = self.sql(expression, "this")
expression_ = self.sql(expression, "expression")
add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
enforced = expression.args.get("enforced")
if enforced is not None:
return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
return f"{add_constraint} {expression_}"
def distinct_sql(self, expression: exp.Distinct) -> str: def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True) this = self.expressions(expression, flat=True)
this = f" {this}" if this else "" this = f" {this}" if this else ""
@ -1550,13 +1587,19 @@ class Generator:
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
) )
def tag_sql(self, expression: exp.Tag) -> str:
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
def token_sql(self, token_type: TokenType) -> str: def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name) return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression)) expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})" expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
return f"{this}{expressions}"
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")

View file

@ -332,7 +332,7 @@ def is_iterable(value: t.Any) -> bool:
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]: def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
""" """
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type `str` and `bytes` are not regarded as iterables. type `str` and `bytes` are not regarded as iterables.

228
sqlglot/lineage.py Normal file
View file

@ -0,0 +1,228 @@
from __future__ import annotations
import json
import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
@dataclass(frozen=True)
class Node:
name: str
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
def walk(self) -> t.Iterator[Node]:
yield self
for d in self.downstream:
if isinstance(d, Node):
yield from d.walk()
else:
yield d
def to_html(self, **opts) -> LineageHTML:
return LineageHTML(self, **opts)
def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
dialect: t.Optional[str] = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
Args:
column: The column to build the lineage for.
sql: The SQL string or expression.
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
rules: Optimizer rules to apply, by default only qualifying tables and columns.
dialect: The dialect of input SQL.
Returns:
A lineage node.
"""
expression = maybe_parse(sql, dialect=dialect)
if sources:
expression = exp.expand(
expression,
{
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
for k, v in sources.items()
},
)
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
tables: t.Dict[str, Node] = {}
def to_node(
column_name: str,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
) -> Node:
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
column_name,
scope=scope,
scope_name=scope_name,
upstream=upstream,
)
return node
select = next(select for select in scope.selects if select.alias_or_name == column_name)
source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules)
select = source.selects[0]
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
source=source,
expression=select,
)
if upstream:
upstream.downstream.append(node)
for c in set(select.find_all(exp.Column)):
table = c.table
source = scope.sources[table]
if isinstance(source, Scope):
to_node(
c.name,
scope=source,
scope_name=table,
upstream=node,
)
else:
if table not in tables:
tables[table] = Node(name=table, source=source, expression=source)
node.downstream.append(tables[table])
return node
return to_node(column if isinstance(column, str) else column.name, scope)
class LineageHTML:
"""Node to HTML generator using vis.js.
https://visjs.github.io/vis-network/docs/network/
"""
def __init__(
self,
node: Node,
dialect: t.Optional[str] = None,
imports: bool = True,
**opts: t.Any,
):
self.node = node
self.imports = imports
self.options = {
"height": "500px",
"width": "100%",
"layout": {
"hierarchical": {
"enabled": True,
"nodeSpacing": 200,
"sortMethod": "directed",
},
},
"interaction": {
"dragNodes": False,
"selectable": False,
},
"physics": {
"enabled": False,
},
"edges": {
"arrows": "to",
},
"nodes": {
"font": "20px monaco",
"shape": "box",
"widthConstraint": {
"maximum": 300,
},
},
**opts,
}
self.nodes = {}
self.edges = []
for node in node.walk():
if isinstance(node.expression, exp.Table):
label = f"FROM {node.expression.this}"
title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
group = 1
else:
label = node.expression.sql(pretty=True, dialect=dialect)
source = node.source.transform(
lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
if n is node.expression
else n,
copy=False,
).sql(pretty=True, dialect=dialect)
title = f"<pre>{source}</pre>"
group = 0
node_id = id(node)
self.nodes[node_id] = {
"id": node_id,
"label": label,
"title": title,
"group": group,
}
for d in node.downstream:
self.edges.append({"from": node_id, "to": id(d)})
def __str__(self):
nodes = json.dumps(list(self.nodes.values()))
edges = json.dumps(self.edges)
options = json.dumps(self.options)
imports = (
"""<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
<script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
if self.imports
else ""
)
return f"""<div>
<div id="sqlglot-lineage"></div>
{imports}
<script type="text/javascript">
var nodes = new vis.DataSet({nodes})
nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
new vis.Network(
document.getElementById("sqlglot-lineage"),
{{
nodes: nodes,
edges: new vis.DataSet({edges})
}},
{options},
)
</script>
</div>"""
def _repr_html_(self) -> str:
return self.__str__()

View file

@ -1 +1,2 @@
from sqlglot.optimizer.optimizer import RULES, optimize from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope

View file

@ -1,15 +1,18 @@
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.errors import OptimizeError from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import traverse_scope
from sqlglot.schema import ensure_schema
def isolate_table_selects(expression): def isolate_table_selects(expression, schema=None):
schema = ensure_schema(schema)
for scope in traverse_scope(expression): for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1: if len(scope.selected_sources) == 1:
continue continue
for (_, source) in scope.selected_sources.values(): for (_, source) in scope.selected_sources.values():
if not isinstance(source, exp.Table): if not isinstance(source, exp.Table) or not schema.column_names(source):
continue continue
if not source.alias: if not source.alias:

View file

@ -1,7 +1,8 @@
import itertools import itertools
import typing as t
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.errors import OptimizeError, SchemaError from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema from sqlglot.schema import ensure_schema
@ -190,20 +191,15 @@ def _qualify_columns(scope, resolver):
column_table = column.table column_table = column.table
column_name = column.name column_name = column.name
if ( if column_table and column_table in scope.sources:
column_table source_columns = resolver.get_source_columns(column_table)
and column_table in scope.sources if source_columns and column_name not in source_columns:
and column_name not in resolver.get_source_columns(column_table)
):
raise OptimizeError(f"Unknown column: {column_name}") raise OptimizeError(f"Unknown column: {column_name}")
if not column_table: if not column_table:
column_table = resolver.get_table(column_name) column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf: if not scope.is_subquery and not scope.is_udtf:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if column_table is None: if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}") raise OptimizeError(f"Ambiguous column: {column_name}")
@ -265,6 +261,10 @@ def _expand_stars(scope, resolver):
if table not in scope.sources: if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}") raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True) columns = resolver.get_source_columns(table, only_visible=True)
if not columns:
raise OptimizeError(
f"Table has no schema/columns. Cannot expand star for table: {table}."
)
table_id = id(table) table_id = id(table)
for name in columns: for name in columns:
if name not in except_columns.get(table_id, set()): if name not in except_columns.get(table_id, set()):
@ -306,16 +306,11 @@ def _qualify_outputs(scope):
for i, (selection, aliased_column) in enumerate( for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list) itertools.zip_longest(scope.selects, scope.outer_column_list)
): ):
if isinstance(selection, exp.Column): if isinstance(selection, exp.Subquery):
# convoluted setter because a simple selection.replace(alias) would require a copy if not selection.output_name:
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias): elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}") alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection) alias_.set("this", selection)
selection = alias_ selection = alias_
@ -346,20 +341,30 @@ class _Resolver:
self._unambiguous_columns = None self._unambiguous_columns = None
self._all_columns = None self._all_columns = None
def get_table(self, column_name): def get_table(self, column_name: str) -> t.Optional[str]:
""" """
Get the table for a column name. Get the table for a column name.
Args: Args:
column_name (str) column_name: The column name to find the table for.
Returns: Returns:
(str) table name The table name if it can be found/inferred.
""" """
if self._unambiguous_columns is None: if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns( self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns() self._get_all_source_columns()
) )
return self._unambiguous_columns.get(column_name)
table = self._unambiguous_columns.get(column_name)
if not table:
sources_without_schema = tuple(
source for source, columns in self._get_all_source_columns().items() if not columns
)
if len(sources_without_schema) == 1:
return sources_without_schema[0]
return table
@property @property
def all_columns(self): def all_columns(self):
@ -379,10 +384,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema # If referencing a table, return the columns from the schema
if isinstance(source, exp.Table): if isinstance(source, exp.Table):
try:
return self.schema.column_names(source, only_visible) return self.schema.column_names(source, only_visible)
except Exception as e:
raise SchemaError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values): if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names return source.expression.alias_column_names

View file

@ -230,7 +230,7 @@ class Scope:
column for scope in self.subquery_scopes for column in scope.external_columns column for scope in self.subquery_scopes for column in scope.external_columns
] ]
named_outputs = {e.alias_or_name for e in self.expression.expressions} named_selects = set(self.expression.named_selects)
self._columns = [] self._columns = []
for column in columns + external_columns: for column in columns + external_columns:
@ -238,7 +238,7 @@ class Scope:
if ( if (
not ancestor not ancestor
or column.table or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint)) or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
): ):
self._columns.append(column) self._columns.append(column)

View file

@ -40,22 +40,23 @@ class _Parser(type):
class Parser(metaclass=_Parser): class Parser(metaclass=_Parser):
""" """
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
and produces a parsed syntax tree. a parsed syntax tree.
Args Args:
error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE. error_level: the desired error level.
error_message_context (int): determines the amount of context to capture from Default: ErrorLevel.RAISE
a query string when displaying the error message (in number of characters). error_message_context: determines the amount of context to capture from a
query string when displaying the error message (in number of characters).
Default: 50. Default: 50.
index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
Default: 0 Default: 0
alias_post_tablesample (bool): If the table alias comes after tablesample alias_post_tablesample: If the table alias comes after tablesample.
Default: False Default: False
max_errors (int): Maximum number of error messages to include in a raised ParseError. max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE. This is only relevant if error_level is ErrorLevel.RAISE.
Default: 3 Default: 3
null_ordering (str): Indicates the default null ordering method to use if not explicitly set. null_ordering: Indicates the default null ordering method to use if not explicitly set.
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
Default: "nulls_are_small" Default: "nulls_are_small"
""" """
@ -109,6 +110,8 @@ class Parser(metaclass=_Parser):
TokenType.TEXT, TokenType.TEXT,
TokenType.MEDIUMTEXT, TokenType.MEDIUMTEXT,
TokenType.LONGTEXT, TokenType.LONGTEXT,
TokenType.MEDIUMBLOB,
TokenType.LONGBLOB,
TokenType.BINARY, TokenType.BINARY,
TokenType.VARBINARY, TokenType.VARBINARY,
TokenType.JSON, TokenType.JSON,
@ -176,6 +179,7 @@ class Parser(metaclass=_Parser):
TokenType.DIV, TokenType.DIV,
TokenType.DISTKEY, TokenType.DISTKEY,
TokenType.DISTSTYLE, TokenType.DISTSTYLE,
TokenType.END,
TokenType.EXECUTE, TokenType.EXECUTE,
TokenType.ENGINE, TokenType.ENGINE,
TokenType.ESCAPE, TokenType.ESCAPE,
@ -468,9 +472,6 @@ class Parser(metaclass=_Parser):
TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.PARAMETER: lambda self, _: self.expression(
exp.Parameter, this=self._parse_var() or self._parse_primary()
),
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
@ -479,6 +480,16 @@ class Parser(metaclass=_Parser):
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
} }
PLACEHOLDER_PARSERS = {
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self.expression(
exp.Parameter, this=self._parse_var() or self._parse_primary()
),
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match_set((TokenType.NUMBER, TokenType.VAR))
else None,
}
RANGE_PARSERS = { RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this), TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IN: lambda self, this: self._parse_in(this),
@ -601,8 +612,7 @@ class Parser(metaclass=_Parser):
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
# allows tables to have special tokens as prefixes ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
STRICT_CAST = True STRICT_CAST = True
@ -677,7 +687,7 @@ class Parser(metaclass=_Parser):
def parse_into( def parse_into(
self, self,
expression_types: str | exp.Expression | t.Collection[exp.Expression | str], expression_types: exp.IntoType,
raw_tokens: t.List[Token], raw_tokens: t.List[Token],
sql: t.Optional[str] = None, sql: t.Optional[str] = None,
) -> t.List[t.Optional[exp.Expression]]: ) -> t.List[t.Optional[exp.Expression]]:
@ -820,24 +830,8 @@ class Parser(metaclass=_Parser):
if self.error_level == ErrorLevel.IGNORE: if self.error_level == ErrorLevel.IGNORE:
return return
for k in expression.args: for error_message in expression.error_messages(args):
if k not in expression.arg_types: self.raise_error(error_message)
self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
for k, mandatory in expression.arg_types.items():
v = expression.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)):
self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
if (
args
and isinstance(expression, exp.Func)
and len(args) > len(expression.arg_types)
and not expression.is_var_len_args
):
self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})"
)
def _find_token(self, token: Token, sql: str) -> int: def _find_token(self, token: Token, sql: str) -> int:
line = 1 line = 1
@ -868,6 +862,9 @@ class Parser(metaclass=_Parser):
def _retreat(self, index: int) -> None: def _retreat(self, index: int) -> None:
self._advance(index - self._index) self._advance(index - self._index)
def _parse_command(self) -> exp.Expression:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
def _parse_statement(self) -> t.Optional[exp.Expression]: def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None: if self._curr is None:
return None return None
@ -876,11 +873,7 @@ class Parser(metaclass=_Parser):
return self.STATEMENT_PARSERS[self._prev.token_type](self) return self.STATEMENT_PARSERS[self._prev.token_type](self)
if self._match_set(Tokenizer.COMMANDS): if self._match_set(Tokenizer.COMMANDS):
return self.expression( return self._parse_command()
exp.Command,
this=self._prev.text,
expression=self._parse_string(),
)
expression = self._parse_expression() expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select() expression = self._parse_set_operations(expression) if expression else self._parse_select()
@ -942,12 +935,18 @@ class Parser(metaclass=_Parser):
no_primary_index = None no_primary_index = None
indexes = None indexes = None
no_schema_binding = None no_schema_binding = None
begin = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function() this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties() properties = self._parse_properties()
if self._match(TokenType.ALIAS): if self._match(TokenType.ALIAS):
expression = self._parse_select_or_expression() begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
expression = self._parse_statement()
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX: elif create_token.token_type == TokenType.INDEX:
this = self._parse_index() this = self._parse_index()
elif create_token.token_type in ( elif create_token.token_type in (
@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser):
no_primary_index=no_primary_index, no_primary_index=no_primary_index,
indexes=indexes, indexes=indexes,
no_schema_binding=no_schema_binding, no_schema_binding=no_schema_binding,
begin=begin,
) )
def _parse_property(self) -> t.Optional[exp.Expression]: def _parse_property(self) -> t.Optional[exp.Expression]:
@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT): if not self._match(TokenType.GT):
self.raise_error("Expecting >") self.raise_error("Expecting >")
else: else:
value = self._parse_schema(exp.Literal.string("TABLE")) value = self._parse_schema(exp.Var(this="TABLE"))
else: else:
value = self._parse_types() value = self._parse_types()
@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser):
return None return None
index = self._parse_id_var() index = self._parse_id_var()
columns = None columns = None
if self._curr and self._curr.token_type == TokenType.L_PAREN: if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_column) columns = self._parse_wrapped_csv(self._parse_column)
return self.expression( return self.expression(
exp.Index, exp.Index,
@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser):
amp=amp, amp=amp,
) )
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
catalog = None
db = None
table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
while self._match(TokenType.DOT):
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
else:
catalog = db
db = table
table = self._parse_id_var()
if not table:
self.raise_error(f"Expected table name but got {self._curr}")
return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
)
def _parse_table( def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]: ) -> t.Optional[exp.Expression]:
@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser):
if subquery: if subquery:
return subquery return subquery
catalog = None this = self._parse_table_parts(schema=schema)
db = None
table = (not schema and self._parse_function()) or self._parse_id_var(
any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
)
while self._match(TokenType.DOT):
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
else:
catalog = db
db = table
table = self._parse_id_var()
if not table:
self.raise_error(f"Expected table name but got {self._curr}")
this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
)
if schema: if schema:
return self._parse_schema(this=this) return self._parse_schema(this=this)
@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser):
expression, expression,
this=this, this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
expression=self._parse_select(nested=True), expression=self._parse_set_operations(self._parse_select(nested=True)),
) )
def _parse_expression(self) -> t.Optional[exp.Expression]: def _parse_expression(self) -> t.Optional[exp.Expression]:
@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this) self._match_r_paren(this)
return self._parse_window(this) return self._parse_window(this)
def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
) -> t.Optional[exp.Expression]:
this = self._parse_id_var() this = self._parse_id_var()
while self._match(TokenType.DOT): while self._match(TokenType.DOT):
@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_udf_kwarg) expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren() self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) return self.expression(
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
)
def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
literal = self._parse_primary() literal = self._parse_primary()
@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser):
or self._parse_column_def(self._parse_field(any_token=True)) or self._parse_column_def(self._parse_field(any_token=True))
) )
self._match_r_paren() self._match_r_paren()
if isinstance(this, exp.Literal):
this = this.name
return self.expression(exp.Schema, this=this, expressions=args) return self.expression(exp.Schema, this=this, expressions=args)
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser):
def _parse_unique(self) -> exp.Expression: def _parse_unique(self) -> exp.Expression:
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
def _parse_key_constraint_options(self) -> t.List[str]:
options = []
while True:
if not self._curr:
break
if self._match_text_seq("NOT", "ENFORCED"):
options.append("NOT ENFORCED")
elif self._match_text_seq("DEFERRABLE"):
options.append("DEFERRABLE")
elif self._match_text_seq("INITIALLY", "DEFERRED"):
options.append("INITIALLY DEFERRED")
elif self._match_text_seq("NORELY"):
options.append("NORELY")
elif self._match_text_seq("MATCH", "FULL"):
options.append("MATCH FULL")
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
options.append("ON UPDATE NO ACTION")
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
options.append("ON DELETE NO ACTION")
else:
break
return options
def _parse_references(self) -> t.Optional[exp.Expression]: def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES): if not self._match(TokenType.REFERENCES):
return None return None
return self.expression( expressions = None
exp.Reference, this = self._parse_id_var()
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(), if self._match(TokenType.L_PAREN, advance=False):
) expressions = self._parse_wrapped_id_vars()
options = self._parse_key_constraint_options()
return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
def _parse_foreign_key(self) -> exp.Expression: def _parse_foreign_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars() expressions = self._parse_wrapped_id_vars()
@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser):
options[kind] = action options[kind] = action
return self.expression( return self.expression(
exp.ForeignKey, exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
expressions=expressions,
reference=reference,
**options, # type: ignore
) )
def _parse_primary_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars()
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET): if not self._match(TokenType.L_BRACKET):
return this return this
@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=expression) order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict: bool) -> exp.Expression: def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to: t.Optional[exp.Expression] to: t.Optional[exp.Expression]
this = self._parse_column() this = self._parse_column()
@ -2641,20 +2672,26 @@ class Parser(metaclass=_Parser):
to = self._parse_types() to = self._parse_types()
else: else:
to = None to = None
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_position(self) -> exp.Expression: def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise) args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN): if self._match(TokenType.IN):
args.append(self._parse_bitwise()) return self.expression(
exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0)
this = exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
) )
if haystack_first:
haystack = seq_get(args, 0)
needle = seq_get(args, 1)
else:
needle = seq_get(args, 0)
haystack = seq_get(args, 1)
this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
self.validate_expression(this, args) self.validate_expression(this, args)
return this return this
@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser):
return None return None
def _parse_placeholder(self) -> t.Optional[exp.Expression]: def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.PLACEHOLDER): if self._match_set(self.PLACEHOLDER_PARSERS):
return self.expression(exp.Placeholder) placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
elif self._match(TokenType.COLON): if placeholder:
if self._match_set((TokenType.NUMBER, TokenType.VAR)): return placeholder
return self.expression(exp.Placeholder, this=self._prev.text)
self._advance(-1) self._advance(-1)
return None return None
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.EXCEPT): if not self._match(TokenType.EXCEPT):
return None return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_id_vars() return self._parse_wrapped_id_vars()
return self._parse_csv(self._parse_id_var)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE): if not self._match(TokenType.REPLACE):
return None return None
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_expression)
return self._parse_csv(self._parse_expression)
def _parse_csv( def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser):
def _parse_drop_column(self) -> t.Optional[exp.Expression]: def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
this = None
kind = self._prev.token_type
if kind == TokenType.CONSTRAINT:
this = self._parse_id_var()
if self._match(TokenType.CHECK):
expression = self._parse_wrapped(self._parse_conjunction)
enforced = self._match_text_seq("ENFORCED")
return self.expression(
exp.AddConstraint, this=this, expression=expression, enforced=enforced
)
if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
expression = self._parse_foreign_key()
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
expression = self._parse_primary_key()
return self.expression(exp.AddConstraint, this=this, expression=expression)
def _parse_alter(self) -> t.Optional[exp.Expression]: def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE): if not self._match(TokenType.TABLE):
return None return None
@ -3029,7 +3090,13 @@ class Parser(metaclass=_Parser):
this = self._parse_table(schema=True) this = self._parse_table(schema=True)
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
if self._match_text_seq("ADD", advance=False):
index = self._index
if self._match_text_seq("ADD"):
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
actions = self._parse_csv(self._parse_add_constraint)
else:
self._retreat(index)
actions = self._parse_csv(self._parse_add_column) actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False): elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column) actions = self._parse_csv(self._parse_drop_column)
@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser):
def _parse_merge(self) -> exp.Expression: def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO) self._match(TokenType.INTO)
target = self._parse_table(schema=True) target = self._parse_table()
self._match(TokenType.USING) self._match(TokenType.USING)
using = self._parse_table() using = self._parse_table()
@ -3146,11 +3213,12 @@ class Parser(metaclass=_Parser):
self._retreat(index) self._retreat(index)
return None return None
def _match(self, token_type): def _match(self, token_type, advance=True):
if not self._curr: if not self._curr:
return None return None
if self._curr.token_type == token_type: if self._curr.token_type == token_type:
if advance:
self._advance() self._advance()
return True return True

View file

@ -32,7 +32,7 @@ class Plan:
return self._dag return self._dag
@property @property
def leaves(self) -> t.Generator[Step, None, None]: def leaves(self) -> t.Iterator[Step]:
return (node for node, deps in self.dag.items() if not deps) return (node for node, deps in self.dag.items() if not deps)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -401,7 +401,7 @@ class SetOperation(Step):
op=expression.__class__, op=expression.__class__,
left=left.name, left=left.name,
right=right.name, right=right.name,
distinct=expression.args.get("distinct"), distinct=bool(expression.args.get("distinct")),
) )
step.add_dependency(left) step.add_dependency(left)
step.add_dependency(right) step.add_dependency(right)

View file

@ -109,9 +109,6 @@ class AbstractMappingSchema(t.Generic[T]):
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0: if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find mapping for {table}.")
else:
return None return None
elif value == 1: elif value == 1:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
@ -262,7 +259,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema = self.find(table_) schema = self.find(table_)
if schema is None: if schema is None:
raise SchemaError(f"Could not find table schema {table}") return []
if not only_visible or not self.visible: if not only_visible or not self.visible:
return list(schema) return list(schema)

View file

@ -84,6 +84,8 @@ class TokenType(AutoName):
TEXT = auto() TEXT = auto()
MEDIUMTEXT = auto() MEDIUMTEXT = auto()
LONGTEXT = auto() LONGTEXT = auto()
MEDIUMBLOB = auto()
LONGBLOB = auto()
BINARY = auto() BINARY = auto()
VARBINARY = auto() VARBINARY = auto()
JSON = auto() JSON = auto()
@ -587,6 +589,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PRECEDING": TokenType.PRECEDING, "PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY, "PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE, "PROCEDURE": TokenType.PROCEDURE,
"QUALIFY": TokenType.QUALIFY,
"RANGE": TokenType.RANGE, "RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE, "RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE, "REGEXP": TokenType.RLIKE,
@ -726,6 +729,8 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SHOW, TokenType.SHOW,
} }
COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}
# handle numeric literals like in hive (3L = BIGINT) # handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {} NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None ENCODE: t.Optional[str] = None
@ -842,8 +847,10 @@ class Tokenizer(metaclass=_Tokenizer):
) )
self._comments = [] self._comments = []
# If we have either a semicolon or a begin token before the command's token, we'll parse
# whatever follows the command's token as a string
if token_type in self.COMMANDS and ( if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS
): ):
start = self._current start = self._current
tokens = len(self.tokens) tokens = len(self.tokens)

View file

@ -17,7 +17,8 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
self.validate_identity("position(a, b)") self.validate_identity("position(haystack, needle)")
self.validate_identity("position(haystack, needle, position)")
self.validate_all( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -48,6 +49,10 @@ class TestClickhouse(Validator):
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
}, },
) )
self.validate_all(
"SELECT position(needle IN haystack)",
write={"clickhouse": "SELECT position(haystack, needle)"},
)
def test_cte(self): def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo") self.validate_identity("WITH 'x' AS foo SELECT foo")

View file

@ -950,40 +950,40 @@ class TestDialect(Validator):
}, },
) )
self.validate_all( self.validate_all(
"POSITION(' ' in x)", "POSITION(needle in haystack)",
write={ write={
"drill": "STRPOS(x, ' ')", "drill": "STRPOS(haystack, needle)",
"duckdb": "STRPOS(x, ' ')", "duckdb": "STRPOS(haystack, needle)",
"postgres": "STRPOS(x, ' ')", "postgres": "STRPOS(haystack, needle)",
"presto": "STRPOS(x, ' ')", "presto": "STRPOS(haystack, needle)",
"spark": "LOCATE(' ', x)", "spark": "LOCATE(needle, haystack)",
"clickhouse": "position(x, ' ')", "clickhouse": "position(haystack, needle)",
"snowflake": "POSITION(' ', x)", "snowflake": "POSITION(needle, haystack)",
"mysql": "LOCATE(' ', x)", "mysql": "LOCATE(needle, haystack)",
}, },
) )
self.validate_all( self.validate_all(
"STR_POSITION(x, 'a')", "STR_POSITION(haystack, needle)",
write={ write={
"drill": "STRPOS(x, 'a')", "drill": "STRPOS(haystack, needle)",
"duckdb": "STRPOS(x, 'a')", "duckdb": "STRPOS(haystack, needle)",
"postgres": "STRPOS(x, 'a')", "postgres": "STRPOS(haystack, needle)",
"presto": "STRPOS(x, 'a')", "presto": "STRPOS(haystack, needle)",
"spark": "LOCATE('a', x)", "spark": "LOCATE(needle, haystack)",
"clickhouse": "position(x, 'a')", "clickhouse": "position(haystack, needle)",
"snowflake": "POSITION('a', x)", "snowflake": "POSITION(needle, haystack)",
"mysql": "LOCATE('a', x)", "mysql": "LOCATE(needle, haystack)",
}, },
) )
self.validate_all( self.validate_all(
"POSITION('a', x, 3)", "POSITION(needle, haystack, pos)",
write={ write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1",
"presto": "STRPOS(x, 'a', 3)", "presto": "STRPOS(haystack, needle, pos)",
"spark": "LOCATE('a', x, 3)", "spark": "LOCATE(needle, haystack, pos)",
"clickhouse": "position(x, 'a', 3)", "clickhouse": "position(haystack, needle, pos)",
"snowflake": "POSITION('a', x, 3)", "snowflake": "POSITION(needle, haystack, pos)",
"mysql": "LOCATE('a', x, 3)", "mysql": "LOCATE(needle, haystack, pos)",
}, },
) )
self.validate_all( self.validate_all(
@ -1365,3 +1365,19 @@ SELECT
"spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *", "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
}, },
) )
self.validate_all(
"""
MERGE a b USING c d ON b.id = d.id
WHEN MATCHED AND EXISTS (
SELECT b.name
EXCEPT
SELECT d.name
)
THEN UPDATE SET b.name = d.name
""",
write={
"bigquery": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT DISTINCT SELECT d.name) THEN UPDATE SET b.name = d.name",
"snowflake": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
"spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
},
)

View file

@ -75,6 +75,15 @@ class TestMySQL(Validator):
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT)", "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
}, },
) )
self.validate_all(
"CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)",
read={
"mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)",
},
write={
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB)",
},
)
def test_canonical_functions(self): def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")

View file

@ -12,6 +12,24 @@ class TestSnowflake(Validator):
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'", "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
}, },
) )
self.validate_all(
"SELECT * EXCLUDE a, b FROM xxx",
write={
"snowflake": "SELECT * EXCLUDE (a, b) FROM xxx",
},
)
self.validate_all(
"SELECT * RENAME a AS b, c AS d FROM xxx",
write={
"snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx",
},
)
self.validate_all(
"SELECT * EXCLUDE a, b RENAME (c AS d, E as F) FROM xxx",
write={
"snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
},
)
self.validate_all( self.validate_all(
'x:a:"b c"', 'x:a:"b c"',
write={ write={

View file

@ -1,3 +1,4 @@
from sqlglot import exp, parse, parse_one
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
@ -5,6 +6,10 @@ class TestTSQL(Validator):
dialect = "tsql" dialect = "tsql"
def test_tsql(self): def test_tsql(self):
self.validate_identity("SELECT CASE WHEN a > 1 THEN b END")
self.validate_identity("END")
self.validate_identity("@x")
self.validate_identity("#x")
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
self.validate_identity("PRINT @TestVariable") self.validate_identity("PRINT @TestVariable")
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
@ -87,6 +92,95 @@ class TestTSQL(Validator):
}, },
) )
def test_udf(self):
self.validate_identity(
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
)
self.validate_identity(
"CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB"
)
self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz")
self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz")
self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1")
self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER")
# The following two cases don't necessarily correspond to valid TSQL, but they are used to verify
# that the syntax RETURNS @return_variable TABLE <table_type_definition> ... is parsed correctly.
#
# See also "Transact-SQL Multi-Statement Table-Valued Function Syntax"
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16
self.validate_identity(
"CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1"
)
self.validate_identity(
"CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone"
)
self.validate_all(
"""
CREATE FUNCTION udfProductInYear (
@model_year INT
)
RETURNS TABLE
AS
RETURN
SELECT
product_name,
model_year,
list_price
FROM
production.products
WHERE
model_year = @model_year
""",
write={
"tsql": """CREATE FUNCTION udfProductInYear(
@model_year INTEGER
)
RETURNS TABLE AS
RETURN SELECT
product_name,
model_year,
list_price
FROM production.products
WHERE
model_year = @model_year""",
},
pretty=True,
)
sql = """
CREATE procedure [TRANSF].[SP_Merge_Sales_Real]
@Loadid INTEGER
,@NumberOfRows INTEGER
AS
BEGIN
SET XACT_ABORT ON;
DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104);
DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104);
DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER);
DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER);
DECLARE @SalesAmountBefore float;
SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S;
END
"""
expected_sqls = [
'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON',
"DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
"DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)",
"DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)",
"DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)",
"DECLARE @SalesAmountBefore float",
'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S',
"END",
]
for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
def test_charindex(self): def test_charindex(self):
self.validate_all( self.validate_all(
"CHARINDEX(x, y, 9)", "CHARINDEX(x, y, 9)",
@ -472,3 +566,51 @@ class TestTSQL(Validator):
"EOMONTH(GETDATE(), -1)", "EOMONTH(GETDATE(), -1)",
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
) )
def test_variables(self):
# In TSQL @, # can be used as a prefix for variables/identifiers
expr = parse_one("@x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
expr = parse_one("#x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
def test_system_time(self):
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""",
},
)
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""",
},
)
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""",
},
)
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""",
},
)
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""",
},
)
self.validate_all(
"SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias",
write={
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""",
},
)

View file

@ -94,8 +94,8 @@ CONCAT_WS('-', 'a', 'b')
CONCAT_WS('-', 'a', 'b', 'c') CONCAT_WS('-', 'a', 'b', 'c')
POSEXPLODE("x") AS ("a", "b") POSEXPLODE("x") AS ("a", "b")
POSEXPLODE("x") AS ("a", "b", "c") POSEXPLODE("x") AS ("a", "b", "c")
STR_POSITION(x, 'a') STR_POSITION(haystack, needle)
STR_POSITION(x, 'a', 3) STR_POSITION(haystack, needle, pos)
LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1) LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1)
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
x[ORDINAL(1)][SAFE_OFFSET(2)] x[ORDINAL(1)][SAFE_OFFSET(2)]
@ -375,12 +375,16 @@ SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) SELECT * FROM (SELECT 1 UNION ALL SELECT 2)
SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b) SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b)
SELECT * FROM ((SELECT 1) AS a(b)) SELECT * FROM ((SELECT 1) AS a(b))
SELECT * FROM ((SELECT 1) UNION (SELECT 2) UNION (SELECT 3))
SELECT * FROM x AS y(a, b) SELECT * FROM x AS y(a, b)
SELECT * EXCEPT (a, b) SELECT * EXCEPT (a, b)
SELECT * EXCEPT (a, b) FROM y
SELECT * REPLACE (a AS b, b AS C) SELECT * REPLACE (a AS b, b AS C)
SELECT * REPLACE (a + 1 AS b, b AS C) SELECT * REPLACE (a + 1 AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) FROM y
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) FROM x
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals) SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
@ -438,6 +442,8 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLO
SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x)
SELECT PERCENTILE_DISC(0.5) WITHIN GROUP (ORDER BY x)
SELECT SUM(x) FILTER(WHERE x > 1) SELECT SUM(x) FILTER(WHERE x > 1)
SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y) SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y)
SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
@ -611,6 +617,7 @@ WITH a AS (SELECT * FROM b) DELETE FROM a
WITH a AS (SELECT * FROM b) CACHE TABLE a WITH a AS (SELECT * FROM b) CACHE TABLE a
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
SELECT :hello, ? FROM x LIMIT :my_limit SELECT :hello, ? FROM x LIMIT :my_limit
SELECT * FROM x FETCH NEXT @take ROWS ONLY OFFSET @skip
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
@ -670,3 +677,17 @@ CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)
CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))
ALTER TABLE "schema"."tablename" ADD CONSTRAINT "CHK_Name" CHECK (NOT "IdDwh" IS NULL AND "IdDwh" <> (0))
ALTER TABLE persons ADD CONSTRAINT persons_pk PRIMARY KEY (first_name, last_name)
ALTER TABLE pets ADD CONSTRAINT pets_persons_fk FOREIGN KEY (owner_first_name, owner_last_name) REFERENCES persons
ALTER TABLE pets ADD CONSTRAINT pets_name_not_cute_chk CHECK (LENGTH(name) < 20)
ALTER TABLE people10m ADD CONSTRAINT dateWithinRange CHECK (birthDate > '1900-01-01')
ALTER TABLE people10m ADD CONSTRAINT validIds CHECK (id > 1 AND id < 99999999) ENFORCED
ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY
ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE INITIALLY DEFERRED NORELY
ALTER TABLE baa ADD CONSTRAINT boo FOREIGN KEY (x, y) REFERENCES persons ON UPDATE NO ACTION ON DELETE NO ACTION MATCH FULL
ALTER TABLE a ADD PRIMARY KEY (x, y) NOT ENFORCED
ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla
CREATE TABLE foo (baz_id INT REFERENCES baz(id) DEFERRABLE)
SELECT end FROM a
SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1

View file

@ -18,3 +18,6 @@ WITH y AS (SELECT *) SELECT * FROM x AS x;
WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y; WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y;
WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y; WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y;
SELECT * FROM x AS x JOIN xx AS y;
SELECT * FROM (SELECT * FROM x AS x) AS x JOIN xx AS y;

View file

@ -2,7 +2,7 @@ SELECT a FROM (SELECT * FROM x);
SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0;
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2; SELECT 1 AS "1" FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2;
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q; SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q;

View file

@ -4,6 +4,14 @@
SELECT a FROM x; SELECT a FROM x;
SELECT x.a AS a FROM x AS x; SELECT x.a AS a FROM x AS x;
# execute: false
SELECT a FROM zz GROUP BY a ORDER BY a;
SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a;
# execute: false
SELECT x, p FROM (SELECT x from xx) xx CROSS JOIN yy;
SELECT xx.x AS x, yy.p AS p FROM (SELECT xx.x AS x FROM xx AS xx) AS xx CROSS JOIN yy AS yy;
SELECT a FROM x AS z; SELECT a FROM x AS z;
SELECT z.a AS a FROM x AS z; SELECT z.a AS a FROM x AS z;
@ -20,8 +28,8 @@ SELECT a AS b FROM x;
SELECT x.a AS b FROM x AS x; SELECT x.a AS b FROM x AS x;
# execute: false # execute: false
SELECT 1, 2 FROM x; SELECT 1, 2 + 3 FROM x;
SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x; SELECT 1 AS "1", 2 + 3 AS _col_1 FROM x AS x;
# execute: false # execute: false
SELECT a + b FROM x; SELECT a + b FROM x;
@ -57,6 +65,10 @@ SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a;
SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2; SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b); SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
# execute: false
SELECT CAST(a AS INT) FROM x ORDER BY a;
SELECT CAST(x.a AS INT) AS a FROM x AS x ORDER BY a;
# execute: false # execute: false
SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);

View file

@ -1,4 +1,3 @@
SELECT a FROM zz;
SELECT * FROM zz; SELECT * FROM zz;
SELECT z.a FROM x; SELECT z.a FROM x;
SELECT z.* FROM x; SELECT z.* FROM x;
@ -11,3 +10,4 @@ SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.
SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c; SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c;
SELECT x.a FROM x JOIN y USING (a); SELECT x.a FROM x JOIN y USING (a);
SELECT a, SUM(b) FROM x GROUP BY 3; SELECT a, SUM(b) FROM x GROUP BY 3;
SELECT p FROM (SELECT x from xx) y CROSS JOIN yy CROSS JOIN zz

View file

@ -481,11 +481,11 @@ class TestExecutor(unittest.TestCase):
def test_static_queries(self): def test_static_queries(self):
for sql, cols, rows in [ for sql, cols, rows in [
("SELECT 1", ["_col_0"], [(1,)]), ("SELECT 1", ["1"], [(1,)]),
("SELECT 1 + 2 AS x", ["x"], [(3,)]), ("SELECT 1 + 2 AS x", ["x"], [(3,)]),
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), ("SELECT 'foo' LIMIT 1", ["foo"], [("foo",)]),
( (
"SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)", "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
["_col_0", "_col_1"], ["_col_0", "_col_1"],

View file

@ -189,6 +189,27 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100",
) )
def test_function_building(self):
self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, 'foo')")
self.assertEqual(exp.func("COUNT", exp.Star()).sql(), "COUNT(*)")
self.assertEqual(exp.func("bloo").sql(), "BLOO()")
self.assertEqual(
exp.func("locate", "x", "xo", dialect="hive").sql("hive"), "LOCATE('x', 'xo')"
)
self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition)
self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous)
self.assertIsInstance(
exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")),
exp.Cast,
)
with self.assertRaises(ValueError):
exp.func("some_func", 1, arg2="foo")
with self.assertRaises(ValueError):
exp.func("abs")
def test_named_selects(self): def test_named_selects(self):
expression = parse_one( expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"

20
tests/test_lineage.py Normal file
View file

@ -0,0 +1,20 @@
import unittest
from sqlglot.lineage import lineage
class TestLineage(unittest.TestCase):
maxDiff = None
def test_lineage(self) -> None:
node = lineage(
"a",
"SELECT a FROM y",
schema={"x": {"a": "int"}},
sources={"y": "SELECT * FROM x"},
)
self.assertEqual(
node.source.sql(),
"SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */",
)
self.assertGreater(len(node.to_html()._repr_html_()), 1000)

View file

@ -117,6 +117,7 @@ class TestOptimizer(unittest.TestCase):
self.check_file( self.check_file(
"isolate_table_selects", "isolate_table_selects",
optimizer.isolate_table_selects.isolate_table_selects, optimizer.isolate_table_selects.isolate_table_selects,
schema=self.schema,
) )
def test_qualify_tables(self): def test_qualify_tables(self):

View file

@ -17,6 +17,11 @@ class TestSchema(unittest.TestCase):
with self.assertRaises(SchemaError): with self.assertRaises(SchemaError):
schema.column_names(to_table(table)) schema.column_names(to_table(table))
def assert_column_names_empty(self, schema, *tables):
for table in tables:
with self.subTest(table):
self.assertEqual(schema.column_names(to_table(table)), [])
def test_schema(self): def test_schema(self):
schema = ensure_schema( schema = ensure_schema(
{ {
@ -38,7 +43,7 @@ class TestSchema(unittest.TestCase):
("z.x.y", ["b", "c"]), ("z.x.y", ["b", "c"]),
) )
self.assert_column_names_raises( self.assert_column_names_empty(
schema, schema,
"z", "z",
"z.z", "z.z",
@ -76,6 +81,10 @@ class TestSchema(unittest.TestCase):
self.assert_column_names_raises( self.assert_column_names_raises(
schema, schema,
"x", "x",
)
self.assert_column_names_empty(
schema,
"z.x", "z.x",
"z.y", "z.y",
) )
@ -129,12 +138,16 @@ class TestSchema(unittest.TestCase):
self.assert_column_names_raises( self.assert_column_names_raises(
schema, schema,
"q",
"d2.x",
"y", "y",
"z", "z",
"d1.y", "d1.y",
"d1.z", "d1.z",
)
self.assert_column_names_empty(
schema,
"q",
"d2.x",
"a.b.c", "a.b.c",
) )