1
0
Fork 0

Merging upstream version 10.4.2.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:01:55 +01:00
parent de4e42d4d3
commit 0c79f8b507
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
88 changed files with 1637 additions and 436 deletions

View file

@ -1,4 +1,6 @@
"""## Python SQL parser, transpiler and optimizer."""
"""
.. include:: ../README.md
"""
from __future__ import annotations
@ -30,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.2.9"
__version__ = "10.4.2"
pretty = False

View file

@ -1,9 +1,15 @@
import argparse
import sys
import sqlglot
parser = argparse.ArgumentParser(description="Transpile SQL")
parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
parser.add_argument(
"sql",
metavar="sql",
type=str,
help="SQL statement(s) to transpile, or - to parse stdin.",
)
parser.add_argument(
"--read",
dest="read",
@ -48,14 +54,20 @@ parser.add_argument(
args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
sql = sys.stdin.read() if args.sql == "-" else args.sql
if args.parse:
sqls = [
repr(expression)
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
for expression in sqlglot.parse(
sql,
read=args.read,
error_level=error_level,
)
]
else:
sqls = sqlglot.transpile(
args.sql,
sql,
read=args.read,
write=args.write,
identify=args.identify,

View file

@ -0,0 +1,3 @@
"""
.. include:: ./README.md
"""

View file

@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
"ColumnLiterals",
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral",
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
SchemaInput = t.TypeVar(
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
)
OutputExpressionContainer = t.TypeVar(
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]

View file

@ -634,7 +634,7 @@ class DataFrame:
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict):
values = value.values()
values = list(value.values())
columns = self._ensure_and_normalize_cols(list(value))
if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns

View file

@ -1,11 +1,15 @@
"""Supports BigQuery Standard SQL."""
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
inline_array_sql,
no_ilike_sql,
rename_func,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -120,13 +124,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
}
KEYWORDS.pop("DIV")
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
@ -144,31 +147,33 @@ class BigQuery(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*parser.Parser.NESTED_TYPE_TOKENS,
*parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
TokenType.TABLE,
}
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@ -176,6 +181,7 @@ class BigQuery(Dialect):
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
@ -188,7 +194,7 @@ class BigQuery(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",

View file

@ -35,13 +35,13 @@ class ClickHouse(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map,
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
@ -55,7 +55,7 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@ -70,7 +70,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{name}({self.format_args(*args)})"
return f"{self.normalize_func(name)}({self.format_args(*args)})"
return _rename
@ -217,11 +217,11 @@ def if_sql(self, expression):
def arrow_json_extract_sql(self, expression):
return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression):
return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
return self.binary(expression, "->>")
def inline_array_sql(self, expression):
@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression):
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"

View file

@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
timestrtotime_sql,
)
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
@ -117,14 +118,14 @@ class Drill(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
@ -139,14 +140,13 @@ class Drill(Dialect):
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
@ -160,7 +160,7 @@ class Drill(Dialect):
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
format_time_lambda,
no_pivot_sql,
no_properties_sql,
@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
rename_func,
str_position_sql,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -83,11 +85,12 @@ class DuckDB(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
"CHARACTER VARYING": TokenType.VARCHAR,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@ -119,16 +122,18 @@ class DuckDB(Dialect):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
@ -136,6 +141,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@ -150,7 +156,7 @@ class DuckDB(Dialect):
exp.Struct: _struct_pack_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
@ -163,7 +169,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
rename_func,
strposition_to_local_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
@ -197,7 +198,7 @@ class Hive(Dialect):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
@ -217,7 +218,12 @@ class Hive(Dialect):
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
[
exp.TimeStrToTime(this=seq_get(args, 0)),
seq_get(args, 1),
]
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
@ -240,7 +246,7 @@ class Hive(Dialect):
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
@ -248,14 +254,14 @@ class Hive(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
@ -294,7 +300,7 @@ class Hive(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),

View file

@ -161,8 +161,6 @@ class MySQL(Dialect):
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
"N": TokenType.INTRODUCER,
"n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,
@ -175,10 +173,10 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
@ -190,7 +188,7 @@ class MySQL(Dialect):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@ -199,12 +197,12 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
@ -429,7 +427,7 @@ class MySQL(Dialect):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,

View file

@ -39,13 +39,13 @@ class Oracle(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"DECODE": exp.Matches.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@ -60,7 +60,7 @@ class Oracle(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,

View file

@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
"MILLISECOND": " * 1000",
"SECOND": "",
"MINUTE": " / 60",
"HOUR": " / 3600",
"DAY": " / 86400",
}
def _date_add_sql(kind):
def func(self, expression):
@ -34,16 +44,30 @@ def _date_add_sql(kind):
return func
def _lateral_sql(self, expression):
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True)
columns = f" AS {columns}" if columns else ""
return f"LATERAL{self.sep()}{this}{table}{columns}"
def _date_diff_sql(self, expression):
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
end = f"CAST({expression.this} AS TIMESTAMP)"
start = f"CAST({expression.expression} AS TIMESTAMP)"
if factor is not None:
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
age = f"AGE({end}, {start})"
if unit == "WEEK":
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
elif unit == "MONTH":
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
elif unit == "YEAR":
extract = f"EXTRACT(year FROM {age})"
else:
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
return f"CAST({extract} AS BIGINT)"
def _substring_sql(self, expression):
@ -141,7 +165,7 @@ def _serial_to_generated(expression):
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
@ -211,11 +235,16 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
@ -233,6 +262,7 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@ -244,17 +274,16 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
LATERAL_FUNCTION_AS_VIEW = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@ -264,7 +293,7 @@ class Postgres(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
@ -274,13 +303,16 @@ class Postgres(Dialect):
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.DateDiff: _date_diff_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
@ -291,5 +323,7 @@ class Postgres(Dialect):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
}

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
@ -38,10 +39,6 @@ def _datatype_sql(self, expression):
return sql
def _date_parse_sql(self, expression):
return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
def _explode_to_unnest_sql(self, expression):
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
@ -137,7 +134,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
@ -174,7 +171,7 @@ class Presto(Dialect):
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@ -184,7 +181,7 @@ class Presto(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@ -224,8 +221,8 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),

View file

@ -36,7 +36,6 @@ class Redshift(Postgres):
"TIMETZ": TokenType.TIMESTAMPTZ,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):

View file

@ -3,13 +3,15 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
rename_func,
timestrtotime_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@ -183,7 +185,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPES = ["\\"]
ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -206,9 +208,10 @@ class Snowflake(Dialect):
CREATE_TRANSIENT = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -218,13 +221,14 @@ class Snowflake(Dialect):
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@ -246,3 +250,47 @@ class Snowflake(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
def values_sql(self, expression: exp.Values) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
from adding quotes to the column by using the `identify` argument when generating the SQL.
"""
alias = expression.args.get("alias")
if alias and alias.args.get("columns"):
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier)
and isinstance(node.parent, exp.TableAlias)
and node.arg_key == "columns"
else node,
)
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
generating the SQL.
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
values_expressions = expression.find_all(exp.Values)
values_identifiers = set(
flatten(
v.args.get("alias", exp.Alias()).args.get("columns", [])
for v in values_expressions
)
)
if values_identifiers:
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier) and node in values_identifiers
else node,
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)

View file

@ -76,7 +76,7 @@ class Spark(Hive):
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -87,6 +87,16 @@ class Spark(Hive):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
def _parse_add_column(self):
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self):
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
)
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, # type: ignore

View file

@ -42,13 +42,13 @@ class SQLite(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@ -70,7 +70,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
**MySQL.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",

View file

@ -30,7 +30,7 @@ class Tableau(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -224,11 +224,7 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"']
for prefix in ["", "n", "N"]
]
QUOTES = ["'", '"']
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -253,7 +249,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -314,7 +310,7 @@ class TSQL(Dialect):
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",

View file

@ -1,3 +1,7 @@
"""
.. include:: ../posts/sql_diff.md
"""
from __future__ import annotations
import typing as t

View file

@ -29,10 +29,10 @@ class Context:
self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
self.env = {**ENV, **(env or {}), "scope": self.row_readers}
def eval(self, code):
return eval(code, ENV, self.env)
return eval(code, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)

View file

@ -127,14 +127,16 @@ def interval(this, unit):
ENV = {
"exp": exp,
# aggs
"SUM": filter_nulls(sum),
"ARRAYAGG": list,
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
"SUM": filter_nulls(sum),
# scalar functions
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),

View file

@ -394,6 +394,18 @@ def _case_sql(self, expression):
return chain
def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
lambda n: exp.Var(this=n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
)
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
@ -414,6 +426,7 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),

View file

@ -1,6 +1,11 @@
"""
.. include:: ../pdoc/docs/expressions.md
"""
from __future__ import annotations
import datetime
import math
import numbers
import re
import typing as t
@ -682,6 +687,10 @@ class CharacterSet(Expression):
class With(Expression):
arg_types = {"expressions": True, "recursive": False}
@property
def recursive(self) -> bool:
return bool(self.args.get("recursive"))
class WithinGroup(Expression):
arg_types = {"this": True, "expression": False}
@ -724,6 +733,18 @@ class ColumnDef(Expression):
"this": True,
"kind": True,
"constraints": False,
"exists": False,
}
class AlterColumn(Expression):
arg_types = {
"this": True,
"dtype": False,
"collate": False,
"using": False,
"default": False,
"drop": False,
}
@ -877,6 +898,11 @@ class Introducer(Expression):
arg_types = {"this": True, "expression": True}
# national char, like n'utf8'
class National(Expression):
pass
class LoadData(Expression):
arg_types = {
"this": True,
@ -894,7 +920,7 @@ class Partition(Expression):
class Fetch(Expression):
arg_types = {"direction": False, "count": True}
arg_types = {"direction": False, "count": False}
class Group(Expression):
@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
"group": False,
"having": False,
"qualify": False,
"window": False,
"windows": False,
"distribute": False,
"sort": False,
"cluster": False,
@ -1353,7 +1379,7 @@ class Union(Subqueryable):
Example:
>>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
Args:
expression (str | int | Expression): the SQL code string to parse.
@ -1889,6 +1915,18 @@ class Select(Subqueryable):
**opts,
)
def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
return _apply_list_builder(
*expressions,
instance=self,
arg="windows",
append=append,
into=Window,
dialect=dialect,
copy=copy,
**opts,
)
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@ -2140,6 +2178,11 @@ class DataType(Expression):
)
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
pass
class StructKwarg(Expression):
arg_types = {"this": True, "expression": True}
@ -2167,18 +2210,26 @@ class Command(Expression):
arg_types = {"this": True, "expression": False}
class Transaction(Command):
class Transaction(Expression):
arg_types = {"this": False, "modes": False}
class Commit(Command):
class Commit(Expression):
arg_types = {"chain": False}
class Rollback(Command):
class Rollback(Expression):
arg_types = {"savepoint": False}
class AlterTable(Expression):
arg_types = {
"this": True,
"actions": True,
"exists": False,
}
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
pass
class Slice(Binary):
arg_types = {"this": False, "expression": False}
class Sub(Binary):
pass
@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
class Interval(TimeUnit):
arg_types = {"this": True, "unit": False}
arg_types = {"this": False, "unit": False}
class IgnoreNulls(Expression):
@ -2730,8 +2785,11 @@ class Initcap(Func):
pass
class JSONExtract(Func):
arg_types = {"this": True, "path": True}
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
class JSONExtract(Binary, Func):
_sql_names = ["JSON_EXTRACT"]
@ -2776,6 +2834,10 @@ class Log10(Func):
pass
class LogicalOr(AggFunc):
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
class Lower(Func):
_sql_names = ["LOWER", "LCASE"]
@ -2846,6 +2908,10 @@ class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
update.set(
"from",
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
)
if isinstance(where, Condition):
where = Where(this=where)
if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
update.set(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
return update
@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value.keys()],
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))

View file

@ -361,10 +361,11 @@ class Generator:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
if not constraints:
return f"{column} {kind}"
return f"{column} {kind} {constraints}"
return f"{exists}{column} {kind}"
return f"{exists}{column} {kind} {constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@ -549,6 +550,9 @@ class Generator:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
@ -633,6 +637,9 @@ class Generator:
def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
@ -793,19 +800,17 @@ class Generator:
if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}"
alias = expression.args["alias"]
table = alias.name
columns = self.expressions(alias, key="columns", flat=True)
if expression.args.get("view"):
table = f" {table}" if table else table
alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
table = f" {alias.name}" if alias.name else ""
columns = f" AS {columns}" if columns else ""
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}"
table = f" AS {table}" if table else table
columns = f"({columns})" if columns else ""
return f"LATERAL {this}{table}{columns}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{alias}"
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
@ -891,13 +896,15 @@ class Generator:
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins", [])],
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
*[self.sql(sql) for sql in expression.args.get("joins") or []],
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.sql(expression, "window"),
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@ -1008,11 +1015,7 @@ class Generator:
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
this = f"{this} OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
@ -1141,9 +1144,11 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
return f"INTERVAL{this}{unit}"
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
@ -1245,6 +1250,43 @@ class Generator:
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
this = self.sql(expression, "this")
dtype = self.sql(expression, "dtype")
if dtype:
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
default = self.sql(expression, "default")
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
return f"ALTER COLUMN {this} DROP DEFAULT"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@ -1327,6 +1369,9 @@ class Generator:
def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
@ -1369,6 +1414,7 @@ class Generator:
flat: bool = False,
indent: bool = True,
sep: str = ", ",
prefix: str = "",
) -> str:
expressions = expression.args.get(key or "expressions")
@ -1391,11 +1437,13 @@ class Generator:
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
result_sqls.append(
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
)
else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=False) if indent else result_sql

View file

@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
return expression

View file

@ -129,10 +129,23 @@ def join_condition(join):
"""
name = join.this.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
def extract_condition(condition):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.true())
# find the join keys
# SELECT
# FROM x
@ -141,20 +154,30 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
extract_condition(condition)
elif normalized(on, dnf=True):
conditions = None
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.true())
for condition in on.flatten():
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
if conditions is None:
conditions = parts
else:
temp = []
for p in parts:
cs = [c for c in conditions if p == c]
if cs:
temp.append(p)
temp.extend(cs)
conditions = temp
for condition in conditions:
extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on

View file

@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
existing_ctes = {}
with_ = root.expression.args.get("with")
recursive = False
if with_:
recursive = with_.args.get("recursive")
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
return expression

View file

@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
return x
return [
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)

View file

@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
RULES = (
@ -34,7 +33,6 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
quote_identities,
)

View file

@ -27,7 +27,14 @@ def pushdown_predicates(expression):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources, scope_ref_count)
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
with_ = source.parent.expression.args.get("with")
if with_ and with_.recursive:
return {}
node = source.expression
if isinstance(node, exp.Join):
if node.side:
if node.side and node.side != "RIGHT":
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:

View file

@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# SELECTION TO USE IF SELECTION LIST IS EMPTY
# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
return removed_indexes
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)

View file

@ -311,6 +311,9 @@ def _qualify_outputs(scope):
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif isinstance(selection, exp.Subquery):
if not selection.alias:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)

View file

@ -1,25 +0,0 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

View file

@ -511,9 +511,20 @@ def _traverse_union(scope):
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE
for derived_table in derived_tables:
top = None
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if is_cte and scope.expression.args["with"].recursive:
union = derived_table.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
alias = derived_table.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
if is_cte:
scope.cte_scopes.append(child_scope)
else:
scope.derived_table_scopes.append(top)
scope.derived_table_scopes.append(child_scope)
scope.sources.update(sources)

View file

@ -16,7 +16,7 @@ def unnest_subqueries(expression):
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
# for all external columns in the where statement, find the relevant predicate
# keys to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
)
value = select.selects[0]
key_aliases = {}
group_by = []
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
# so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
select.select(
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
append=False,
copy=False,
)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if is_subquery_projection:
key.replace(nested)
continue
if key in group_by:
key.replace(nested)
parent_predicate = _replace(

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@ -117,6 +117,7 @@ class Parser(metaclass=_Parser):
TokenType.GEOMETRY,
TokenType.HLLSKETCH,
TokenType.HSTORE,
TokenType.PSEUDO_TYPE,
TokenType.SUPER,
TokenType.SERIAL,
TokenType.SMALLSERIAL,
@ -153,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.CACHE,
TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COLUMN,
TokenType.COMMAND,
TokenType.COMMIT,
TokenType.COMPOUND,
@ -169,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.ESCAPE,
TokenType.FALSE,
TokenType.FIRST,
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FUNCTION,
@ -188,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.OFFSET,
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
@ -222,12 +226,18 @@ class Parser(metaclass=_Parser):
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.WINDOW,
}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.WINDOW,
*TYPE_TOKENS,
*SUBQUERY_PREDICATES,
}
@ -351,22 +362,27 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
path=path,
expression=path,
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
path=path,
expression=path,
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
this=this,
path=path,
expression=path,
),
TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtractScalar,
this=this,
path=path,
expression=path,
),
TokenType.PLACEHOLDER: lambda self, this, key: self.expression(
exp.JSONBContains,
this=this,
expression=key,
),
}
@ -392,25 +408,27 @@ class Parser(metaclass=_Parser):
exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
exp.Window: lambda self: self._parse_named_window(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
STATEMENT_PARSERS = {
TokenType.ALTER: lambda self: self._parse_alter(),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
}
UNARY_PARSERS = {
@ -441,6 +459,7 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.NATIONAL: lambda self, token: self._parse_national(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@ -454,6 +473,9 @@ class Parser(metaclass=_Parser):
TokenType.ILIKE: lambda self, this: self._parse_escape(
self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
),
TokenType.IRLIKE: lambda self, this: self.expression(
exp.RegexpILike, this=this, expression=self._parse_bitwise()
),
TokenType.RLIKE: lambda self, this: self.expression(
exp.RegexpLike, this=this, expression=self._parse_bitwise()
),
@ -535,8 +557,7 @@ class Parser(metaclass=_Parser):
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW)
and self._parse_window(self._parse_id_var(), alias=True),
"windows": lambda self: self._parse_window_clause(),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@ -551,18 +572,18 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
TokenType.TABLE,
TokenType.VIEW,
TokenType.COLUMN,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
STRICT_CAST = True
LATERAL_FUNCTION_AS_VIEW = False
__slots__ = (
"error_level",
@ -782,13 +803,16 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
def _parse_drop(self):
def _parse_drop(self, default_kind=None):
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
self.raise_error(f"Expected {self.CREATABLES}")
return
if default_kind:
kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
return
return self.expression(
exp.Drop,
@ -876,7 +900,7 @@ class Parser(metaclass=_Parser):
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
if assignment:
key = self._parse_var() or self._parse_string()
key = self._parse_var_or_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser):
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren()
this = self._parse_subquery(this)
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression(
exp.Values,
expressions=self._parse_csv(self._parse_value),
expressions=expressions,
alias=self._parse_table_alias(),
)
else:
this = None
return self._parse_set_operations(this) if this else None
return self._parse_set_operations(this)
def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH):
@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
columns = None
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(lambda: self._parse_id_var(any_token))
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
self._match_r_paren()
else:
columns = None
if not alias and not columns:
return None
@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
columns = None
table_alias = None
if view or self.LATERAL_FUNCTION_AS_VIEW:
table_alias = self._parse_id_var(any_token=False)
if self._match(TokenType.ALIAS):
columns = self._parse_csv(self._parse_id_var)
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
else:
self._match(TokenType.ALIAS)
table_alias = self._parse_id_var(any_token=False)
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
self._match_r_paren()
table_alias = self._parse_table_alias()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
alias=table_alias,
)
if outer_apply or cross_apply:
@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser):
if negate:
this = self.expression(exp.Not, this=this)
if self._match(TokenType.IS):
this = self._parse_is(this)
return this
def _parse_is(self, this):
@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser):
return None
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
expressions = None
@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser):
if value is None:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
elif type_token == TokenType.INTERVAL:
value = self.expression(exp.Interval, unit=self._parse_var())
if maybe_func and check_func:
index2 = self._index
@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser):
def _parse_primary(self):
if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
token_type = self._prev.token_type
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
if token_type == TokenType.STRING:
expressions = [primary]
while self._match(TokenType.STRING):
expressions.append(exp.Literal.string(self._prev.text))
if len(expressions) > 1:
return self.expression(exp.Concat, expressions=expressions)
return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text)
def _parse_national(self, token):
return self.expression(exp.National, this=exp.Literal.string(token.text))
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
else:
expressions = [self._parse_id_var()]
@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
this = self._parse_conjunction()
this = self._parse_select_or_expression()
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
self._match(TokenType.RESPECT_NULLS)
return self._parse_alias(self._parse_limit(self._parse_order(this)))
return self._parse_limit(self._parse_order(this))
def _parse_schema(self, this=None):
index = self._index
@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser):
return this
args = self._parse_csv(
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.ENCODE):
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.NULL):
@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_BRACKET):
return this
expressions = self._parse_csv(self._parse_conjunction)
if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
if not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments
return self._parse_bracket(this)
def _parse_slice(self, this):
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return this
def _parse_case(self):
ifs = []
default = None
@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser):
collation=collation,
)
def _parse_window_clause(self):
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
def _parse_named_window(self):
return self._parse_window(self._parse_id_var(), alias=True)
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser):
if identifier:
return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
return None
return exp.Identifier(this=self._prev.text, quoted=False)
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
return exp.Identifier(this=self._prev.text, quoted=False)
return None
def _parse_string(self):
if self._match(TokenType.STRING):
@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
def _parse_var(self):
if self._match(TokenType.VAR):
def _parse_var(self, any_token=False):
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
def _advance_any(self):
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
return self._prev
return None
def _parse_var_or_string(self):
return self._parse_var() or self._parse_string()
@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
self._advance()
return self.expression(exp.Placeholder, this=self._prev.text)
if self._match_set((TokenType.NUMBER, TokenType.VAR)):
return self.expression(exp.Placeholder, this=self._prev.text)
self._advance(-1)
return None
def _parse_except(self):
@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain)
def _parse_add_column(self):
if not self._match_text_seq("ADD"):
return None
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
expression.set("exists", exists_column)
return expression
def _parse_drop_column(self):
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
def _parse_alter(self):
if not self._match(TokenType.TABLE):
return None
exists = self._parse_exists()
this = self._parse_table(schema=True)
actions = None
if self._match_text_seq("ADD", advance=False):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
actions = self.expression(exp.AlterColumn, this=column, drop=True)
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
actions = self.expression(
exp.AlterColumn, this=column, default=self._parse_conjunction()
)
else:
self._match_text_seq("SET", "DATA")
actions = self.expression(
exp.AlterColumn,
this=column,
dtype=self._match_text_seq("TYPE") and self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser):
return True
return False
def _match_text_seq(self, *texts):
def _match_text_seq(self, *texts, advance=True):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser):
else:
self._retreat(index)
return False
if not advance:
self._retreat(index)
return True
def _replace_columns_with_dots(self, this):

View file

@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType] = {
"STR": exp.DataType.build("text"),
}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:

View file

@ -48,6 +48,7 @@ class TokenType(AutoName):
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
NATIONAL = auto()
BLOCK_START = auto()
BLOCK_END = auto()
@ -111,6 +112,7 @@ class TokenType(AutoName):
# keywords
ALIAS = auto()
ALTER = auto()
ALWAYS = auto()
ALL = auto()
ANTI = auto()
@ -196,6 +198,7 @@ class TokenType(AutoName):
INTERVAL = auto()
INTO = auto()
INTRODUCER = auto()
IRLIKE = auto()
IS = auto()
ISNULL = auto()
JOIN = auto()
@ -241,6 +244,7 @@ class TokenType(AutoName):
PRIMARY_KEY = auto()
PROCEDURE = auto()
PROPERTIES = auto()
PSEUDO_TYPE = auto()
QUALIFY = auto()
QUOTE = auto()
RANGE = auto()
@ -346,7 +350,11 @@ class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
klass._QUOTES = {
f"{prefix}{s}": e
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
}
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"COMPOUND": TokenType.COMPOUND,
@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VOLATILE": TokenType.VOLATILE,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WINDOW": TokenType.WINDOW,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
"STR": TokenType.TEXT,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
"ALTER": TokenType.COMMAND,
"ALTER": TokenType.ALTER,
"ALTER AGGREGATE": TokenType.COMMAND,
"ALTER DEFAULT": TokenType.COMMAND,
"ALTER DOMAIN": TokenType.COMMAND,
"ALTER ROLE": TokenType.COMMAND,
"ALTER RULE": TokenType.COMMAND,
"ALTER SEQUENCE": TokenType.COMMAND,
"ALTER TYPE": TokenType.COMMAND,
"ALTER USER": TokenType.COMMAND,
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer):
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text)
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.