Merging upstream version 7.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
964bd62de9
commit
e6b3d2fe54
42 changed files with 1430 additions and 253 deletions
|
@ -23,7 +23,7 @@ from sqlglot.generator import Generator
|
|||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "6.3.1"
|
||||
__version__ = "7.1.3"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
|
@ -66,7 +65,7 @@ class ClickHouse(Dialect):
|
|||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})",
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv, list_get
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
@ -177,11 +177,11 @@ class Dialect(metaclass=_Dialect):
|
|||
def rename_func(name):
|
||||
def _rename(self, expression):
|
||||
args = (
|
||||
self.expressions(expression, flat=True)
|
||||
expression.expressions
|
||||
if isinstance(expression, exp.Func) and expression.is_var_len_args
|
||||
else csv(*[self.sql(e) for e in expression.args.values()])
|
||||
else expression.args.values()
|
||||
)
|
||||
return f"{name}({args})"
|
||||
return f"{name}({self.format_args(*args)})"
|
||||
|
||||
return _rename
|
||||
|
||||
|
@ -189,15 +189,11 @@ def rename_func(name):
|
|||
def approx_count_distinct_sql(self, expression):
|
||||
if expression.args.get("accuracy"):
|
||||
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
||||
return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})"
|
||||
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
||||
|
||||
|
||||
def if_sql(self, expression):
|
||||
expressions = csv(
|
||||
self.sql(expression, "this"),
|
||||
self.sql(expression, "true"),
|
||||
self.sql(expression, "false"),
|
||||
)
|
||||
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
|
||||
return f"IF({expressions})"
|
||||
|
||||
|
||||
|
@ -254,6 +250,11 @@ def no_trycast_sql(self, expression):
|
|||
return self.cast_sql(expression)
|
||||
|
||||
|
||||
def no_properties_sql(self, expression):
|
||||
self.unsupported("Properties unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
|
@ -275,13 +276,13 @@ def var_map_sql(self, expression):
|
|||
|
||||
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
||||
self.unsupported("Cannot convert array columns into map.")
|
||||
return f"MAP({self.sql(keys)}, {self.sql(values)})"
|
||||
return f"MAP({self.format_args(keys, values)})"
|
||||
|
||||
args = []
|
||||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
return f"MAP({csv(*args)})"
|
||||
return f"MAP({self.format_args(*args)})"
|
||||
|
||||
|
||||
def format_time_lambda(exp_class, dialect, default=None):
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_sql,
|
||||
format_time_lambda,
|
||||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
no_tablesample_sql,
|
||||
rename_func,
|
||||
|
@ -68,6 +69,12 @@ def _struct_pack_sql(self, expression):
|
|||
return f"STRUCT_PACK({', '.join(args)})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
class Tokenizer(Tokenizer):
|
||||
KEYWORDS = {
|
||||
|
@ -106,6 +113,8 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
class Generator(Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
|
@ -113,8 +122,9 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
|
||||
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
|
@ -124,6 +134,7 @@ class DuckDB(Dialect):
|
|||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
|
|
@ -14,7 +14,7 @@ from sqlglot.dialects.dialect import (
|
|||
var_map_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv, list_get
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
||||
|
@ -32,7 +32,7 @@ def _property_sql(self, expression):
|
|||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})"
|
||||
return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})"
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
|
@ -226,7 +226,7 @@ class Hive(Dialect):
|
|||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
|
||||
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: _str_to_time,
|
||||
exp.StrToUnix: _str_to_unix,
|
||||
|
@ -241,7 +241,7 @@ class Hive(Dialect):
|
|||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
|
||||
|
|
|
@ -167,6 +167,7 @@ class Postgres(Dialect):
|
|||
**Tokenizer.KEYWORDS,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"COMMENT ON": TokenType.COMMENT_ON,
|
||||
"IDENTITY": TokenType.IDENTITY,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
|
|
|
@ -11,7 +11,7 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv, list_get
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
|
@ -26,7 +26,7 @@ def _concat_ws_sql(self, expression):
|
|||
sep, *args = expression.expressions
|
||||
sep = self.sql(sep)
|
||||
if len(args) > 1:
|
||||
return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})"
|
||||
return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
|
||||
return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ def _no_sort_array(self, expression):
|
|||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
comparator = None
|
||||
args = csv(self.sql(expression, "this"), comparator)
|
||||
args = self.format_args(expression.this, comparator)
|
||||
return f"ARRAY_SORT({args})"
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
|
||||
|
||||
|
@ -16,7 +15,7 @@ def _coalesce_sql(self, expression):
|
|||
def _count_sql(self, expression):
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Distinct):
|
||||
return f"COUNTD({self.sql(this, 'this')})"
|
||||
return f"COUNTD({self.expressions(this, flat=True)})"
|
||||
return f"COUNT({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
|
@ -33,5 +32,5 @@ class Tableau(Dialect):
|
|||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))),
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import datetime
|
||||
import numbers
|
||||
import re
|
||||
from collections import deque
|
||||
|
@ -508,7 +509,7 @@ class DerivedTable(Expression):
|
|||
return [select.alias_or_name for select in self.selects]
|
||||
|
||||
|
||||
class Unionable:
|
||||
class Unionable(Expression):
|
||||
def union(self, expression, distinct=True, dialect=None, **opts):
|
||||
"""
|
||||
Builds a UNION expression.
|
||||
|
@ -614,6 +615,10 @@ class Create(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Describe(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class UserDefinedFunction(Expression):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
@ -741,6 +746,11 @@ class Check(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class Directory(Expression):
|
||||
# https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html
|
||||
arg_types = {"this": True, "local": False, "row_format": False}
|
||||
|
||||
|
||||
class ForeignKey(Expression):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
|
@ -804,6 +814,18 @@ class Introducer(Expression):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class LoadData(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"local": False,
|
||||
"overwrite": False,
|
||||
"inpath": True,
|
||||
"partition": False,
|
||||
"input_format": False,
|
||||
"serde": False,
|
||||
}
|
||||
|
||||
|
||||
class Partition(Expression):
|
||||
pass
|
||||
|
||||
|
@ -1037,6 +1059,18 @@ class Reference(Expression):
|
|||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
||||
class RowFormat(Expression):
|
||||
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
|
||||
arg_types = {
|
||||
"fields": False,
|
||||
"escaped": False,
|
||||
"collection_items": False,
|
||||
"map_keys": False,
|
||||
"lines": False,
|
||||
"null": False,
|
||||
}
|
||||
|
||||
|
||||
class Tuple(Expression):
|
||||
arg_types = {"expressions": False}
|
||||
|
||||
|
@ -1071,6 +1105,14 @@ class Subqueryable(Unionable):
|
|||
return []
|
||||
return with_.expressions
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
raise NotImplementedError("Subqueryable objects must implement `selects`")
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
|
||||
|
||||
def with_(
|
||||
self,
|
||||
alias,
|
||||
|
@ -1158,7 +1200,7 @@ class Table(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Union(Subqueryable, Expression):
|
||||
class Union(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -1169,7 +1211,11 @@ class Union(Subqueryable, Expression):
|
|||
|
||||
@property
|
||||
def named_selects(self):
|
||||
return self.args["this"].unnest().named_selects
|
||||
return self.this.unnest().named_selects
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
return self.this.unnest().selects
|
||||
|
||||
@property
|
||||
def left(self):
|
||||
|
@ -1222,7 +1268,7 @@ class Schema(Expression):
|
|||
arg_types = {"this": False, "expressions": True}
|
||||
|
||||
|
||||
class Select(Subqueryable, Expression):
|
||||
class Select(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"expressions": False,
|
||||
|
@ -2075,7 +2121,7 @@ class Bracket(Condition):
|
|||
|
||||
|
||||
class Distinct(Expression):
|
||||
arg_types = {"this": False, "on": False}
|
||||
arg_types = {"expressions": False, "on": False}
|
||||
|
||||
|
||||
class In(Predicate):
|
||||
|
@ -2233,6 +2279,14 @@ class Case(Func):
|
|||
class Cast(Func):
|
||||
arg_types = {"this": True, "to": True}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.this.name
|
||||
|
||||
@property
|
||||
def to(self):
|
||||
return self.args["to"]
|
||||
|
||||
|
||||
class TryCast(Cast):
|
||||
pass
|
||||
|
@ -2666,7 +2720,7 @@ def _norm_args(expression):
|
|||
else:
|
||||
arg = _norm_arg(arg)
|
||||
|
||||
if arg is not None:
|
||||
if arg is not None and arg is not False:
|
||||
args[k] = arg
|
||||
|
||||
return args
|
||||
|
@ -3012,6 +3066,30 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
|||
return update
|
||||
|
||||
|
||||
def delete(table, where=None, dialect=None, **opts):
|
||||
"""
|
||||
Builds a delete statement.
|
||||
|
||||
Example:
|
||||
>>> delete("my_table", where="id > 1").sql()
|
||||
'DELETE FROM my_table WHERE id > 1'
|
||||
|
||||
Args:
|
||||
where (str|Condition): sql conditional parsed into a WHERE statement
|
||||
dialect (str): the dialect used to parse the input expressions.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the syntax tree for the DELETE statement.
|
||||
"""
|
||||
return Delete(
|
||||
this=maybe_parse(table, into=Table, dialect=dialect, **opts),
|
||||
where=Where(this=where)
|
||||
if isinstance(where, Condition)
|
||||
else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
|
||||
)
|
||||
|
||||
|
||||
def condition(expression, dialect=None, **opts):
|
||||
"""
|
||||
Initialize a logical condition expression.
|
||||
|
@ -3131,6 +3209,25 @@ def to_identifier(alias, quoted=None):
|
|||
return identifier
|
||||
|
||||
|
||||
def to_table(sql_path, **kwargs):
|
||||
"""
|
||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||
Example:
|
||||
>>> to_table('catalog.db.table_name').sql()
|
||||
'catalog.db.table_name'
|
||||
|
||||
Args:
|
||||
sql_path(str): `[catalog].[schema].[table]` string
|
||||
Returns:
|
||||
Table: A table expression
|
||||
"""
|
||||
table_parts = sql_path.split(".")
|
||||
catalog, db, table_name = [
|
||||
to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts
|
||||
]
|
||||
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
|
||||
|
||||
|
||||
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
||||
"""
|
||||
Create an Alias expression.
|
||||
|
@ -3216,6 +3313,28 @@ def table_(table, db=None, catalog=None, quoted=None):
|
|||
)
|
||||
|
||||
|
||||
def values(values, alias=None):
|
||||
"""Build VALUES statement.
|
||||
|
||||
Example:
|
||||
>>> values([(1, '2')]).sql()
|
||||
"VALUES (1, '2')"
|
||||
|
||||
Args:
|
||||
values (list[tuple[str | Expression]]): values statements that will be converted to SQL
|
||||
alias (str): optional alias
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Values: the Values expression object
|
||||
"""
|
||||
return Values(
|
||||
expressions=[convert(tup) for tup in values],
|
||||
alias=to_identifier(alias) if alias else None,
|
||||
)
|
||||
|
||||
|
||||
def convert(value):
|
||||
"""Convert a python value into an expression object.
|
||||
|
||||
|
@ -3246,6 +3365,12 @@ def convert(value):
|
|||
keys=[convert(k) for k in value.keys()],
|
||||
values=[convert(v) for v in value.values()],
|
||||
)
|
||||
if isinstance(value, datetime.datetime):
|
||||
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
return TimeStrToTime(this=datetime_literal)
|
||||
if isinstance(value, datetime.date):
|
||||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
return DateStrToDate(this=date_literal)
|
||||
raise ValueError(f"Cannot convert {value}")
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||
from sqlglot.helper import apply_index_offset, csv, ensure_list
|
||||
from sqlglot.helper import apply_index_offset, csv
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -43,14 +43,18 @@ class Generator:
|
|||
Default: 3
|
||||
leading_comma (bool): if the the comma is leading or trailing in select statements
|
||||
Default: False
|
||||
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
|
||||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -111,6 +115,7 @@ class Generator:
|
|||
"_replace_backslash",
|
||||
"_escaped_quote_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -135,6 +140,7 @@ class Generator:
|
|||
null_ordering=None,
|
||||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -162,6 +168,7 @@ class Generator:
|
|||
self._replace_backslash = self.escape == "\\"
|
||||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
|
||||
def generate(self, expression):
|
||||
"""
|
||||
|
@ -268,7 +275,7 @@ class Generator:
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
return self.sql(expression, "expression")
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
|
||||
|
||||
def uncache_sql(self, expression):
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -364,6 +371,9 @@ class Generator:
|
|||
)
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def describe_sql(self, expression):
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
|
||||
def prepend_ctes(self, expression, sql):
|
||||
with_ = self.sql(expression, "with")
|
||||
if with_:
|
||||
|
@ -405,6 +415,12 @@ class Generator:
|
|||
)
|
||||
return f"{type_sql}{nested}"
|
||||
|
||||
def directory_sql(self, expression):
|
||||
local = "LOCAL " if expression.args.get("local") else ""
|
||||
row_format = self.sql(expression, "row_format")
|
||||
row_format = f" {row_format}" if row_format else ""
|
||||
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
|
||||
|
||||
def delete_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
where_sql = self.sql(expression, "where")
|
||||
|
@ -513,13 +529,19 @@ class Generator:
|
|||
return f"{key}={value}"
|
||||
|
||||
def insert_sql(self, expression):
|
||||
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
|
||||
this = self.sql(expression, "this")
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
||||
if isinstance(expression.this, exp.Directory):
|
||||
this = "OVERWRITE " if overwrite else "INTO "
|
||||
else:
|
||||
this = "OVERWRITE TABLE " if overwrite else "INTO "
|
||||
|
||||
this = f"{this}{self.sql(expression, 'this')}"
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression):
|
||||
|
@ -534,6 +556,21 @@ class Generator:
|
|||
def introducer_sql(self, expression):
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def rowformat_sql(self, expression):
|
||||
fields = expression.args.get("fields")
|
||||
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
|
||||
escaped = expression.args.get("escaped")
|
||||
escaped = f" ESCAPED BY {escaped}" if escaped else ""
|
||||
items = expression.args.get("collection_items")
|
||||
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
|
||||
keys = expression.args.get("map_keys")
|
||||
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
|
||||
lines = expression.args.get("lines")
|
||||
lines = f" LINES TERMINATED BY {lines}" if lines else ""
|
||||
null = expression.args.get("null")
|
||||
null = f" NULL DEFINED AS {null}" if null else ""
|
||||
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
|
||||
|
||||
def table_sql(self, expression):
|
||||
table = ".".join(
|
||||
part
|
||||
|
@ -688,6 +725,19 @@ class Generator:
|
|||
return f"{self.quote_start}{text}{self.quote_end}"
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression):
|
||||
local = " LOCAL" if expression.args.get("local") else ""
|
||||
inpath = f" INPATH {self.sql(expression, 'inpath')}"
|
||||
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
|
||||
this = f" INTO TABLE {self.sql(expression, 'this')}"
|
||||
partition = self.sql(expression, "partition")
|
||||
partition = f" {partition}" if partition else ""
|
||||
input_format = self.sql(expression, "input_format")
|
||||
input_format = f" INPUTFORMAT {input_format}" if input_format else ""
|
||||
serde = self.sql(expression, "serde")
|
||||
serde = f" SERDE {serde}" if serde else ""
|
||||
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
|
||||
|
||||
def null_sql(self, *_):
|
||||
return "NULL"
|
||||
|
||||
|
@ -885,20 +935,24 @@ class Generator:
|
|||
return f"EXISTS{self.wrap(expression)}"
|
||||
|
||||
def case_sql(self, expression):
|
||||
this = self.indent(self.sql(expression, "this"), skip_first=True)
|
||||
this = f" {this}" if this else ""
|
||||
ifs = []
|
||||
this = self.sql(expression, "this")
|
||||
statements = [f"CASE {this}" if this else "CASE"]
|
||||
|
||||
for e in expression.args["ifs"]:
|
||||
ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}"))
|
||||
ifs.append(self.indent(f"THEN {self.sql(e, 'true')}"))
|
||||
statements.append(f"WHEN {self.sql(e, 'this')}")
|
||||
statements.append(f"THEN {self.sql(e, 'true')}")
|
||||
|
||||
if expression.args.get("default") is not None:
|
||||
ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}"))
|
||||
default = self.sql(expression, "default")
|
||||
|
||||
ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs)
|
||||
statement = f"CASE{this}{ifs}{self.seg('END')}"
|
||||
return statement
|
||||
if default:
|
||||
statements.append(f"ELSE {default}")
|
||||
|
||||
statements.append("END")
|
||||
|
||||
if self.pretty and self.text_width(statements) > self._max_text_width:
|
||||
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
|
||||
|
||||
return " ".join(statements)
|
||||
|
||||
def constraint_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -970,7 +1024,7 @@ class Generator:
|
|||
return f"REFERENCES {this}({expressions})"
|
||||
|
||||
def anonymous_sql(self, expression):
|
||||
args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
|
||||
args = self.format_args(*expression.expressions)
|
||||
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
|
||||
|
||||
def paren_sql(self, expression):
|
||||
|
@ -1008,7 +1062,9 @@ class Generator:
|
|||
if not self.pretty:
|
||||
return self.binary(expression, op)
|
||||
|
||||
return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False))
|
||||
sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
|
||||
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
|
||||
return f"{sep}{op} ".join(sqls)
|
||||
|
||||
def bitwiseand_sql(self, expression):
|
||||
return self.binary(expression, "&")
|
||||
|
@ -1039,7 +1095,7 @@ class Generator:
|
|||
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
|
||||
|
||||
def distinct_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
this = self.expressions(expression, flat=True)
|
||||
this = f" {this}" if this else ""
|
||||
|
||||
on = self.sql(expression, "on")
|
||||
|
@ -1128,13 +1184,23 @@ class Generator:
|
|||
|
||||
def function_fallback_sql(self, expression):
|
||||
args = []
|
||||
for arg_key in expression.arg_types:
|
||||
arg_value = ensure_list(expression.args.get(arg_key) or [])
|
||||
for a in arg_value:
|
||||
args.append(self.sql(a))
|
||||
for arg_value in expression.args.values():
|
||||
if isinstance(arg_value, list):
|
||||
for value in arg_value:
|
||||
args.append(value)
|
||||
elif arg_value:
|
||||
args.append(arg_value)
|
||||
|
||||
args_str = self.indent(", ".join(args), skip_first=True, skip_last=True)
|
||||
return f"{self.normalize_func(expression.sql_name())}({args_str})"
|
||||
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
|
||||
|
||||
def format_args(self, *args):
|
||||
args = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
if self.pretty and self.text_width(args) > self._max_text_width:
|
||||
return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(args)
|
||||
|
||||
def text_width(self, args):
|
||||
return sum(len(arg) for arg in args)
|
||||
|
||||
def format_time(self, expression):
|
||||
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
|
||||
|
|
42
sqlglot/optimizer/eliminate_ctes.py
Normal file
42
sqlglot/optimizer/eliminate_ctes.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from sqlglot.optimizer.scope import Scope, build_scope
|
||||
|
||||
|
||||
def eliminate_ctes(expression):
|
||||
"""
|
||||
Remove unused CTEs from an expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> eliminate_ctes(expression).sql()
|
||||
'SELECT a FROM z'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
root = build_scope(expression)
|
||||
|
||||
ref_count = root.ref_count()
|
||||
|
||||
# Traverse the scope tree in reverse so we can remove chains of unused CTEs
|
||||
for scope in reversed(list(root.traverse())):
|
||||
if scope.is_cte:
|
||||
count = ref_count[id(scope)]
|
||||
if count <= 0:
|
||||
cte_node = scope.expression.parent
|
||||
with_node = cte_node.parent
|
||||
cte_node.pop()
|
||||
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
for _, source in scope.selected_sources.values():
|
||||
if isinstance(source, Scope):
|
||||
ref_count[id(source)] -= 1
|
||||
|
||||
return expression
|
160
sqlglot/optimizer/eliminate_joins.py
Normal file
160
sqlglot/optimizer/eliminate_joins.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_joins(expression):
|
||||
"""
|
||||
Remove unused joins from an expression.
|
||||
|
||||
This only removes joins when we know that the join condition doesn't produce duplicate rows.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> eliminate_joins(expression).sql()
|
||||
'SELECT x.a FROM x'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
for scope in traverse_scope(expression):
|
||||
# If any columns in this scope aren't qualified, it's hard to determine if a join isn't used.
|
||||
# It's probably possible to infer this from the outputs of derived tables.
|
||||
# But for now, let's just skip this rule.
|
||||
if scope.unqualified_columns:
|
||||
continue
|
||||
|
||||
joins = scope.expression.args.get("joins", [])
|
||||
|
||||
# Reverse the joins so we can remove chains of unused joins
|
||||
for join in reversed(joins):
|
||||
alias = join.this.alias_or_name
|
||||
if _should_eliminate_join(scope, join, alias):
|
||||
join.pop()
|
||||
scope.remove_source(alias)
|
||||
return expression
|
||||
|
||||
|
||||
def _should_eliminate_join(scope, join, alias):
|
||||
inner_source = scope.sources.get(alias)
|
||||
return (
|
||||
isinstance(inner_source, Scope)
|
||||
and not _join_is_used(scope, join, alias)
|
||||
and (
|
||||
(join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join))
|
||||
or (not join.args.get("on") and _has_single_output_row(inner_source))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _join_is_used(scope, join, alias):
|
||||
# We need to find all columns that reference this join.
|
||||
# But columns in the ON clause shouldn't count.
|
||||
on = join.args.get("on")
|
||||
if on:
|
||||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
|
||||
|
||||
|
||||
def _is_joined_on_all_unique_outputs(scope, join):
|
||||
unique_outputs = _unique_outputs(scope)
|
||||
if not unique_outputs:
|
||||
return False
|
||||
|
||||
_, join_keys, _ = join_condition(join)
|
||||
remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
|
||||
return not remaining_unique_outputs
|
||||
|
||||
|
||||
def _unique_outputs(scope):
|
||||
"""Determine output columns of `scope` that must have a unique combination per row"""
|
||||
if scope.expression.args.get("distinct"):
|
||||
return set(scope.expression.named_selects)
|
||||
|
||||
group = scope.expression.args.get("group")
|
||||
if group:
|
||||
grouped_expressions = set(group.expressions)
|
||||
grouped_outputs = set()
|
||||
|
||||
unique_outputs = set()
|
||||
for select in scope.selects:
|
||||
output = select.unalias()
|
||||
if output in grouped_expressions:
|
||||
grouped_outputs.add(output)
|
||||
unique_outputs.add(select.alias_or_name)
|
||||
|
||||
# All the grouped expressions must be in the output
|
||||
if not grouped_expressions.difference(grouped_outputs):
|
||||
return unique_outputs
|
||||
else:
|
||||
return set()
|
||||
|
||||
if _has_single_output_row(scope):
|
||||
return set(scope.expression.named_selects)
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _has_single_output_row(scope):
|
||||
return isinstance(scope.expression, exp.Select) and (
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
|
||||
or _is_limit_1(scope)
|
||||
or not scope.expression.args.get("from")
|
||||
)
|
||||
|
||||
|
||||
def _is_limit_1(scope):
|
||||
limit = scope.expression.args.get("limit")
|
||||
return limit and limit.expression.this == "1"
|
||||
|
||||
|
||||
def join_condition(join):
|
||||
"""
|
||||
Extract the join condition from a join expression.
|
||||
|
||||
Args:
|
||||
join (exp.Join)
|
||||
Returns:
|
||||
tuple[list[str], list[str], exp.Expression]:
|
||||
Tuple of (source key, join key, remaining predicate)
|
||||
"""
|
||||
name = join.this.alias_or_name
|
||||
on = join.args.get("on") or exp.TRUE
|
||||
on = on.copy()
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
# find the join keys
|
||||
# SELECT
|
||||
# FROM x
|
||||
# JOIN y
|
||||
# ON x.a = y.b AND y.b > 1
|
||||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.TRUE)
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.TRUE)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.TRUE else on
|
||||
|
||||
return source_key, join_key, remaining_condition
|
|
@ -8,7 +8,7 @@ from sqlglot.optimizer.simplify import simplify
|
|||
|
||||
def eliminate_subqueries(expression):
|
||||
"""
|
||||
Rewrite subqueries as CTES, deduplicating if possible.
|
||||
Rewrite derived tables as CTES, deduplicating if possible.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
|
|
@ -119,6 +119,23 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
|
||||
def _is_a_window_expression_in_unmergable_operation():
|
||||
window_expressions = inner_select.find_all(exp.Window)
|
||||
window_alias_names = {window.parent.alias_or_name for window in window_expressions}
|
||||
inner_select_name = inner_select.parent.alias_or_name
|
||||
unmergable_window_columns = [
|
||||
column
|
||||
for column in outer_scope.columns
|
||||
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
|
||||
]
|
||||
window_expressions_in_unmergable = [
|
||||
column
|
||||
for column in unmergable_window_columns
|
||||
if column.table == inner_select_name and column.name in window_alias_names
|
||||
]
|
||||
return any(window_expressions_in_unmergable)
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and isinstance(inner_select, exp.Select)
|
||||
|
@ -137,6 +154,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
)
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
|
@ -23,6 +25,8 @@ RULES = (
|
|||
optimize_joins,
|
||||
eliminate_subqueries,
|
||||
merge_subqueries,
|
||||
eliminate_joins,
|
||||
eliminate_ctes,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
|
@ -22,15 +20,10 @@ def pushdown_predicates(expression):
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
scopes = traverse_scope(expression)
|
||||
scopes.reverse()
|
||||
root = build_scope(expression)
|
||||
scope_ref_count = root.ref_count()
|
||||
|
||||
for scope in scopes:
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
for scope in scopes:
|
||||
for scope in reversed(list(root.traverse())):
|
||||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
|
@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
# We can't push down window expressions
|
||||
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2:
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -314,6 +315,16 @@ class Scope:
|
|||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
def unqualified_columns(self):
|
||||
"""
|
||||
Unqualified columns in the current scope.
|
||||
|
||||
Returns:
|
||||
list[exp.Column]: Unqualified columns
|
||||
"""
|
||||
return [c for c in self.columns if not c.table]
|
||||
|
||||
@property
|
||||
def join_hints(self):
|
||||
"""
|
||||
|
@ -403,6 +414,21 @@ class Scope:
|
|||
yield from child_scope.traverse()
|
||||
yield self
|
||||
|
||||
def ref_count(self):
|
||||
"""
|
||||
Count the number of times each scope in this tree is referenced.
|
||||
|
||||
Returns:
|
||||
dict[int, int]: Mapping of Scope instance ID to reference count
|
||||
"""
|
||||
scope_ref_count = defaultdict(lambda: 0)
|
||||
|
||||
for scope in self.traverse():
|
||||
for _, source in scope.selected_sources.values():
|
||||
scope_ref_count[id(source)] += 1
|
||||
|
||||
return scope_ref_count
|
||||
|
||||
|
||||
def traverse_scope(expression):
|
||||
"""
|
||||
|
|
|
@ -135,11 +135,13 @@ class Parser:
|
|||
TokenType.BOTH,
|
||||
TokenType.BUCKET,
|
||||
TokenType.CACHE,
|
||||
TokenType.CALL,
|
||||
TokenType.COLLATE,
|
||||
TokenType.COMMIT,
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.DEFAULT,
|
||||
TokenType.DELETE,
|
||||
TokenType.DESCRIBE,
|
||||
TokenType.DETERMINISTIC,
|
||||
TokenType.EXECUTE,
|
||||
TokenType.ENGINE,
|
||||
|
@ -160,6 +162,7 @@ class Parser:
|
|||
TokenType.LAZY,
|
||||
TokenType.LANGUAGE,
|
||||
TokenType.LEADING,
|
||||
TokenType.LOCAL,
|
||||
TokenType.LOCATION,
|
||||
TokenType.MATERIALIZED,
|
||||
TokenType.NATURAL,
|
||||
|
@ -176,6 +179,7 @@ class Parser:
|
|||
TokenType.REFERENCES,
|
||||
TokenType.RETURNS,
|
||||
TokenType.ROWS,
|
||||
TokenType.SCHEMA,
|
||||
TokenType.SCHEMA_COMMENT,
|
||||
TokenType.SEED,
|
||||
TokenType.SEMI,
|
||||
|
@ -294,6 +298,11 @@ class Parser:
|
|||
|
||||
COLUMN_OPERATORS = {
|
||||
TokenType.DOT: None,
|
||||
TokenType.DCOLON: lambda self, this, to: self.expression(
|
||||
exp.Cast,
|
||||
this=this,
|
||||
to=to,
|
||||
),
|
||||
TokenType.ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
|
@ -342,8 +351,10 @@ class Parser:
|
|||
|
||||
STATEMENT_PARSERS = {
|
||||
TokenType.CREATE: lambda self: self._parse_create(),
|
||||
TokenType.DESCRIBE: lambda self: self._parse_describe(),
|
||||
TokenType.DROP: lambda self: self._parse_drop(),
|
||||
TokenType.INSERT: lambda self: self._parse_insert(),
|
||||
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
|
||||
TokenType.UPDATE: lambda self: self._parse_update(),
|
||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||
|
@ -449,7 +460,14 @@ class Parser:
|
|||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE}
|
||||
CREATABLES = {
|
||||
TokenType.TABLE,
|
||||
TokenType.VIEW,
|
||||
TokenType.FUNCTION,
|
||||
TokenType.INDEX,
|
||||
TokenType.PROCEDURE,
|
||||
TokenType.SCHEMA,
|
||||
}
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
|
@ -650,7 +668,7 @@ class Parser:
|
|||
materialized = self._match(TokenType.MATERIALIZED)
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
if not kind:
|
||||
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return
|
||||
|
||||
return self.expression(
|
||||
|
@ -677,7 +695,7 @@ class Parser:
|
|||
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||
|
||||
if not create_token:
|
||||
self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE")
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return
|
||||
|
||||
exists = self._parse_exists(not_=True)
|
||||
|
@ -692,7 +710,7 @@ class Parser:
|
|||
expression = self._parse_select_or_expression()
|
||||
elif create_token.token_type == TokenType.INDEX:
|
||||
this = self._parse_index()
|
||||
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
|
||||
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA):
|
||||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
if self._match(TokenType.ALIAS):
|
||||
|
@ -836,19 +854,74 @@ class Parser:
|
|||
return self.expression(exp.Properties, expressions=properties)
|
||||
return None
|
||||
|
||||
def _parse_describe(self):
|
||||
self._match(TokenType.TABLE)
|
||||
|
||||
return self.expression(exp.Describe, this=self._parse_id_var())
|
||||
|
||||
def _parse_insert(self):
|
||||
overwrite = self._match(TokenType.OVERWRITE)
|
||||
self._match(TokenType.INTO)
|
||||
self._match(TokenType.TABLE)
|
||||
local = self._match(TokenType.LOCAL)
|
||||
if self._match_text("DIRECTORY"):
|
||||
this = self.expression(
|
||||
exp.Directory,
|
||||
this=self._parse_var_or_string(),
|
||||
local=local,
|
||||
row_format=self._parse_row_format(),
|
||||
)
|
||||
else:
|
||||
self._match(TokenType.INTO)
|
||||
self._match(TokenType.TABLE)
|
||||
this = self._parse_table(schema=True)
|
||||
return self.expression(
|
||||
exp.Insert,
|
||||
this=self._parse_table(schema=True),
|
||||
this=this,
|
||||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
expression=self._parse_select(nested=True),
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def _parse_row_format(self):
|
||||
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
|
||||
return None
|
||||
|
||||
self._match_text("DELIMITED")
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if self._match_text("FIELDS", "TERMINATED", "BY"):
|
||||
kwargs["fields"] = self._parse_string()
|
||||
if self._match_text("ESCAPED", "BY"):
|
||||
kwargs["escaped"] = self._parse_string()
|
||||
if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
|
||||
kwargs["collection_items"] = self._parse_string()
|
||||
if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
|
||||
kwargs["map_keys"] = self._parse_string()
|
||||
if self._match_text("LINES", "TERMINATED", "BY"):
|
||||
kwargs["lines"] = self._parse_string()
|
||||
if self._match_text("NULL", "DEFINED", "AS"):
|
||||
kwargs["null"] = self._parse_string()
|
||||
return self.expression(exp.RowFormat, **kwargs)
|
||||
|
||||
def _parse_load_data(self):
|
||||
local = self._match(TokenType.LOCAL)
|
||||
self._match_text("INPATH")
|
||||
inpath = self._parse_string()
|
||||
overwrite = self._match(TokenType.OVERWRITE)
|
||||
self._match_pair(TokenType.INTO, TokenType.TABLE)
|
||||
|
||||
return self.expression(
|
||||
exp.LoadData,
|
||||
this=self._parse_table(schema=True),
|
||||
local=local,
|
||||
overwrite=overwrite,
|
||||
inpath=inpath,
|
||||
partition=self._parse_partition(),
|
||||
input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
|
||||
serde=self._match_text("SERDE") and self._parse_string(),
|
||||
)
|
||||
|
||||
def _parse_delete(self):
|
||||
self._match(TokenType.FROM)
|
||||
|
||||
|
@ -1484,6 +1557,14 @@ class Parser:
|
|||
|
||||
if self._match_set(self.RANGE_PARSERS):
|
||||
this = self.RANGE_PARSERS[self._prev.token_type](self, this)
|
||||
elif self._match(TokenType.ISNULL):
|
||||
this = self.expression(exp.Is, this=this, expression=exp.Null())
|
||||
|
||||
# Postgres supports ISNULL and NOTNULL for conditions.
|
||||
# https://blog.andreiavram.ro/postgresql-null-composite-type/
|
||||
if self._match(TokenType.NOTNULL):
|
||||
this = self.expression(exp.Is, this=this, expression=exp.Null())
|
||||
this = self.expression(exp.Not, this=this)
|
||||
|
||||
if negate:
|
||||
this = self.expression(exp.Not, this=this)
|
||||
|
@ -1582,12 +1663,6 @@ class Parser:
|
|||
return self._parse_column()
|
||||
return type_token
|
||||
|
||||
while self._match(TokenType.DCOLON):
|
||||
type_token = self._parse_types()
|
||||
if not type_token:
|
||||
self.raise_error("Expected type")
|
||||
this = self.expression(exp.Cast, this=this, to=type_token)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_types(self):
|
||||
|
@ -1601,6 +1676,11 @@ class Parser:
|
|||
is_struct = type_token == TokenType.STRUCT
|
||||
expressions = None
|
||||
|
||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True
|
||||
)
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
@ -1611,7 +1691,7 @@ class Parser:
|
|||
elif nested:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_type)
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
|
||||
if not expressions:
|
||||
self._retreat(index)
|
||||
|
@ -1677,8 +1757,17 @@ class Parser:
|
|||
this = self._parse_bracket(this)
|
||||
|
||||
while self._match_set(self.COLUMN_OPERATORS):
|
||||
op = self.COLUMN_OPERATORS.get(self._prev.token_type)
|
||||
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
||||
op_token = self._prev.token_type
|
||||
op = self.COLUMN_OPERATORS.get(op_token)
|
||||
|
||||
if op_token == TokenType.DCOLON:
|
||||
field = self._parse_types()
|
||||
if not field:
|
||||
self.raise_error("Expected type")
|
||||
elif op:
|
||||
field = exp.Literal.string(self._advance() or self._prev.text)
|
||||
else:
|
||||
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
||||
|
||||
if isinstance(field, exp.Func):
|
||||
# bigquery allows function calls like x.y.count(...)
|
||||
|
@ -1687,7 +1776,7 @@ class Parser:
|
|||
this = self._replace_columns_with_dots(this)
|
||||
|
||||
if op:
|
||||
this = op(self, this, exp.Literal.string(field.name))
|
||||
this = op(self, this, field)
|
||||
elif isinstance(this, exp.Column) and not this.table:
|
||||
this = self.expression(exp.Column, this=field, table=this.this)
|
||||
else:
|
||||
|
@ -1808,11 +1897,10 @@ class Parser:
|
|||
if not self._match(TokenType.ARROW):
|
||||
self._retreat(index)
|
||||
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
this = self._parse_conjunction()
|
||||
|
||||
if distinct:
|
||||
this = self.expression(exp.Distinct, this=this)
|
||||
if self._match(TokenType.DISTINCT):
|
||||
this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction))
|
||||
else:
|
||||
this = self._parse_conjunction()
|
||||
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
|
@ -2112,6 +2200,8 @@ class Parser:
|
|||
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
|
||||
self._match_r_paren()
|
||||
|
||||
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
|
||||
if self._match(TokenType.WITHIN_GROUP):
|
||||
self._match_l_paren()
|
||||
this = self.expression(
|
||||
|
@ -2120,7 +2210,6 @@ class Parser:
|
|||
expression=self._parse_order(),
|
||||
)
|
||||
self._match_r_paren()
|
||||
return this
|
||||
|
||||
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
|
||||
# Some dialects choose to implement and some do not.
|
||||
|
@ -2366,6 +2455,16 @@ class Parser:
|
|||
if not self._match(TokenType.R_PAREN):
|
||||
self.raise_error("Expecting )")
|
||||
|
||||
def _match_text(self, *texts):
|
||||
index = self._index
|
||||
for text in texts:
|
||||
if self._curr and self._curr.text.upper() == text:
|
||||
self._advance()
|
||||
else:
|
||||
self._retreat(index)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _replace_columns_with_dots(self, this):
|
||||
if isinstance(this, exp.Dot):
|
||||
exp.replace_children(this, self._replace_columns_with_dots)
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import UnsupportedError
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
from sqlglot.optimizer.eliminate_joins import join_condition
|
||||
|
||||
|
||||
class Plan:
|
||||
|
@ -236,40 +236,12 @@ class Join(Step):
|
|||
step = Join()
|
||||
|
||||
for join in joins:
|
||||
name = join.this.alias
|
||||
on = join.args.get("on") or exp.TRUE
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
# find the join keys
|
||||
# SELECT
|
||||
# FROM x
|
||||
# JOIN y
|
||||
# ON x.a = y.b AND y.b > 1
|
||||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.TRUE)
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.TRUE)
|
||||
|
||||
on = simplify(on)
|
||||
|
||||
step.joins[name] = {
|
||||
source_key, join_key, condition = join_condition(join)
|
||||
step.joins[join.this.alias_or_name] = {
|
||||
"side": join.side,
|
||||
"join_key": join_key,
|
||||
"source_key": source_key,
|
||||
"condition": None if on == exp.TRUE else on,
|
||||
"condition": condition,
|
||||
}
|
||||
|
||||
step.add_dependency(Scan.from_expression(join.this, ctes))
|
||||
|
|
|
@ -123,6 +123,7 @@ class TokenType(AutoName):
|
|||
CLUSTER_BY = auto()
|
||||
COLLATE = auto()
|
||||
COMMENT = auto()
|
||||
COMMENT_ON = auto()
|
||||
COMMIT = auto()
|
||||
CONSTRAINT = auto()
|
||||
CREATE = auto()
|
||||
|
@ -133,13 +134,14 @@ class TokenType(AutoName):
|
|||
CURRENT_ROW = auto()
|
||||
CURRENT_TIME = auto()
|
||||
CURRENT_TIMESTAMP = auto()
|
||||
DIV = auto()
|
||||
DEFAULT = auto()
|
||||
DELETE = auto()
|
||||
DESC = auto()
|
||||
DESCRIBE = auto()
|
||||
DETERMINISTIC = auto()
|
||||
DISTINCT = auto()
|
||||
DISTRIBUTE_BY = auto()
|
||||
DIV = auto()
|
||||
DROP = auto()
|
||||
ELSE = auto()
|
||||
END = auto()
|
||||
|
@ -189,6 +191,8 @@ class TokenType(AutoName):
|
|||
LEFT = auto()
|
||||
LIKE = auto()
|
||||
LIMIT = auto()
|
||||
LOAD_DATA = auto()
|
||||
LOCAL = auto()
|
||||
LOCATION = auto()
|
||||
MAP = auto()
|
||||
MATERIALIZED = auto()
|
||||
|
@ -196,6 +200,7 @@ class TokenType(AutoName):
|
|||
NATURAL = auto()
|
||||
NEXT = auto()
|
||||
NO_ACTION = auto()
|
||||
NOTNULL = auto()
|
||||
NULL = auto()
|
||||
NULLS_FIRST = auto()
|
||||
NULLS_LAST = auto()
|
||||
|
@ -436,13 +441,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"CURRENT_DATE": TokenType.CURRENT_DATE,
|
||||
"CURRENT ROW": TokenType.CURRENT_ROW,
|
||||
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
|
||||
"DIV": TokenType.DIV,
|
||||
"DEFAULT": TokenType.DEFAULT,
|
||||
"DELETE": TokenType.DELETE,
|
||||
"DESC": TokenType.DESC,
|
||||
"DESCRIBE": TokenType.DESCRIBE,
|
||||
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
||||
"DISTINCT": TokenType.DISTINCT,
|
||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||
"DIV": TokenType.DIV,
|
||||
"DROP": TokenType.DROP,
|
||||
"ELSE": TokenType.ELSE,
|
||||
"END": TokenType.END,
|
||||
|
@ -487,12 +493,15 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"LEFT": TokenType.LEFT,
|
||||
"LIKE": TokenType.LIKE,
|
||||
"LIMIT": TokenType.LIMIT,
|
||||
"LOAD DATA": TokenType.LOAD_DATA,
|
||||
"LOCAL": TokenType.LOCAL,
|
||||
"LOCATION": TokenType.LOCATION,
|
||||
"MATERIALIZED": TokenType.MATERIALIZED,
|
||||
"NATURAL": TokenType.NATURAL,
|
||||
"NEXT": TokenType.NEXT,
|
||||
"NO ACTION": TokenType.NO_ACTION,
|
||||
"NOT": TokenType.NOT,
|
||||
"NOTNULL": TokenType.NOTNULL,
|
||||
"NULL": TokenType.NULL,
|
||||
"NULLS FIRST": TokenType.NULLS_FIRST,
|
||||
"NULLS LAST": TokenType.NULLS_LAST,
|
||||
|
@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ROLLUP": TokenType.ROLLUP,
|
||||
"ROW": TokenType.ROW,
|
||||
"ROWS": TokenType.ROWS,
|
||||
"SCHEMA": TokenType.SCHEMA,
|
||||
"SEED": TokenType.SEED,
|
||||
"SELECT": TokenType.SELECT,
|
||||
"SEMI": TokenType.SEMI,
|
||||
|
@ -629,6 +639,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
TokenType.ANALYZE,
|
||||
TokenType.BEGIN,
|
||||
TokenType.CALL,
|
||||
TokenType.COMMENT_ON,
|
||||
TokenType.COMMIT,
|
||||
TokenType.EXPLAIN,
|
||||
TokenType.OPTIMIZE,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue