1
0
Fork 0

Merging upstream version 6.3.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:45:11 +01:00
parent 81e6900b0a
commit 393757f998
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
41 changed files with 1558 additions and 267 deletions

View file

@ -8,7 +8,9 @@ from sqlglot.expressions import (
and_,
column,
condition,
except_,
from_,
intersect,
maybe_parse,
not_,
or_,
@ -16,11 +18,12 @@ from sqlglot.expressions import (
subquery,
)
from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.8"
__version__ = "6.3.1"
pretty = False

View file

@ -135,6 +135,7 @@ class BigQuery(Dialect):
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -172,12 +173,11 @@ class BigQuery(Dialect):
exp.AnonymousProperty,
}
EXPLICIT_UNION = True
def in_unnest_op(self, unnest):
return self.sql(unnest)
def union_op(self, expression):
return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")

View file

@ -1,10 +1,16 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.helper import csv
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer, TokenType
def _lower_func(sql):
index = sql.index("(")
return sql[:index].lower() + sql[index:]
class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
@ -14,17 +20,23 @@ class ClickHouse(Dialect):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"NULLABLE": TokenType.NULLABLE,
"FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"TUPLE": TokenType.STRUCT,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"MAP": parse_var_map,
}
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
@ -39,10 +51,25 @@ class ClickHouse(Dialect):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
EXPLICIT_UNION = True

View file

@ -77,7 +77,6 @@ class Dialect(metaclass=_Dialect):
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
wrap_derived_values = True
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
@ -170,7 +169,6 @@ class Dialect(metaclass=_Dialect):
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
"wrap_derived_values": self.wrap_derived_values,
**opts,
}
)
@ -271,6 +269,21 @@ def struct_extract_sql(self, expression):
return f"{this}.{struct_key}"
def var_map_sql(self, expression):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map.")
return f"MAP({self.sql(keys)}, {self.sql(values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({csv(*args)})"
def format_time_lambda(exp_class, dialect, default=None):
"""Helper used for time expressions.

View file

@ -11,40 +11,14 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
struct_extract_sql,
var_map_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import csv, list_get
from sqlglot.parser import Parser
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer
def _parse_map(args):
keys = []
values = []
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
return HiveMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
def _map_sql(self, expression):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map use SparkSQL instead.")
return f"MAP({self.sql(keys)}, {self.sql(values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({csv(*args)})"
def _array_sort(self, expression):
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
@ -122,10 +96,6 @@ def _index_sql(self, expression):
return f"{this} ON TABLE {table} {columns}"
class HiveMap(exp.Map):
is_var_len_args = True
class Hive(Dialect):
alias_post_tablesample = True
@ -206,7 +176,7 @@ class Hive(Dialect):
position=list_get(args, 2),
),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": _parse_map,
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
@ -245,8 +215,8 @@ class Hive(Dialect):
exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql,
HiveMap: _map_sql,
exp.Map: var_map_sql,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),

View file

@ -10,6 +10,32 @@ def _limit_sql(self, expression):
class Oracle(Dialect):
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
"AM": "%p", # Meridian indicator with or without periods
"A.M.": "%p", # Meridian indicator with or without periods
"PM": "%p", # Meridian indicator with or without periods
"P.M.": "%p", # Meridian indicator with or without periods
"D": "%u", # Day of week (1-7)
"DAY": "%A", # name of day
"DD": "%d", # day of month (1-31)
"DDD": "%j", # day of year (1-366)
"DY": "%a", # abbreviated name of day
"HH": "%I", # Hour of day (1-12)
"HH12": "%I", # alias for HH
"HH24": "%H", # Hour of day (0-23)
"IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard
"MI": "%M", # Minute (0-59)
"MM": "%m", # Month (01-12; January = 01)
"MON": "%b", # Abbreviated name of month
"MONTH": "%B", # Name of month
"SS": "%S", # Second (0-59)
"WW": "%W", # Week of year (1-53)
"YY": "%y", # 15
"YYYY": "%Y", # 2015
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
@ -30,6 +56,9 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP,
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
def query_modifiers(self, expression, *sqls):

View file

@ -118,13 +118,22 @@ def _serial_to_generated(expression):
return expression
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
return format_time_lambda(exp.StrToTime, "postgres")(args)
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
"AM": "%p",
"PM": "%p",
"D": "%w", # 1-based day of week
"D": "%u", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
"FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres
@ -172,7 +181,7 @@ class Postgres(Dialect):
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
@ -211,4 +220,5 @@ class Postgres(Dialect):
exp.TableSample: no_tablesample_sql,
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
}

View file

@ -121,6 +121,7 @@ class Snowflake(Dialect):
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
@ -143,7 +144,7 @@ class Snowflake(Dialect):
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
"$": TokenType.DOLLAR, # needed to break for quotes
"$": TokenType.PARAMETER,
}
KEYWORDS = {
@ -164,6 +165,8 @@ class Snowflake(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
}

View file

@ -4,8 +4,9 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
rename_func,
)
from sqlglot.dialects.hive import Hive, HiveMap
from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get
from sqlglot.parser import Parser
def _create_sql(self, e):
@ -47,8 +48,6 @@ def _unix_to_time(self, expression):
class Spark(Hive):
wrap_derived_values = False
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
@ -78,8 +77,19 @@ class Spark(Hive):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Hive.Generator):
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
"MERGE": lambda self: self._parse_join_hint("MERGE"),
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
@ -102,8 +112,9 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
HiveMap: _map_sql,
}
WRAP_DERIVED_VALUES = False
class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -326,6 +326,7 @@ class Python(Dialect):
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),

View file

@ -508,7 +508,69 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
class UDTF(DerivedTable):
class Unionable:
def union(self, expression, distinct=True, dialect=None, **opts):
"""
Builds a UNION expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Union: the Union expression.
"""
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
def intersect(self, expression, distinct=True, dialect=None, **opts):
"""
Builds an INTERSECT expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Intersect: the Intersect expression
"""
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
def except_(self, expression, distinct=True, dialect=None, **opts):
"""
Builds an EXCEPT expression.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Except: the Except expression
"""
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
class UDTF(DerivedTable, Unionable):
pass
@ -518,6 +580,10 @@ class Annotation(Expression):
"expression": True,
}
@property
def alias(self):
return self.expression.alias_or_name
class Cache(Expression):
arg_types = {
@ -700,6 +766,10 @@ class Hint(Expression):
arg_types = {"expressions": True}
class JoinHint(Expression):
arg_types = {"this": True, "expressions": True}
class Identifier(Expression):
arg_types = {"this": True, "quoted": False}
@ -971,7 +1041,7 @@ class Tuple(Expression):
arg_types = {"expressions": False}
class Subqueryable:
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True):
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@ -1654,7 +1724,7 @@ class Select(Subqueryable, Expression):
return self.expressions
class Subquery(DerivedTable):
class Subquery(DerivedTable, Unionable):
arg_types = {
"this": True,
"alias": False,
@ -1731,7 +1801,7 @@ class Parameter(Expression):
class Placeholder(Expression):
arg_types = {}
arg_types = {"this": False}
class Null(Condition):
@ -1791,6 +1861,8 @@ class DataType(Expression):
IMAGE = auto()
VARIANT = auto()
OBJECT = auto()
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
def build(cls, dtype, **kwargs):
@ -2007,7 +2079,7 @@ class Distinct(Expression):
class In(Predicate):
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False}
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
class TimeUnit(Expression):
@ -2377,6 +2449,11 @@ class Map(Func):
arg_types = {"keys": True, "values": True}
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
class Max(AggFunc):
pass
@ -2449,7 +2526,7 @@ class Substring(Func):
class StrPosition(Func):
arg_types = {"this": True, "substr": True, "position": False}
arg_types = {"substr": True, "this": True, "position": False}
class StrToDate(Func):
@ -2785,6 +2862,81 @@ def _wrap_operator(expression):
return expression
def union(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one UNION expression.
Example:
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Union: the syntax tree for the UNION expression.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Union(this=left, expression=right, distinct=distinct)
def intersect(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one INTERSECT expression.
Example:
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Intersect: the syntax tree for the INTERSECT expression.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Intersect(this=left, expression=right, distinct=distinct)
def except_(left, right, distinct=True, dialect=None, **opts):
"""
Initializes a syntax tree from one EXCEPT expression.
Example:
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Except: the syntax tree for the EXCEPT statement.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
return Except(this=left, expression=right, distinct=distinct)
def select(*expressions, dialect=None, **opts):
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@ -2991,7 +3143,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
If an Expression instance is passed, this is used as-is.
alias (str or Identifier): the alias name to use. If the name has
special characters it is quoted.
table (boolean): create a table alias, default false
table (bool): create a table alias, default false
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@ -3002,7 +3154,7 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted)
alias = TableAlias(this=alias) if table else alias
if "alias" in exp.arg_types:
if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
exp.set("alias", alias)
return exp
@ -3138,6 +3290,60 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
def table_name(table):
"""Get the full name of a table as a string.
Args:
table (exp.Table | str): Table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
>>> table_name(parse_one("select * from a.b.c").find(exp.Table))
'a.b.c'
Returns:
str: the table name
"""
table = maybe_parse(table, into=Table)
return ".".join(
part
for part in (
table.text("catalog"),
table.text("db"),
table.name,
)
if part
)
def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping.
Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced
mapping (Dict[str, str]): Mapping of table names
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
'SELECT * FROM "c"'
Returns:
The mapped expression
"""
def _replace_tables(node):
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
return table_(*reversed(new_name.split(".")), quoted=True)
return node
return expression.transform(_replace_tables)
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()

View file

@ -48,8 +48,9 @@ class Generator:
TRANSFORMS = {
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -57,7 +58,12 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
# whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
EXPLICIT_UNION = False
# wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
@ -101,7 +107,6 @@ class Generator:
"unsupported_messages",
"null_ordering",
"max_unsupported",
"wrap_derived_values",
"_indent",
"_replace_backslash",
"_escaped_quote_end",
@ -130,7 +135,6 @@ class Generator:
null_ordering=None,
max_unsupported=3,
leading_comma=False,
wrap_derived_values=True,
):
import sqlglot
@ -154,7 +158,6 @@ class Generator:
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self.wrap_derived_values = wrap_derived_values
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
@ -595,7 +598,7 @@ class Generator:
if not alias:
return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias
if self.wrap_derived_values:
if self.WRAP_DERIVED_VALUES:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
@ -779,8 +782,8 @@ class Generator:
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def placeholder_sql(self, *_):
return "?"
def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"
def subquery_sql(self, expression):
alias = self.sql(expression, "alias")
@ -803,7 +806,9 @@ class Generator:
)
def union_op(self, expression):
return f"UNION{'' if expression.args.get('distinct') else ' ALL'}"
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
def unnest_sql(self, expression):
args = self.expressions(expression, flat=True)
@ -940,10 +945,13 @@ class Generator:
def in_sql(self, expression):
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
if query:
in_sql = self.wrap(query)
elif unnest:
in_sql = self.in_unnest_op(unnest)
elif field:
in_sql = self.sql(field)
else:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
@ -1178,3 +1186,8 @@ class Generator:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"
def joinhint_sql(self, expression):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"

View file

@ -1,16 +1,20 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import Scope, traverse_scope
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
Assumes that we've already executed the optimizer's qualify_columns step.
(TODO -- replace this with a better example after adding some functionality)
Example:
>>> import sqlglot
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
>>> annotated_expression.type
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
sqlglot.Expression: expression annotated with types
"""
schema = ensure_schema(schema)
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
@ -35,10 +41,81 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_cast(expr),
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
@ -97,43 +174,82 @@ class TypeAnnotator:
},
}
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
subscope_selects = {
name: {select.alias_or_name: select for select in source.selects}
for name, source in scope.sources.items()
if isinstance(source, Scope)
}
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
col.type = subscope_selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
if expression.type:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
return annotator(self, expression) if annotator else self._annotate_args(expression)
return (
annotator(self, expression)
if annotator
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
self.annotate(v)
self._maybe_annotate(v)
return expression
def _annotate_cast(self, expression):
expression.type = expression.args["to"].this
return self._annotate_args(expression)
def _annotate_data_type(self, expression):
expression.type = expression.this
return self._annotate_args(expression)
def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
if isinstance(expression, (exp.Condition, exp.Predicate)):
left_type = expression.left.type
right_type = expression.right.type
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
expression.type = self._maybe_coerce(left_type, right_type)
return expression
@ -157,6 +273,6 @@ class TypeAnnotator:
return expression
def _annotate_boolean(self, expression):
expression.type = exp.DataType.Type.BOOLEAN
return expression
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)

View file

@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
"joins",
"where",
"order",
"hint",
}
@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = table.find_ancestor(exp.From, exp.Join)
from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
return expression
@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
return expression
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
outer_scope (Scope)
inner_select (exp.Select)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated):
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
isinstance(from_or_join, exp.Join)
and inner_select.args.get("where")
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
)
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
)
)
@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
for table in tables:
if table.alias_or_name == node_to_replace.alias_or_name:
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
table.set("this", exp.to_identifier(new_table.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _merge_hints(outer_scope, inner_scope):
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
outer_scope_hint = outer_scope.expression.args.get("hint")
if outer_scope_hint:
for hint_expression in inner_scope_hint.expressions:
outer_scope_hint.append("expressions", hint_expression)
else:
outer_scope.expression.set("hint", inner_scope_hint)
def _pop_cte(inner_scope):
"""
Remove CTE from the AST.

View file

@ -1,3 +1,5 @@
from collections import defaultdict
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
@ -20,22 +22,30 @@ def pushdown_predicates(expression):
Returns:
sqlglot.Expression: optimized expression
"""
for scope in reversed(traverse_scope(expression)):
scope_ref_count = defaultdict(lambda: 0)
scopes = traverse_scope(expression)
scopes.reverse()
for scope in scopes:
for _, source in scope.selected_sources.values():
scope_ref_count[id(source)] += 1
for scope in scopes:
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources)
pushdown(where.this, scope.selected_sources, scope_ref_count)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
def pushdown(condition, sources):
def pushdown(condition, sources, scope_ref_count):
if not condition:
return
@ -45,17 +55,17 @@ def pushdown(condition, sources):
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
if cnf_like:
pushdown_cnf(predicates, sources)
pushdown_cnf(predicates, sources, scope_ref_count)
else:
pushdown_dnf(predicates, sources)
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope):
def pushdown_cnf(predicates, scope, scope_ref_count):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope).values():
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
node.on(predicate, copy=False)
@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False)
def pushdown_dnf(predicates, scope):
def pushdown_dnf(predicates, scope, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope):
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope)
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
if table not in nodes:
continue
@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope):
node.where(replace_aliases(node, predicate), copy=False)
def nodes_for_predicate(predicate, sources):
def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources):
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be push down
# a node can reference a CTE which should be pushed down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression
@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources):
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
if not node.args.get("group"):
# we can't push down predicates to select statements if they are referenced in
# multiple places.
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
nodes[table] = node
return nodes

View file

@ -31,8 +31,8 @@ def qualify_columns(expression, schema):
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
_expand_order_by(scope)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
@ -235,7 +235,7 @@ def _expand_stars(scope, resolver):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table)
columns = resolver.get_source_columns(table, only_visible=True)
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
@ -332,7 +332,7 @@ class _Resolver:
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
return self._all_columns
def get_source_columns(self, name):
def get_source_columns(self, name, only_visible=False):
"""Resolve the source columns for a given source `name`"""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
@ -342,7 +342,7 @@ class _Resolver:
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
try:
return self.schema.column_names(source)
return self.schema.column_names(source, only_visible)
except Exception as e:
raise OptimizeError(str(e)) from e

View file

@ -9,16 +9,28 @@ class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def column_names(self, table):
def column_names(self, table, only_visible=False):
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
only_visible (bool): Whether to include invisible columns
Returns:
list[str]: list of column names
"""
@abc.abstractmethod
def get_column_type(self, table, column):
"""
Get the exp.DataType type of a column in the schema.
Args:
table (sqlglot.expressions.Table): The source table.
column (sqlglot.expressions.Column): The target column.
Returns:
sqlglot.expressions.DataType.Type: The resulting column type.
"""
class MappingSchema(Schema):
"""
@ -29,10 +41,19 @@ class MappingSchema(Schema):
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
"""
def __init__(self, schema):
def __init__(self, schema, visible=None, dialect=None):
self.schema = schema
self.visible = visible
self.dialect = dialect
self._type_mapping_cache = {}
depth = _dict_depth(schema)
@ -49,7 +70,7 @@ class MappingSchema(Schema):
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def column_names(self, table):
def column_names(self, table, only_visible=False):
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
@ -58,7 +79,39 @@ class MappingSchema(Schema):
for forbidden in self.forbidden_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
return columns
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
return [col for col in columns if col in visible]
def get_column_type(self, table, column):
try:
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column.sql()}")
def _convert_type(self, schema_type):
"""
Convert a type represented as a string to the corresponding exp.DataType.Type object.
Args:
schema_type (str): The type we want to convert.
Returns:
sqlglot.expressions.DataType.Type: The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
self._type_mapping_cache[schema_type] = exp.maybe_parse(
schema_type, into=exp.DataType, dialect=self.dialect
).this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def ensure_schema(schema):

View file

@ -68,6 +68,7 @@ class Scope:
self._selected_sources = None
self._columns = None
self._external_columns = None
self._join_hints = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
@ -85,14 +86,17 @@ class Scope:
self._subqueries = []
self._derived_tables = []
self._raw_columns = []
self._join_hints = []
for node, parent, _ in self.walk(bfs=False):
if node is self.expression:
continue
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
@ -246,7 +250,7 @@ class Scope:
table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
referenced_names = []
@ -310,6 +314,18 @@ class Scope:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
return self._external_columns
@property
def join_hints(self):
"""
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
"""
if self._join_hints is None:
return []
return self._join_hints
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.

View file

@ -56,12 +56,16 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
return NULL
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
return NULL
if always_true(expression.this):
return FALSE
if expression.this == FALSE:
@ -95,10 +99,10 @@ def simplify_connectors(expression):
return left
if isinstance(expression, exp.And):
if NULL in (left, right):
return NULL
if FALSE in (left, right):
return FALSE
if NULL in (left, right):
return NULL
if always_true(left) and always_true(right):
return TRUE
if always_true(left):

View file

@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
logger = logging.getLogger("sqlglot")
def parse_var_map(args):
keys = []
values = []
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
return exp.VarMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
class Parser:
"""
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
@ -48,6 +60,7 @@ class Parser:
start=exp.Literal.number(1),
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
}
NO_PAREN_FUNCTIONS = {
@ -117,6 +130,7 @@ class Parser:
TokenType.VAR,
TokenType.ALTER,
TokenType.ALWAYS,
TokenType.ANTI,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
@ -164,6 +178,7 @@ class Parser:
TokenType.ROWS,
TokenType.SCHEMA_COMMENT,
TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
TokenType.STABLE,
@ -273,6 +288,8 @@ class Parser:
TokenType.INNER,
TokenType.OUTER,
TokenType.CROSS,
TokenType.SEMI,
TokenType.ANTI,
}
COLUMN_OPERATORS = {
@ -318,6 +335,8 @@ class Parser:
exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@ -338,7 +357,6 @@ class Parser:
TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
@ -910,7 +928,20 @@ class Parser:
return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False):
if self._match(TokenType.SELECT):
cte = self._parse_with()
if cte:
this = self._parse_statement()
if not this:
self.raise_error("Failed to parse any statement following CTE")
return cte
if "with" in this.arg_types:
this.set("with", cte)
else:
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@ -938,39 +969,6 @@ class Parser:
if from_:
this.set("from", from_)
self._parse_query_modifiers(this)
elif self._match(TokenType.WITH):
recursive = self._match(TokenType.RECURSIVE)
expressions = []
while True:
expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA):
break
cte = self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
)
this = self._parse_statement()
if not this:
self.raise_error("Failed to parse any statement following CTE")
return cte
if "with" in this.arg_types:
this.set(
"with",
self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
),
)
else:
self.raise_error(f"{this.key} does not support CTE")
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
@ -986,6 +984,26 @@ class Parser:
return self._parse_set_operations(this) if this else None
def _parse_with(self):
if not self._match(TokenType.WITH):
return None
recursive = self._match(TokenType.RECURSIVE)
expressions = []
while True:
expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA):
break
return self.expression(
exp.With,
expressions=expressions,
recursive=recursive,
)
def _parse_cte(self):
alias = self._parse_table_alias()
if not alias or not alias.this:
@ -1485,8 +1503,7 @@ class Parser:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
else:
self._match_l_paren()
elif self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
@ -1495,6 +1512,9 @@ class Parser:
this = self.expression(exp.In, this=this, expressions=expressions)
self._match_r_paren()
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
return this
def _parse_between(self, this):
@ -1591,7 +1611,7 @@ class Parser:
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
expressions = self._parse_csv(self._parse_number)
expressions = self._parse_csv(self._parse_type)
if not expressions:
self._retreat(index)
@ -1706,7 +1726,7 @@ class Parser:
def _parse_field(self, any_token=False):
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
def _parse_function(self):
def _parse_function(self, functions=None):
if not self._curr:
return None
@ -1742,7 +1762,9 @@ class Parser:
self._match_r_paren()
return this
function = self.FUNCTIONS.get(upper)
if functions is None:
functions = self.FUNCTIONS
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
if function:
@ -2025,10 +2047,20 @@ class Parser:
return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self):
substr = self._parse_bitwise()
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
string = self._parse_bitwise()
return self.expression(exp.StrPosition, this=string, substr=substr)
args.append(self._parse_bitwise())
# Note: we're parsing in order needle, haystack, position
this = exp.StrPosition.from_arg_list(args)
self.validate_expression(this, args)
return this
def _parse_join_hint(self, func_name):
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int])
@ -2247,6 +2279,9 @@ class Parser:
def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER):
return exp.Placeholder()
elif self._match(TokenType.COLON):
self._advance()
return exp.Placeholder(this=self._prev.text)
return None
def _parse_except(self):

View file

@ -104,6 +104,7 @@ class TokenType(AutoName):
ALL = auto()
ALTER = auto()
ANALYZE = auto()
ANTI = auto()
ANY = auto()
ARRAY = auto()
ASC = auto()
@ -236,6 +237,7 @@ class TokenType(AutoName):
SCHEMA_COMMENT = auto()
SEED = auto()
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
SET = auto()
SHOW = auto()
@ -262,6 +264,7 @@ class TokenType(AutoName):
USE = auto()
USING = auto()
VALUES = auto()
VACUUM = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
@ -406,6 +409,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
@ -528,6 +532,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ROWS": TokenType.ROWS,
"SEED": TokenType.SEED,
"SELECT": TokenType.SELECT,
"SEMI": TokenType.SEMI,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
@ -551,6 +556,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
"VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@ -577,6 +583,7 @@ class Tokenizer(metaclass=_Tokenizer):
"INT8": TokenType.BIGINT,
"DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL,
"NUMERIC": TokenType.DECIMAL,
"FIXED": TokenType.DECIMAL,
@ -629,6 +636,7 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.USE,
TokenType.VACUUM,
}
# handle numeric literals like in hive (3L = BIGINT)