1
0
Fork 0

Merging upstream version 17.9.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:48:36 +01:00
parent 2bf6699c56
commit 9777880e00
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
87 changed files with 45907 additions and 42511 deletions

View file

@ -67,19 +67,22 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
def parse(
sql: str, read: DialectType = None, dialect: DialectType = None, **opts
) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
Args:
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
dialect: the SQL dialect (alias for read).
**opts: other `sqlglot.parser.Parser` options.
Returns:
The resulting syntax tree collection.
"""
dialect = Dialect.get_or_raise(read)()
dialect = Dialect.get_or_raise(read or dialect)()
return dialect.parse(sql, **opts)

View file

@ -386,7 +386,7 @@ def input_file_name() -> Column:
def isnan(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "ISNAN")
return Column.invoke_expression_over_column(col, expression.IsNan)
def isnull(col: ColumnOrName) -> Column:

View file

@ -211,6 +211,10 @@ class BigQuery(Dialect):
"TZH": "%z",
}
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
@classmethod
def normalize_identifier(cls, expression: E) -> E:
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).

View file

@ -380,7 +380,7 @@ class ClickHouse(Dialect):
]
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
params = self.expressions(expression, "params", flat=True)
params = self.expressions(expression, key="params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
def placeholder_sql(self, expression: exp.Placeholder) -> str:

View file

@ -5,6 +5,7 @@ from enum import Enum
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
@ -168,6 +169,10 @@ class Dialect(metaclass=_Dialect):
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
# Columns that are auto-generated by the engine corresponding to this dialect
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
# Autofilled
tokenizer_class = Tokenizer
parser_class = Parser
@ -497,6 +502,10 @@ def parse_date_delta_with_interval(
return None
interval = args[1]
if not isinstance(interval, exp.Interval):
raise ParseError(f"INTERVAL expression expected but got '{interval}'")
expression = interval.this
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
@ -555,11 +564,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return self.sql(exp.cast(expression.this, "date"))
def min_or_least(self: Generator, expression: exp.Min) -> str:
@ -608,8 +617,9 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
return self.sql(exp.cast(self.sql(expression, "this"), "date"))
return _ts_or_ds_to_date_sql
@ -664,5 +674,15 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names
def simplify_literal(expression: E, copy: bool = True) -> E:
if not isinstance(expression.expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = exp.maybe_copy(expression, copy)
simplify(expression.expression)
return expression
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))

View file

@ -359,14 +359,16 @@ class Hive(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.BIT: "BOOLEAN",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIME: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
@ -396,6 +398,7 @@ class Hive(Dialect):
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: _json_format_sql,

View file

@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
simplify_literal,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@ -303,6 +304,22 @@ class MySQL(Dialect):
"NAMES": lambda self: self._parse_set_item_names(),
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS,
"FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"),
"INDEX": lambda self: self._parse_index_constraint(),
"KEY": lambda self: self._parse_index_constraint(),
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
}
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"FULLTEXT",
"INDEX",
"KEY",
"SPATIAL",
}
PROFILE_TYPES = {
"ALL",
"BLOCK IO",
@ -327,6 +344,57 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
def _parse_index_constraint(
self, kind: t.Optional[str] = None
) -> exp.IndexColumnConstraint:
if kind:
self._match_texts({"INDEX", "KEY"})
this = self._parse_id_var(any_token=False)
type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text
schema = self._parse_schema()
options = []
while True:
if self._match_text_seq("KEY_BLOCK_SIZE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(key_block_size=self._parse_number())
elif self._match(TokenType.USING):
opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text)
elif self._match_text_seq("WITH", "PARSER"):
opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True))
elif self._match(TokenType.COMMENT):
opt = exp.IndexConstraintOption(comment=self._parse_string())
elif self._match_text_seq("VISIBLE"):
opt = exp.IndexConstraintOption(visible=True)
elif self._match_text_seq("INVISIBLE"):
opt = exp.IndexConstraintOption(visible=False)
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
else:
opt = None
if not opt:
break
options.append(opt)
return self.expression(
exp.IndexColumnConstraint,
this=this,
schema=schema,
kind=kind,
type=type_,
options=options,
)
def _parse_show_mysql(
self,
this: str,
@ -454,6 +522,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
@ -485,6 +554,16 @@ class MySQL(Dialect):
exp.DataType.Type.VARCHAR: "CHAR",
}
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression)
return super().limit_sql(expression, top=top)
def offset_sql(self, expression: exp.Offset) -> str:
# MySQL requires simple literal values for its OFFSET clause.
expression = simplify_literal(expression)
return super().offset_sql(expression)
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")

View file

@ -30,6 +30,9 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
TIME_MAPPING = {

View file

@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
rename_func,
simplify_literal,
str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
@ -39,16 +40,13 @@ DATE_DIFF_FACTOR = {
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
unit = expression.args.get("unit")
expression = simplify(expression.args["expression"])
expression = simplify_literal(expression.copy(), copy=False).expression
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")
expression = expression.copy()
expression.args["is_string"] = True
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"

View file

@ -192,6 +192,8 @@ class Presto(Dialect):
"START": TokenType.BEGIN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
"IPADDRESS": TokenType.IPADDRESS,
"IPPREFIX": TokenType.IPPREFIX,
}
class Parser(parser.Parser):

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
@ -47,7 +48,11 @@ class Spark(Spark2):
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
}
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.StartsWith: rename_func("STARTSWITH"),
}
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)

View file

@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
if (
kind.upper() == "TABLE"
and e.expression
and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
)
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)

View file

@ -33,8 +33,10 @@ class Teradata(Dialect):
**tokens.Tokenizer.KEYWORDS,
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
"COLLECT": TokenType.COMMAND,
"GE": TokenType.GTE,
"GT": TokenType.GT,
"HELP": TokenType.COMMAND,
"INS": TokenType.INSERT,
"LE": TokenType.LTE,
"LT": TokenType.LT,

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import datetime
import re
import typing as t
@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
parse_date_delta,
rename_func,
timestrtotime_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
DEFAULT_START_DATE = datetime.date(1900, 1, 1)
def _format_time_lambda(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}){order}"
def _parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
def inner_func(args: t.List) -> E:
unit = seq_get(args, 0)
if unit and unit_mapping:
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
start_date = seq_get(args, 1)
if start_date and start_date.is_number:
# Numeric types are valid DATETIME values
if start_date.is_int:
adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this))
start_date = exp.Literal.string(adds.strftime("%F"))
else:
# We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
# This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
return exp_class(
this=exp.TimeStrToTime(this=seq_get(args, 2)),
expression=exp.TimeStrToTime(this=start_date),
unit=unit,
)
return inner_func
class TSQL(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
@ -298,7 +330,6 @@ class TSQL(Dialect):
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
@ -307,10 +338,6 @@ class TSQL(Dialect):
"SYSTEM_USER": TokenType.CURRENT_USER,
}
# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
SINGLE_TOKENS.pop("#")
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -320,7 +347,7 @@ class TSQL(Dialect):
position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"EOMONTH": _parse_eomonth,
@ -518,6 +545,36 @@ class TSQL(Dialect):
expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_id_var(
self,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
is_temporary = self._match(TokenType.HASH)
is_global = is_temporary and self._match(TokenType.HASH)
this = super()._parse_id_var(any_token=any_token, tokens=tokens)
if this:
if is_global:
this.set("global", True)
elif is_temporary:
this.set("temporary", True)
return this
def _parse_create(self) -> exp.Create | exp.Command:
create = super()._parse_create()
if isinstance(create, exp.Create):
table = create.this.this if isinstance(create.this, exp.Schema) else create.this
if isinstance(table, exp.Table) and table.this.args.get("temporary"):
if not create.args.get("properties"):
create.set("properties", exp.Properties(expressions=[]))
create.args["properties"].append("expressions", exp.TemporaryProperty())
return create
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
@ -526,9 +583,11 @@ class TSQL(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}
@ -552,6 +611,8 @@ class TSQL(Dialect):
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
}
@ -564,6 +625,22 @@ class TSQL(Dialect):
LIMIT_FETCH = "FETCH"
def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
sql = self.sql(expression, "this")
properties = expression.args.get("properties")
if sql[:1] != "#" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
sql = f"#{sql}"
return sql
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@ -616,3 +693,13 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"ROLLBACK TRANSACTION{this}"
def identifier_sql(self, expression: exp.Identifier) -> str:
identifier = super().identifier_sql(expression)
if expression.args.get("global"):
identifier = f"##{identifier}"
elif expression.args.get("temporary"):
identifier = f"#{identifier}"
return identifier

View file

@ -67,8 +67,9 @@ class Expression(metaclass=_Expression):
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
meta: a dictionary that can be used to store useful metadata for a given expression.
Example:
>>> class Foo(Expression):
@ -767,7 +768,7 @@ class Condition(Expression):
**opts,
) -> In:
return In(
this=_maybe_copy(self, copy),
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@ -781,7 +782,7 @@ class Condition(Expression):
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
return Between(
this=_maybe_copy(self, copy),
this=maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)
@ -990,7 +991,28 @@ class Uncache(Expression):
arg_types = {"this": True, "exists": False}
class Create(Expression):
class DDL(Expression):
@property
def ctes(self):
with_ = self.args.get("with")
if not with_:
return []
return with_.expressions
@property
def named_selects(self) -> t.List[str]:
if isinstance(self.expression, Subqueryable):
return self.expression.named_selects
return []
@property
def selects(self) -> t.List[Expression]:
if isinstance(self.expression, Subqueryable):
return self.expression.selects
return []
class Create(DDL):
arg_types = {
"with": False,
"this": True,
@ -1206,6 +1228,19 @@ class MergeTreeTTL(Expression):
}
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
class IndexConstraintOption(Expression):
arg_types = {
"key_block_size": False,
"using": False,
"parser": False,
"comment": False,
"visible": False,
"engine_attr": False,
"secondary_engine_attr": False,
}
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@ -1272,6 +1307,11 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
}
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
class IndexColumnConstraint(ColumnConstraintKind):
arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False}
class InlineLengthColumnConstraint(ColumnConstraintKind):
pass
@ -1496,7 +1536,7 @@ class JoinHint(Expression):
class Identifier(Expression):
arg_types = {"this": True, "quoted": False}
arg_types = {"this": True, "quoted": False, "global": False, "temporary": False}
@property
def quoted(self) -> bool:
@ -1525,7 +1565,7 @@ class Index(Expression):
}
class Insert(Expression):
class Insert(DDL):
arg_types = {
"with": False,
"this": True,
@ -1892,6 +1932,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
class HeapProperty(Property):
arg_types = {}
class ToTableProperty(Property):
arg_types = {"this": True}
@ -2182,7 +2226,7 @@ class Tuple(Expression):
**opts,
) -> In:
return In(
this=_maybe_copy(self, copy),
this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@ -2212,7 +2256,7 @@ class Subqueryable(Unionable):
Returns:
Alias: the subquery
"""
instance = _maybe_copy(self, copy)
instance = maybe_copy(self, copy)
if not isinstance(alias, Expression):
alias = TableAlias(this=to_identifier(alias)) if alias else None
@ -2865,7 +2909,7 @@ class Select(Subqueryable):
self,
expression: ExpOrStr,
on: t.Optional[ExpOrStr] = None,
using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None,
using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None,
append: bool = True,
join_type: t.Optional[str] = None,
join_alias: t.Optional[Identifier | str] = None,
@ -2943,6 +2987,7 @@ class Select(Subqueryable):
arg="using",
append=append,
copy=copy,
into=Identifier,
**opts,
)
@ -3092,7 +3137,7 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
instance = _maybe_copy(self, copy)
instance = maybe_copy(self, copy)
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None
instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
@ -3123,7 +3168,7 @@ class Select(Subqueryable):
Returns:
The new Create expression.
"""
instance = _maybe_copy(self, copy)
instance = maybe_copy(self, copy)
table_expression = maybe_parse(
table,
into=Table,
@ -3159,7 +3204,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst = maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
@ -3181,7 +3226,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst = maybe_copy(self, copy)
inst.set(
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
)
@ -3376,6 +3421,8 @@ class DataType(Expression):
HSTORE = auto()
IMAGE = auto()
INET = auto()
IPADDRESS = auto()
IPPREFIX = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
@ -3987,7 +4034,7 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
instance = _maybe_copy(self, copy)
instance = maybe_copy(self, copy)
instance.append(
"ifs",
If(
@ -3998,7 +4045,7 @@ class Case(Func):
return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
instance = _maybe_copy(self, copy)
instance = maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance
@ -4263,6 +4310,10 @@ class Initcap(Func):
arg_types = {"this": True, "expression": False}
class IsNan(Func):
_sql_names = ["IS_NAN", "ISNAN"]
class JSONKeyValue(Expression):
arg_types = {"this": True, "expression": True}
@ -4549,6 +4600,11 @@ class StandardHash(Func):
arg_types = {"this": True, "expression": False}
class StartsWith(Func):
_sql_names = ["STARTS_WITH", "STARTSWITH"]
arg_types = {"this": True, "expression": True}
class StrPosition(Func):
arg_types = {
"this": True,
@ -4804,7 +4860,7 @@ def maybe_parse(
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
def _maybe_copy(instance: E, copy: bool = True) -> E:
def maybe_copy(instance: E, copy: bool = True) -> E:
return instance.copy() if copy else instance
@ -4824,7 +4880,7 @@ def _apply_builder(
):
if _is_wrong_expression(expression, into):
expression = into(this=expression)
instance = _maybe_copy(instance, copy)
instance = maybe_copy(instance, copy)
expression = maybe_parse(
sql_or_expression=expression,
prefix=prefix,
@ -4848,7 +4904,7 @@ def _apply_child_list_builder(
properties=None,
**opts,
):
instance = _maybe_copy(instance, copy)
instance = maybe_copy(instance, copy)
parsed = []
for expression in expressions:
if expression is not None:
@ -4887,7 +4943,7 @@ def _apply_list_builder(
dialect=None,
**opts,
):
inst = _maybe_copy(instance, copy)
inst = maybe_copy(instance, copy)
expressions = [
maybe_parse(
@ -4923,7 +4979,7 @@ def _apply_conjunction_builder(
if not expressions:
return instance
inst = _maybe_copy(instance, copy)
inst = maybe_copy(instance, copy)
existing = inst.args.get(arg)
if append and existing is not None:
@ -5398,7 +5454,7 @@ def to_identifier(name, quoted=None, copy=True):
return None
if isinstance(name, Identifier):
identifier = _maybe_copy(name, copy)
identifier = maybe_copy(name, copy)
elif isinstance(name, str):
identifier = Identifier(
this=name,
@ -5735,7 +5791,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
return _maybe_copy(value, copy)
return maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):

View file

@ -68,6 +68,7 @@ class Generator:
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
@ -161,6 +162,9 @@ class Generator:
# Whether or not to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether or not to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@ -224,6 +228,7 @@ class Generator:
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
@ -265,9 +270,12 @@ class Generator:
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Select,
exp.Update,
exp.Where,
exp.With,
)
@ -985,8 +993,9 @@ class Generator:
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
if expressions:
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
@ -1905,7 +1914,7 @@ class Generator:
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
@ -2370,7 +2379,12 @@ class Generator:
elif arg_value is not None:
args.append(arg_value)
return self.func(expression.sql_name(), *args)
if self.normalize_functions:
name = expression.sql_name()
else:
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
return self.func(name, *args)
def func(
self,
@ -2412,7 +2426,7 @@ class Generator:
return ""
if flat:
return sep.join(self.sql(e) for e in expressions)
return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql)
num_sqls = len(expressions)
@ -2423,6 +2437,9 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
if not sql:
continue
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
@ -2562,6 +2579,51 @@ class Generator:
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str:
key_block_size = self.sql(expression, "key_block_size")
if key_block_size:
return f"KEY_BLOCK_SIZE = {key_block_size}"
using = self.sql(expression, "using")
if using:
return f"USING {using}"
parser = self.sql(expression, "parser")
if parser:
return f"WITH PARSER {parser}"
comment = self.sql(expression, "comment")
if comment:
return f"COMMENT {comment}"
visible = expression.args.get("visible")
if visible is not None:
return "VISIBLE" if visible else "INVISIBLE"
engine_attr = self.sql(expression, "engine_attr")
if engine_attr:
return f"ENGINE_ATTRIBUTE = {engine_attr}"
secondary_engine_attr = self.sql(expression, "secondary_engine_attr")
if secondary_engine_attr:
return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}"
self.unsupported("Unsupported index constraint option.")
return ""
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
kind = self.sql(expression, "kind")
kind = f"{kind} INDEX" if kind else "INDEX"
this = self.sql(expression, "this")
this = f" {this}" if this else ""
type_ = self.sql(expression, "type")
type_ = f" USING {type_}" if type_ else ""
schema = self.sql(expression, "schema")
schema = f" {schema}" if schema else ""
options = self.expressions(expression, key="options", sep=" ")
options = f" {options}" if options else ""
return f"{kind}{this}{type_}{schema}{options}"
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
# This ensures we don't drop the "pivot" arg from a pivoted subquery
if scope.parent.pivots:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent

View file

@ -1,8 +1,23 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
...
def normalize_identifiers(expression, dialect=None):
"""
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
>>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
'FOO'
Args:
expression: The expression to transform.
@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
expression = exp.maybe_parse(expression, dialect=dialect)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)

View file

@ -39,6 +39,7 @@ def qualify_columns(
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@ -55,7 +56,7 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _expand_stars(
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
scope: Scope,
resolver: Resolver,
using_column_tables: t.Dict[str, t.Any],
pseudocolumns: t.Set[str],
) -> None:
"""Expand stars to lists of column selections"""
@ -367,14 +371,8 @@ def _expand_stars(
columns = resolver.get_source_columns(table, only_visible=True)
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
if resolver.schema.dialect == "bigquery":
columns = [
name
for name in columns
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
]
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:

View file

@ -80,7 +80,9 @@ def qualify_tables(
header = next(reader)
columns = next(reader)
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
source,
{k: type(v).__name__ for k, v in zip(header, columns)},
match_depth=False,
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression

View file

@ -435,7 +435,10 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(self.is_subquery and self.external_columns)
return bool(
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
and self.external_columns
)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
@ -486,7 +489,7 @@ class Scope:
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
if not isinstance(expression, exp.Unionable):
return []
return list(_traverse_scope(Scope(expression)))
if isinstance(expression, exp.Unionable) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
):
return list(_traverse_scope(Scope(expression)))
return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
@ -539,7 +545,9 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
yield from _traverse_udtfs(scope)
elif isinstance(scope.expression, exp.DDL):
yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@ -576,10 +584,10 @@ def _traverse_ctes(scope):
for cte in scope.ctes:
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if scope.expression.args["with"].recursive:
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
# thus the recursive scope is the first section of the union.
with_ = scope.expression.args.get("with")
if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
@ -692,8 +700,7 @@ def _traverse_tables(scope):
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
alias = expression.alias
sources[alias] = child_scope
sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
scope.subquery_scopes.append(top)
def _traverse_udtfs(scope):
if isinstance(scope.expression, exp.Unnest):
expressions = scope.expression.expressions
elif isinstance(scope.expression, exp.Lateral):
expressions = [scope.expression.this]
else:
expressions = []
sources = {}
for expression in expressions:
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
)
):
yield child_scope
top = child_scope
sources[expression.alias] = child_scope
scope.derived_table_scopes.append(top)
scope.table_scopes.append(top)
scope.sources.update(sources)
def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at

View file

@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
if not predicate or parent_select is not predicate.parent_select:
return
# this subquery returns a scalar and can just be converted to a cross join
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
(not clause or clause_parent_select is not parent_select)
and (
parent_select.args.get("group")
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
)
):
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
return
if select.find(exp.Limit, exp.Offset):

View file

@ -185,6 +185,8 @@ class Parser(metaclass=_Parser):
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
TokenType.ENUM,
*NESTED_TYPE_TOKENS,
}
@ -603,6 +605,7 @@ class Parser(metaclass=_Parser):
"FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
"HEAP": lambda self: self.expression(exp.HeapProperty),
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@ -832,6 +835,7 @@ class Parser(metaclass=_Parser):
UNNEST_COLUMN_ONLY: bool = False
ALIAS_POST_TABLESAMPLE: bool = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS = "upper"
NULL_ORDERING: str = "nulls_are_small"
SHOW_TRIE: t.Dict = {}
SET_TRIE: t.Dict = {}
@ -1187,7 +1191,7 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists(not_=True)
this = None
expression = None
expression: t.Optional[exp.Expression] = None
indexes = None
no_schema_binding = None
begin = None
@ -1207,12 +1211,16 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
expression = self._parse_statement()
if return_:
expression = self.expression(exp.Return, this=expression)
if self._match(TokenType.COMMAND):
expression = self._parse_as_command(self._prev)
else:
begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
expression = self._parse_statement()
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
@ -1692,6 +1700,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Describe, this=this, kind=kind)
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
overwrite = self._match(TokenType.OVERWRITE)
ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
@ -1709,6 +1718,7 @@ class Parser(metaclass=_Parser):
alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text
self._match(TokenType.INTO)
comments += ensure_list(self._prev_comments)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)
@ -1716,6 +1726,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Insert,
comments=comments,
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
@ -1840,6 +1851,7 @@ class Parser(metaclass=_Parser):
# This handles MySQL's "Multiple-Table Syntax"
# https://dev.mysql.com/doc/refman/8.0/en/delete.html
tables = None
comments = self._prev_comments
if not self._match(TokenType.FROM, advance=False):
tables = self._parse_csv(self._parse_table) or None
@ -1847,6 +1859,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
comments=comments,
tables=tables,
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
using=self._match(TokenType.USING) and self._parse_table(joins=True),
@ -1856,11 +1869,13 @@ class Parser(metaclass=_Parser):
)
def _parse_update(self) -> exp.Update:
comments = self._prev_comments
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
returning = self._parse_returning()
return self.expression(
exp.Update,
comments=comments,
**{ # type: ignore
"this": this,
"expressions": expressions,
@ -2235,7 +2250,12 @@ class Parser(metaclass=_Parser):
return None
if not this:
this = self._parse_function() or self._parse_id_var(any_token=False)
this = (
self._parse_unnest()
or self._parse_function()
or self._parse_id_var(any_token=False)
)
while self._match(TokenType.DOT):
this = exp.Dot(
this=this,
@ -3341,7 +3361,10 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if function and not anonymous:
this = self.validate_expression(function(args), args)
func = self.validate_expression(function(args), args)
if not self.NORMALIZE_FUNCTIONS:
func.meta["name"] = this
this = func
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@ -3842,13 +3865,11 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(self._parse_conjunction)
index = self._index
if not self._match(TokenType.R_PAREN):
if not self._match(TokenType.R_PAREN) and args:
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
return self.expression(
exp.GroupConcat,
this=seq_get(args, 0),
separator=self._parse_order(this=seq_get(args, 1)),
)
# bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n])
args[-1] = self._parse_limit(this=self._parse_order(this=args[-1]))
return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1))
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
@ -4172,7 +4193,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(
window = self.expression(
exp.Window,
this=this,
partition_by=partition,
@ -4183,6 +4204,12 @@ class Parser(metaclass=_Parser):
first=first,
)
# This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...)
if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False):
return self._parse_window(window, alias=alias)
return window
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
@ -4276,19 +4303,19 @@ class Parser(metaclass=_Parser):
def _parse_null(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NULL):
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None
return self._parse_placeholder()
def _parse_boolean(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.TRUE):
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None
return self._parse_placeholder()
def _parse_star(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STAR):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
wrapped = self._match(TokenType.L_BRACE)

View file

@ -31,14 +31,19 @@ class Schema(abc.ABC):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
match_depth: bool = True,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
normalize: whether to normalize identifiers according to the dialect of interest.
match_depth: whether to enforce that the table must match the schema's depth or not.
"""
@abc.abstractmethod
@ -47,6 +52,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> t.List[str]:
"""
Get the column names for a table.
@ -55,6 +61,7 @@ class Schema(abc.ABC):
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
@ -66,6 +73,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
"""
Get the `sqlglot.exp.DataType` type of a column in the schema.
@ -74,6 +82,7 @@ class Schema(abc.ABC):
table: the source table.
column: the target column.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]):
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]):
def empty(self) -> bool:
return not self.mapping
def _depth(self) -> int:
def depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
depth = self._depth()
depth = self.depth()
if not depth: # None
self._supported_table_args = tuple()
@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
self._depth = 0
super().__init__(self._normalize(schema or {}))
@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
normalize=mapping_schema.normalize,
)
def copy(self, **kwargs) -> MappingSchema:
@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
"normalize": self.normalize,
**kwargs,
}
)
@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
match_depth: bool = True,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
normalize: whether to normalize identifiers according to the dialect of interest.
match_depth: whether to enforce that the table must match the schema's depth or not.
"""
normalized_table = self._normalize_table(table, dialect=dialect)
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
raise SchemaError(
f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
f"schema's nesting level: {self.depth()}."
)
normalized_column_mapping = {
self._normalize_name(key, dialect=dialect): value
self._normalize_name(key, dialect=dialect, normalize=normalize): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> t.List[str]:
normalized_table = self._normalize_table(table, dialect=dialect)
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
schema = self.find(normalized_table)
if schema is None:
@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
normalized_table = self._normalize_table(table, dialect=dialect)
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)
table_schema = self.find(normalized_table, raise_on_missing=False)
@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Returns:
The normalized schema mapping.
"""
normalized_mapping: t.Dict = {}
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
if not isinstance(columns, dict):
raise SchemaError(
f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
)
normalized_keys = [
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
def _normalize_table(
self,
table: exp.Table | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.Table:
normalized_table = exp.maybe_parse(
table, into=exp.Table, dialect=dialect or self.dialect, copy=True
)
@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(
arg,
exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
exp.to_identifier(
self._normalize_name(
value, dialect=dialect, is_table=True, normalize=normalize
)
),
)
return normalized_table
def _normalize_name(
self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
self,
name: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = None,
) -> str:
dialect = dialect or self.dialect
normalize = self.normalize if normalize is None else normalize
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return name if isinstance(name, str) else name.name
name = identifier.name
if not self.normalize:
if not normalize:
return name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
def depth(self) -> int:
if not self.empty and not self._depth:
# The columns themselves are a mapping, but we don't want to include those
self._depth = super().depth() - 1
return self._depth
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""

View file

@ -147,6 +147,8 @@ class TokenType(AutoName):
VARIANT = auto()
OBJECT = auto()
INET = auto()
IPADDRESS = auto()
IPPREFIX = auto()
ENUM = auto()
# keywords

View file

@ -100,7 +100,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
qualify_filters = expression.args["qualify"].pop().this
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
for expr in qualify_filters.find_all(select_candidates):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr, alias), copy=False)