Merging upstream version 17.9.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2bf6699c56
commit
9777880e00
87 changed files with 45907 additions and 42511 deletions
|
@ -67,19 +67,22 @@ schema = MappingSchema()
|
|||
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||
|
||||
|
||||
def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
|
||||
def parse(
|
||||
sql: str, read: DialectType = None, dialect: DialectType = None, **opts
|
||||
) -> t.List[t.Optional[Expression]]:
|
||||
"""
|
||||
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
||||
|
||||
Args:
|
||||
sql: the SQL code string to parse.
|
||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||
dialect: the SQL dialect (alias for read).
|
||||
**opts: other `sqlglot.parser.Parser` options.
|
||||
|
||||
Returns:
|
||||
The resulting syntax tree collection.
|
||||
"""
|
||||
dialect = Dialect.get_or_raise(read)()
|
||||
dialect = Dialect.get_or_raise(read or dialect)()
|
||||
return dialect.parse(sql, **opts)
|
||||
|
||||
|
||||
|
|
|
@ -386,7 +386,7 @@ def input_file_name() -> Column:
|
|||
|
||||
|
||||
def isnan(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "ISNAN")
|
||||
return Column.invoke_expression_over_column(col, expression.IsNan)
|
||||
|
||||
|
||||
def isnull(col: ColumnOrName) -> Column:
|
||||
|
|
|
@ -211,6 +211,10 @@ class BigQuery(Dialect):
|
|||
"TZH": "%z",
|
||||
}
|
||||
|
||||
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
|
||||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
|
|
|
@ -380,7 +380,7 @@ class ClickHouse(Dialect):
|
|||
]
|
||||
|
||||
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
|
||||
params = self.expressions(expression, "params", flat=True)
|
||||
params = self.expressions(expression, key="params", flat=True)
|
||||
return self.func(expression.name, *expression.expressions) + f"({params})"
|
||||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
|
|
|
@ -5,6 +5,7 @@ from enum import Enum
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
|
@ -168,6 +169,10 @@ class Dialect(metaclass=_Dialect):
|
|||
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
|
||||
FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
|
||||
# Columns that are auto-generated by the engine corresponding to this dialect
|
||||
# Such columns may be excluded from SELECT * queries, for example
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
|
||||
# Autofilled
|
||||
tokenizer_class = Tokenizer
|
||||
parser_class = Parser
|
||||
|
@ -497,6 +502,10 @@ def parse_date_delta_with_interval(
|
|||
return None
|
||||
|
||||
interval = args[1]
|
||||
|
||||
if not isinstance(interval, exp.Interval):
|
||||
raise ParseError(f"INTERVAL expression expected but got '{interval}'")
|
||||
|
||||
expression = interval.this
|
||||
if expression and expression.is_string:
|
||||
expression = exp.Literal.number(expression.this)
|
||||
|
@ -555,11 +564,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||
return self.sql(exp.cast(expression.this, "timestamp"))
|
||||
|
||||
|
||||
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
return self.sql(exp.cast(expression.this, "date"))
|
||||
|
||||
|
||||
def min_or_least(self: Generator, expression: exp.Min) -> str:
|
||||
|
@ -608,8 +617,9 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
_dialect = Dialect.get_or_raise(dialect)
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
|
||||
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
|
||||
|
||||
return self.sql(exp.cast(self.sql(expression, "this"), "date"))
|
||||
|
||||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
@ -664,5 +674,15 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
|
|||
return names
|
||||
|
||||
|
||||
def simplify_literal(expression: E, copy: bool = True) -> E:
|
||||
if not isinstance(expression.expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
expression = exp.maybe_copy(expression, copy)
|
||||
simplify(expression.expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
|
||||
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
|
|
@ -359,14 +359,16 @@ class Hive(Dialect):
|
|||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
INDEX_ON = "ON TABLE"
|
||||
EXTRACT_ALLOWS_QUOTES = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.BIT: "BOOLEAN",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIME: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -396,6 +398,7 @@ class Hive(Dialect):
|
|||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
|
|
|
@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
parse_date_delta_with_interval,
|
||||
rename_func,
|
||||
simplify_literal,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -303,6 +304,22 @@ class MySQL(Dialect):
|
|||
"NAMES": lambda self: self._parse_set_item_names(),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS,
|
||||
"FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"),
|
||||
"INDEX": lambda self: self._parse_index_constraint(),
|
||||
"KEY": lambda self: self._parse_index_constraint(),
|
||||
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
|
||||
}
|
||||
|
||||
SCHEMA_UNNAMED_CONSTRAINTS = {
|
||||
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
|
||||
"FULLTEXT",
|
||||
"INDEX",
|
||||
"KEY",
|
||||
"SPATIAL",
|
||||
}
|
||||
|
||||
PROFILE_TYPES = {
|
||||
"ALL",
|
||||
"BLOCK IO",
|
||||
|
@ -327,6 +344,57 @@ class MySQL(Dialect):
|
|||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_index_constraint(
|
||||
self, kind: t.Optional[str] = None
|
||||
) -> exp.IndexColumnConstraint:
|
||||
if kind:
|
||||
self._match_texts({"INDEX", "KEY"})
|
||||
|
||||
this = self._parse_id_var(any_token=False)
|
||||
type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text
|
||||
schema = self._parse_schema()
|
||||
|
||||
options = []
|
||||
while True:
|
||||
if self._match_text_seq("KEY_BLOCK_SIZE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(key_block_size=self._parse_number())
|
||||
elif self._match(TokenType.USING):
|
||||
opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text)
|
||||
elif self._match_text_seq("WITH", "PARSER"):
|
||||
opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True))
|
||||
elif self._match(TokenType.COMMENT):
|
||||
opt = exp.IndexConstraintOption(comment=self._parse_string())
|
||||
elif self._match_text_seq("VISIBLE"):
|
||||
opt = exp.IndexConstraintOption(visible=True)
|
||||
elif self._match_text_seq("INVISIBLE"):
|
||||
opt = exp.IndexConstraintOption(visible=False)
|
||||
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
|
||||
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
|
||||
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
|
||||
self._match(TokenType.EQ)
|
||||
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
|
||||
else:
|
||||
opt = None
|
||||
|
||||
if not opt:
|
||||
break
|
||||
|
||||
options.append(opt)
|
||||
|
||||
return self.expression(
|
||||
exp.IndexColumnConstraint,
|
||||
this=this,
|
||||
schema=schema,
|
||||
kind=kind,
|
||||
type=type_,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def _parse_show_mysql(
|
||||
self,
|
||||
this: str,
|
||||
|
@ -454,6 +522,7 @@ class MySQL(Dialect):
|
|||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")),
|
||||
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
|
@ -485,6 +554,16 @@ class MySQL(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "CHAR",
|
||||
}
|
||||
|
||||
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
|
||||
# MySQL requires simple literal values for its LIMIT clause.
|
||||
expression = simplify_literal(expression)
|
||||
return super().limit_sql(expression, top=top)
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
# MySQL requires simple literal values for its OFFSET clause.
|
||||
expression = simplify_literal(expression)
|
||||
return super().offset_sql(expression)
|
||||
|
||||
def xor_sql(self, expression: exp.Xor) -> str:
|
||||
if expression.expressions:
|
||||
return self.expressions(expression, sep=" XOR ")
|
||||
|
|
|
@ -30,6 +30,9 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
|
|||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
TIME_MAPPING = {
|
||||
|
|
|
@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
simplify_literal,
|
||||
str_position_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
|
@ -39,16 +40,13 @@ DATE_DIFF_FACTOR = {
|
|||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
expression = simplify(expression.args["expression"])
|
||||
|
||||
expression = simplify_literal(expression.copy(), copy=False).expression
|
||||
if not isinstance(expression, exp.Literal):
|
||||
self.unsupported("Cannot add non literal")
|
||||
|
||||
expression = expression.copy()
|
||||
expression.args["is_string"] = True
|
||||
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
|
||||
|
||||
|
|
|
@ -192,6 +192,8 @@ class Presto(Dialect):
|
|||
"START": TokenType.BEGIN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"ROW": TokenType.STRUCT,
|
||||
"IPADDRESS": TokenType.IPADDRESS,
|
||||
"IPPREFIX": TokenType.IPPREFIX,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -47,7 +48,11 @@ class Spark(Spark2):
|
|||
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
|
||||
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
|
||||
}
|
||||
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
|
||||
|
||||
TRANSFORMS = {
|
||||
**Spark2.Generator.TRANSFORMS,
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
}
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
||||
|
|
|
@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
|||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if kind.upper() == "TABLE" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
if (
|
||||
kind.upper() == "TABLE"
|
||||
and e.expression
|
||||
and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
)
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
|
|
|
@ -33,8 +33,10 @@ class Teradata(Dialect):
|
|||
**tokens.Tokenizer.KEYWORDS,
|
||||
"^=": TokenType.NEQ,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"COLLECT": TokenType.COMMAND,
|
||||
"GE": TokenType.GTE,
|
||||
"GT": TokenType.GT,
|
||||
"HELP": TokenType.COMMAND,
|
||||
"INS": TokenType.INSERT,
|
||||
"LE": TokenType.LTE,
|
||||
"LT": TokenType.LT,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
|
@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import (
|
|||
min_or_least,
|
||||
parse_date_delta,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
|
|||
# N = Numeric, C=Currency
|
||||
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
||||
|
||||
DEFAULT_START_DATE = datetime.date(1900, 1, 1)
|
||||
|
||||
|
||||
def _format_time_lambda(
|
||||
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
|
||||
|
@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
|
|||
return f"STRING_AGG({self.format_args(this, separator)}){order}"
|
||||
|
||||
|
||||
def _parse_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
) -> t.Callable[[t.List], E]:
|
||||
def inner_func(args: t.List) -> E:
|
||||
unit = seq_get(args, 0)
|
||||
if unit and unit_mapping:
|
||||
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
|
||||
|
||||
start_date = seq_get(args, 1)
|
||||
if start_date and start_date.is_number:
|
||||
# Numeric types are valid DATETIME values
|
||||
if start_date.is_int:
|
||||
adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this))
|
||||
start_date = exp.Literal.string(adds.strftime("%F"))
|
||||
else:
|
||||
# We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
|
||||
# This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
|
||||
return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
|
||||
|
||||
return exp_class(
|
||||
this=exp.TimeStrToTime(this=seq_get(args, 2)),
|
||||
expression=exp.TimeStrToTime(this=start_date),
|
||||
unit=unit,
|
||||
)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
@ -298,7 +330,6 @@ class TSQL(Dialect):
|
|||
"SMALLDATETIME": TokenType.DATETIME,
|
||||
"SMALLMONEY": TokenType.SMALLMONEY,
|
||||
"SQL_VARIANT": TokenType.VARIANT,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
|
@ -307,10 +338,6 @@ class TSQL(Dialect):
|
|||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
}
|
||||
|
||||
# TSQL allows @, # to appear as a variable/identifier prefix
|
||||
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
|
||||
SINGLE_TOKENS.pop("#")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -320,7 +347,7 @@ class TSQL(Dialect):
|
|||
position=seq_get(args, 2),
|
||||
),
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
|
||||
"DATEPART": _format_time_lambda(exp.TimeToStr),
|
||||
"EOMONTH": _parse_eomonth,
|
||||
|
@ -518,6 +545,36 @@ class TSQL(Dialect):
|
|||
expressions = self._parse_csv(self._parse_function_parameter)
|
||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
def _parse_id_var(
|
||||
self,
|
||||
any_token: bool = True,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
is_temporary = self._match(TokenType.HASH)
|
||||
is_global = is_temporary and self._match(TokenType.HASH)
|
||||
|
||||
this = super()._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
if this:
|
||||
if is_global:
|
||||
this.set("global", True)
|
||||
elif is_temporary:
|
||||
this.set("temporary", True)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_create(self) -> exp.Create | exp.Command:
|
||||
create = super()._parse_create()
|
||||
|
||||
if isinstance(create, exp.Create):
|
||||
table = create.this.this if isinstance(create.this, exp.Schema) else create.this
|
||||
if isinstance(table, exp.Table) and table.this.args.get("temporary"):
|
||||
if not create.args.get("properties"):
|
||||
create.set("properties", exp.Properties(expressions=[]))
|
||||
|
||||
create.args["properties"].append("expressions", exp.TemporaryProperty())
|
||||
|
||||
return create
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
LIMIT_IS_TOP = True
|
||||
|
@ -526,9 +583,11 @@ class TSQL(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
exp.DataType.Type.DATETIME: "DATETIME2",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME2",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
|
||||
exp.DataType.Type.VARIANT: "SQL_VARIANT",
|
||||
}
|
||||
|
||||
|
@ -552,6 +611,8 @@ class TSQL(Dialect):
|
|||
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
|
||||
e.this,
|
||||
),
|
||||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
}
|
||||
|
||||
|
@ -564,6 +625,22 @@ class TSQL(Dialect):
|
|||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def createable_sql(
|
||||
self,
|
||||
expression: exp.Create,
|
||||
locations: dict[exp.Properties.Location, list[exp.Property]],
|
||||
) -> str:
|
||||
sql = self.sql(expression, "this")
|
||||
properties = expression.args.get("properties")
|
||||
|
||||
if sql[:1] != "#" and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
sql = f"#{sql}"
|
||||
|
||||
return sql
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
|
@ -616,3 +693,13 @@ class TSQL(Dialect):
|
|||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
return f"ROLLBACK TRANSACTION{this}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
identifier = super().identifier_sql(expression)
|
||||
|
||||
if expression.args.get("global"):
|
||||
identifier = f"##{identifier}"
|
||||
elif expression.args.get("temporary"):
|
||||
identifier = f"#{identifier}"
|
||||
|
||||
return identifier
|
||||
|
|
|
@ -67,8 +67,9 @@ class Expression(metaclass=_Expression):
|
|||
uses to refer to it.
|
||||
comments: a list of comments that are associated with a given expression. This is used in
|
||||
order to preserve comments when transpiling SQL code.
|
||||
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
|
||||
type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
|
||||
optimizer, in order to enable some transformations that require type information.
|
||||
meta: a dictionary that can be used to store useful metadata for a given expression.
|
||||
|
||||
Example:
|
||||
>>> class Foo(Expression):
|
||||
|
@ -767,7 +768,7 @@ class Condition(Expression):
|
|||
**opts,
|
||||
) -> In:
|
||||
return In(
|
||||
this=_maybe_copy(self, copy),
|
||||
this=maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
|
@ -781,7 +782,7 @@ class Condition(Expression):
|
|||
|
||||
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
|
||||
return Between(
|
||||
this=_maybe_copy(self, copy),
|
||||
this=maybe_copy(self, copy),
|
||||
low=convert(low, copy=copy, **opts),
|
||||
high=convert(high, copy=copy, **opts),
|
||||
)
|
||||
|
@ -990,7 +991,28 @@ class Uncache(Expression):
|
|||
arg_types = {"this": True, "exists": False}
|
||||
|
||||
|
||||
class Create(Expression):
|
||||
class DDL(Expression):
|
||||
@property
|
||||
def ctes(self):
|
||||
with_ = self.args.get("with")
|
||||
if not with_:
|
||||
return []
|
||||
return with_.expressions
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
if isinstance(self.expression, Subqueryable):
|
||||
return self.expression.named_selects
|
||||
return []
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
if isinstance(self.expression, Subqueryable):
|
||||
return self.expression.selects
|
||||
return []
|
||||
|
||||
|
||||
class Create(DDL):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -1206,6 +1228,19 @@ class MergeTreeTTL(Expression):
|
|||
}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
|
||||
class IndexConstraintOption(Expression):
|
||||
arg_types = {
|
||||
"key_block_size": False,
|
||||
"using": False,
|
||||
"parser": False,
|
||||
"comment": False,
|
||||
"visible": False,
|
||||
"engine_attr": False,
|
||||
"secondary_engine_attr": False,
|
||||
}
|
||||
|
||||
|
||||
class ColumnConstraint(Expression):
|
||||
arg_types = {"this": False, "kind": True}
|
||||
|
||||
|
@ -1272,6 +1307,11 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
|||
}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
|
||||
class IndexColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False}
|
||||
|
||||
|
||||
class InlineLengthColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
@ -1496,7 +1536,7 @@ class JoinHint(Expression):
|
|||
|
||||
|
||||
class Identifier(Expression):
|
||||
arg_types = {"this": True, "quoted": False}
|
||||
arg_types = {"this": True, "quoted": False, "global": False, "temporary": False}
|
||||
|
||||
@property
|
||||
def quoted(self) -> bool:
|
||||
|
@ -1525,7 +1565,7 @@ class Index(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Insert(Expression):
|
||||
class Insert(DDL):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -1892,6 +1932,10 @@ class EngineProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class HeapProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class ToTableProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -2182,7 +2226,7 @@ class Tuple(Expression):
|
|||
**opts,
|
||||
) -> In:
|
||||
return In(
|
||||
this=_maybe_copy(self, copy),
|
||||
this=maybe_copy(self, copy),
|
||||
expressions=[convert(e, copy=copy) for e in expressions],
|
||||
query=maybe_parse(query, copy=copy, **opts) if query else None,
|
||||
unnest=Unnest(
|
||||
|
@ -2212,7 +2256,7 @@ class Subqueryable(Unionable):
|
|||
Returns:
|
||||
Alias: the subquery
|
||||
"""
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance = maybe_copy(self, copy)
|
||||
if not isinstance(alias, Expression):
|
||||
alias = TableAlias(this=to_identifier(alias)) if alias else None
|
||||
|
||||
|
@ -2865,7 +2909,7 @@ class Select(Subqueryable):
|
|||
self,
|
||||
expression: ExpOrStr,
|
||||
on: t.Optional[ExpOrStr] = None,
|
||||
using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None,
|
||||
using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None,
|
||||
append: bool = True,
|
||||
join_type: t.Optional[str] = None,
|
||||
join_alias: t.Optional[Identifier | str] = None,
|
||||
|
@ -2943,6 +2987,7 @@ class Select(Subqueryable):
|
|||
arg="using",
|
||||
append=append,
|
||||
copy=copy,
|
||||
into=Identifier,
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
@ -3092,7 +3137,7 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
Select: the modified expression.
|
||||
"""
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance = maybe_copy(self, copy)
|
||||
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None
|
||||
instance.set("distinct", Distinct(on=on) if distinct else None)
|
||||
return instance
|
||||
|
@ -3123,7 +3168,7 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
The new Create expression.
|
||||
"""
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance = maybe_copy(self, copy)
|
||||
table_expression = maybe_parse(
|
||||
table,
|
||||
into=Table,
|
||||
|
@ -3159,7 +3204,7 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst = maybe_copy(self, copy)
|
||||
inst.set("locks", [Lock(update=update)])
|
||||
|
||||
return inst
|
||||
|
@ -3181,7 +3226,7 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst = maybe_copy(self, copy)
|
||||
inst.set(
|
||||
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
|
||||
)
|
||||
|
@ -3376,6 +3421,8 @@ class DataType(Expression):
|
|||
HSTORE = auto()
|
||||
IMAGE = auto()
|
||||
INET = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
INT = auto()
|
||||
INT128 = auto()
|
||||
INT256 = auto()
|
||||
|
@ -3987,7 +4034,7 @@ class Case(Func):
|
|||
arg_types = {"this": False, "ifs": True, "default": False}
|
||||
|
||||
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance = maybe_copy(self, copy)
|
||||
instance.append(
|
||||
"ifs",
|
||||
If(
|
||||
|
@ -3998,7 +4045,7 @@ class Case(Func):
|
|||
return instance
|
||||
|
||||
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
|
||||
instance = _maybe_copy(self, copy)
|
||||
instance = maybe_copy(self, copy)
|
||||
instance.set("default", maybe_parse(condition, copy=copy, **opts))
|
||||
return instance
|
||||
|
||||
|
@ -4263,6 +4310,10 @@ class Initcap(Func):
|
|||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class IsNan(Func):
|
||||
_sql_names = ["IS_NAN", "ISNAN"]
|
||||
|
||||
|
||||
class JSONKeyValue(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -4549,6 +4600,11 @@ class StandardHash(Func):
|
|||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class StartsWith(Func):
|
||||
_sql_names = ["STARTS_WITH", "STARTSWITH"]
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class StrPosition(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -4804,7 +4860,7 @@ def maybe_parse(
|
|||
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
|
||||
|
||||
|
||||
def _maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
return instance.copy() if copy else instance
|
||||
|
||||
|
||||
|
@ -4824,7 +4880,7 @@ def _apply_builder(
|
|||
):
|
||||
if _is_wrong_expression(expression, into):
|
||||
expression = into(this=expression)
|
||||
instance = _maybe_copy(instance, copy)
|
||||
instance = maybe_copy(instance, copy)
|
||||
expression = maybe_parse(
|
||||
sql_or_expression=expression,
|
||||
prefix=prefix,
|
||||
|
@ -4848,7 +4904,7 @@ def _apply_child_list_builder(
|
|||
properties=None,
|
||||
**opts,
|
||||
):
|
||||
instance = _maybe_copy(instance, copy)
|
||||
instance = maybe_copy(instance, copy)
|
||||
parsed = []
|
||||
for expression in expressions:
|
||||
if expression is not None:
|
||||
|
@ -4887,7 +4943,7 @@ def _apply_list_builder(
|
|||
dialect=None,
|
||||
**opts,
|
||||
):
|
||||
inst = _maybe_copy(instance, copy)
|
||||
inst = maybe_copy(instance, copy)
|
||||
|
||||
expressions = [
|
||||
maybe_parse(
|
||||
|
@ -4923,7 +4979,7 @@ def _apply_conjunction_builder(
|
|||
if not expressions:
|
||||
return instance
|
||||
|
||||
inst = _maybe_copy(instance, copy)
|
||||
inst = maybe_copy(instance, copy)
|
||||
|
||||
existing = inst.args.get(arg)
|
||||
if append and existing is not None:
|
||||
|
@ -5398,7 +5454,7 @@ def to_identifier(name, quoted=None, copy=True):
|
|||
return None
|
||||
|
||||
if isinstance(name, Identifier):
|
||||
identifier = _maybe_copy(name, copy)
|
||||
identifier = maybe_copy(name, copy)
|
||||
elif isinstance(name, str):
|
||||
identifier = Identifier(
|
||||
this=name,
|
||||
|
@ -5735,7 +5791,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
|
|||
Expression: the equivalent expression object.
|
||||
"""
|
||||
if isinstance(value, Expression):
|
||||
return _maybe_copy(value, copy)
|
||||
return maybe_copy(value, copy)
|
||||
if isinstance(value, str):
|
||||
return Literal.string(value)
|
||||
if isinstance(value, bool):
|
||||
|
|
|
@ -68,6 +68,7 @@ class Generator:
|
|||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda self, e: "EXTERNAL",
|
||||
exp.HeapProperty: lambda self, e: "HEAP",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -161,6 +162,9 @@ class Generator:
|
|||
# Whether or not to generate the (+) suffix for columns used in old-style join conditions
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = False
|
||||
|
||||
# Whether or not to generate an unquoted value for EXTRACT's date part argument
|
||||
EXTRACT_ALLOWS_QUOTES = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
|
@ -224,6 +228,7 @@ class Generator:
|
|||
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.HeapProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.JournalProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -265,9 +270,12 @@ class Generator:
|
|||
|
||||
# Expressions whose comments are separated from them for better formatting
|
||||
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Delete,
|
||||
exp.Drop,
|
||||
exp.From,
|
||||
exp.Insert,
|
||||
exp.Select,
|
||||
exp.Update,
|
||||
exp.Where,
|
||||
exp.With,
|
||||
)
|
||||
|
@ -985,8 +993,9 @@ class Generator:
|
|||
) -> str:
|
||||
if properties.expressions:
|
||||
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||
expressions = self.wrap(expressions) if wrapped else expressions
|
||||
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
||||
if expressions:
|
||||
expressions = self.wrap(expressions) if wrapped else expressions
|
||||
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
||||
return ""
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
|
@ -1905,7 +1914,7 @@ class Generator:
|
|||
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
|
||||
|
||||
def extract_sql(self, expression: exp.Extract) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
return f"EXTRACT({this} FROM {expression_sql})"
|
||||
|
||||
|
@ -2370,7 +2379,12 @@ class Generator:
|
|||
elif arg_value is not None:
|
||||
args.append(arg_value)
|
||||
|
||||
return self.func(expression.sql_name(), *args)
|
||||
if self.normalize_functions:
|
||||
name = expression.sql_name()
|
||||
else:
|
||||
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
|
||||
|
||||
return self.func(name, *args)
|
||||
|
||||
def func(
|
||||
self,
|
||||
|
@ -2412,7 +2426,7 @@ class Generator:
|
|||
return ""
|
||||
|
||||
if flat:
|
||||
return sep.join(self.sql(e) for e in expressions)
|
||||
return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql)
|
||||
|
||||
num_sqls = len(expressions)
|
||||
|
||||
|
@ -2423,6 +2437,9 @@ class Generator:
|
|||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
sql = self.sql(e, comment=False)
|
||||
if not sql:
|
||||
continue
|
||||
|
||||
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
|
||||
|
||||
if self.pretty:
|
||||
|
@ -2562,6 +2579,51 @@ class Generator:
|
|||
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
|
||||
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
|
||||
|
||||
def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str:
|
||||
key_block_size = self.sql(expression, "key_block_size")
|
||||
if key_block_size:
|
||||
return f"KEY_BLOCK_SIZE = {key_block_size}"
|
||||
|
||||
using = self.sql(expression, "using")
|
||||
if using:
|
||||
return f"USING {using}"
|
||||
|
||||
parser = self.sql(expression, "parser")
|
||||
if parser:
|
||||
return f"WITH PARSER {parser}"
|
||||
|
||||
comment = self.sql(expression, "comment")
|
||||
if comment:
|
||||
return f"COMMENT {comment}"
|
||||
|
||||
visible = expression.args.get("visible")
|
||||
if visible is not None:
|
||||
return "VISIBLE" if visible else "INVISIBLE"
|
||||
|
||||
engine_attr = self.sql(expression, "engine_attr")
|
||||
if engine_attr:
|
||||
return f"ENGINE_ATTRIBUTE = {engine_attr}"
|
||||
|
||||
secondary_engine_attr = self.sql(expression, "secondary_engine_attr")
|
||||
if secondary_engine_attr:
|
||||
return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}"
|
||||
|
||||
self.unsupported("Unsupported index constraint option.")
|
||||
return ""
|
||||
|
||||
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} INDEX" if kind else "INDEX"
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
type_ = self.sql(expression, "type")
|
||||
type_ = f" USING {type_}" if type_ else ""
|
||||
schema = self.sql(expression, "schema")
|
||||
schema = f" {schema}" if schema else ""
|
||||
options = self.expressions(expression, key="options", sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
return f"{kind}{this}{type_}{schema}{options}"
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
# This ensures we don't drop the "pivot" arg from a pivoted subquery
|
||||
if scope.parent.pivots:
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
# - eliminate a lateral correlated subquery
|
||||
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
||||
return None
|
||||
|
||||
parent = scope.expression.parent
|
||||
|
|
|
@ -1,8 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
|
||||
...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> normalize_identifiers(expression).sql()
|
||||
'SELECT bar.a AS a FROM "Foo".bar'
|
||||
>>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
|
||||
'FOO'
|
||||
|
||||
Args:
|
||||
expression: The expression to transform.
|
||||
|
@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
expression = exp.maybe_parse(expression, dialect=dialect)
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
|
|
@ -39,6 +39,7 @@ def qualify_columns(
|
|||
"""
|
||||
schema = ensure_schema(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
|
@ -55,7 +56,7 @@ def qualify_columns(
|
|||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
|
||||
def _expand_stars(
|
||||
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
|
||||
scope: Scope,
|
||||
resolver: Resolver,
|
||||
using_column_tables: t.Dict[str, t.Any],
|
||||
pseudocolumns: t.Set[str],
|
||||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
|
@ -367,14 +371,8 @@ def _expand_stars(
|
|||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
|
||||
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
|
||||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
if resolver.schema.dialect == "bigquery":
|
||||
columns = [
|
||||
name
|
||||
for name in columns
|
||||
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
|
||||
]
|
||||
if pseudocolumns:
|
||||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
|
|
|
@ -80,7 +80,9 @@ def qualify_tables(
|
|||
header = next(reader)
|
||||
columns = next(reader)
|
||||
schema.add_table(
|
||||
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||
source,
|
||||
{k: type(v).__name__ for k, v in zip(header, columns)},
|
||||
match_depth=False,
|
||||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
|
|
|
@ -435,7 +435,10 @@ class Scope:
|
|||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(self.is_subquery and self.external_columns)
|
||||
return bool(
|
||||
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
|
||||
and self.external_columns
|
||||
)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
|
@ -486,7 +489,7 @@ class Scope:
|
|||
|
||||
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
||||
"""
|
||||
Traverse an expression by it's "scopes".
|
||||
Traverse an expression by its "scopes".
|
||||
|
||||
"Scope" represents the current context of a Select statement.
|
||||
|
||||
|
@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if not isinstance(expression, exp.Unionable):
|
||||
return []
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
||||
|
@ -539,7 +545,9 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(scope.expression, exp.DDL):
|
||||
yield from _traverse_ddl(scope)
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
|
@ -576,10 +584,10 @@ def _traverse_ctes(scope):
|
|||
for cte in scope.ctes:
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if scope.expression.args["with"].recursive:
|
||||
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
|
||||
# thus the recursive scope is the first section of the union.
|
||||
with_ = scope.expression.args.get("with")
|
||||
if with_ and with_.recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
|
@ -692,8 +700,7 @@ def _traverse_tables(scope):
|
|||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
alias = expression.alias
|
||||
sources[alias] = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
# append the final child_scope yielded
|
||||
scopes.append(child_scope)
|
||||
|
@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
|
|||
scope.subquery_scopes.append(top)
|
||||
|
||||
|
||||
def _traverse_udtfs(scope):
|
||||
if isinstance(scope.expression, exp.Unnest):
|
||||
expressions = scope.expression.expressions
|
||||
elif isinstance(scope.expression, exp.Lateral):
|
||||
expressions = [scope.expression.this]
|
||||
else:
|
||||
expressions = []
|
||||
|
||||
sources = {}
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.table_scopes.append(top)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_ddl(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
||||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
|
|
|
@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
|
|||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
# this subquery returns a scalar and can just be converted to a cross join
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
having = predicate.find_ancestor(exp.Having)
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
if having and having.parent_select is parent_select:
|
||||
column = exp.Max(this=column)
|
||||
_replace(select.parent, column)
|
||||
|
||||
parent_select.join(
|
||||
select,
|
||||
join_type="CROSS",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
)
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
clause_parent_select = clause.parent_select if clause else None
|
||||
|
||||
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
|
||||
(not clause or clause_parent_select is not parent_select)
|
||||
and (
|
||||
parent_select.args.get("group")
|
||||
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
|
||||
)
|
||||
):
|
||||
column = exp.Max(this=column)
|
||||
|
||||
_replace(select.parent, column)
|
||||
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
|
||||
return
|
||||
|
||||
if select.find(exp.Limit, exp.Offset):
|
||||
|
|
|
@ -185,6 +185,8 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.VARIANT,
|
||||
TokenType.OBJECT,
|
||||
TokenType.INET,
|
||||
TokenType.IPADDRESS,
|
||||
TokenType.IPPREFIX,
|
||||
TokenType.ENUM,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
}
|
||||
|
@ -603,6 +605,7 @@ class Parser(metaclass=_Parser):
|
|||
"FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
|
||||
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||
"FREESPACE": lambda self: self._parse_freespace(),
|
||||
"HEAP": lambda self: self.expression(exp.HeapProperty),
|
||||
"IMMUTABLE": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
|
@ -832,6 +835,7 @@ class Parser(metaclass=_Parser):
|
|||
UNNEST_COLUMN_ONLY: bool = False
|
||||
ALIAS_POST_TABLESAMPLE: bool = False
|
||||
STRICT_STRING_CONCAT = False
|
||||
NORMALIZE_FUNCTIONS = "upper"
|
||||
NULL_ORDERING: str = "nulls_are_small"
|
||||
SHOW_TRIE: t.Dict = {}
|
||||
SET_TRIE: t.Dict = {}
|
||||
|
@ -1187,7 +1191,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
exists = self._parse_exists(not_=True)
|
||||
this = None
|
||||
expression = None
|
||||
expression: t.Optional[exp.Expression] = None
|
||||
indexes = None
|
||||
no_schema_binding = None
|
||||
begin = None
|
||||
|
@ -1207,12 +1211,16 @@ class Parser(metaclass=_Parser):
|
|||
extend_props(self._parse_properties())
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
return_ = self._match_text_seq("RETURN")
|
||||
expression = self._parse_statement()
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
if self._match(TokenType.COMMAND):
|
||||
expression = self._parse_as_command(self._prev)
|
||||
else:
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
return_ = self._match_text_seq("RETURN")
|
||||
expression = self._parse_statement()
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
elif create_token.token_type == TokenType.INDEX:
|
||||
this = self._parse_index(index=self._parse_id_var())
|
||||
elif create_token.token_type in self.DB_CREATABLES:
|
||||
|
@ -1692,6 +1700,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Describe, this=this, kind=kind)
|
||||
|
||||
def _parse_insert(self) -> exp.Insert:
|
||||
comments = ensure_list(self._prev_comments)
|
||||
overwrite = self._match(TokenType.OVERWRITE)
|
||||
ignore = self._match(TokenType.IGNORE)
|
||||
local = self._match_text_seq("LOCAL")
|
||||
|
@ -1709,6 +1718,7 @@ class Parser(metaclass=_Parser):
|
|||
alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text
|
||||
|
||||
self._match(TokenType.INTO)
|
||||
comments += ensure_list(self._prev_comments)
|
||||
self._match(TokenType.TABLE)
|
||||
this = self._parse_table(schema=True)
|
||||
|
||||
|
@ -1716,6 +1726,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.Insert,
|
||||
comments=comments,
|
||||
this=this,
|
||||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
|
@ -1840,6 +1851,7 @@ class Parser(metaclass=_Parser):
|
|||
# This handles MySQL's "Multiple-Table Syntax"
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/delete.html
|
||||
tables = None
|
||||
comments = self._prev_comments
|
||||
if not self._match(TokenType.FROM, advance=False):
|
||||
tables = self._parse_csv(self._parse_table) or None
|
||||
|
||||
|
@ -1847,6 +1859,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.Delete,
|
||||
comments=comments,
|
||||
tables=tables,
|
||||
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
|
||||
using=self._match(TokenType.USING) and self._parse_table(joins=True),
|
||||
|
@ -1856,11 +1869,13 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_update(self) -> exp.Update:
|
||||
comments = self._prev_comments
|
||||
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
|
||||
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
|
||||
returning = self._parse_returning()
|
||||
return self.expression(
|
||||
exp.Update,
|
||||
comments=comments,
|
||||
**{ # type: ignore
|
||||
"this": this,
|
||||
"expressions": expressions,
|
||||
|
@ -2235,7 +2250,12 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
if not this:
|
||||
this = self._parse_function() or self._parse_id_var(any_token=False)
|
||||
this = (
|
||||
self._parse_unnest()
|
||||
or self._parse_function()
|
||||
or self._parse_id_var(any_token=False)
|
||||
)
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
this = exp.Dot(
|
||||
this=this,
|
||||
|
@ -3341,7 +3361,10 @@ class Parser(metaclass=_Parser):
|
|||
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
|
||||
|
||||
if function and not anonymous:
|
||||
this = self.validate_expression(function(args), args)
|
||||
func = self.validate_expression(function(args), args)
|
||||
if not self.NORMALIZE_FUNCTIONS:
|
||||
func.meta["name"] = this
|
||||
this = func
|
||||
else:
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
|
||||
|
@ -3842,13 +3865,11 @@ class Parser(metaclass=_Parser):
|
|||
args = self._parse_csv(self._parse_conjunction)
|
||||
|
||||
index = self._index
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
if not self._match(TokenType.R_PAREN) and args:
|
||||
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
|
||||
return self.expression(
|
||||
exp.GroupConcat,
|
||||
this=seq_get(args, 0),
|
||||
separator=self._parse_order(this=seq_get(args, 1)),
|
||||
)
|
||||
# bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n])
|
||||
args[-1] = self._parse_limit(this=self._parse_order(this=args[-1]))
|
||||
return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1))
|
||||
|
||||
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
|
||||
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
|
||||
|
@ -4172,7 +4193,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
window = self.expression(
|
||||
exp.Window,
|
||||
this=this,
|
||||
partition_by=partition,
|
||||
|
@ -4183,6 +4204,12 @@ class Parser(metaclass=_Parser):
|
|||
first=first,
|
||||
)
|
||||
|
||||
# This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...)
|
||||
if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False):
|
||||
return self._parse_window(window, alias=alias)
|
||||
|
||||
return window
|
||||
|
||||
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
|
||||
self._match(TokenType.BETWEEN)
|
||||
|
||||
|
@ -4276,19 +4303,19 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_null(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.NULL):
|
||||
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
|
||||
return None
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_boolean(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.TRUE):
|
||||
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
|
||||
if self._match(TokenType.FALSE):
|
||||
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
|
||||
return None
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_star(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.STAR):
|
||||
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
|
||||
return None
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_parameter(self) -> exp.Parameter:
|
||||
wrapped = self._match(TokenType.L_BRACE)
|
||||
|
|
|
@ -31,14 +31,19 @@ class Schema(abc.ABC):
|
|||
table: exp.Table | str,
|
||||
column_mapping: t.Optional[ColumnMapping] = None,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
match_depth: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Some implementing classes may require column information to also be provided.
|
||||
The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
|
||||
|
||||
Args:
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
match_depth: whether to enforce that the table must match the schema's depth or not.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -47,6 +52,7 @@ class Schema(abc.ABC):
|
|||
table: exp.Table | str,
|
||||
only_visible: bool = False,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> t.List[str]:
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
@ -55,6 +61,7 @@ class Schema(abc.ABC):
|
|||
table: the `Table` expression instance.
|
||||
only_visible: whether to include invisible columns.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
|
||||
Returns:
|
||||
The list of column names.
|
||||
|
@ -66,6 +73,7 @@ class Schema(abc.ABC):
|
|||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.DataType:
|
||||
"""
|
||||
Get the `sqlglot.exp.DataType` type of a column in the schema.
|
||||
|
@ -74,6 +82,7 @@ class Schema(abc.ABC):
|
|||
table: the source table.
|
||||
column: the target column.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
|
||||
Returns:
|
||||
The resulting column type.
|
||||
|
@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
) -> None:
|
||||
self.mapping = mapping or {}
|
||||
self.mapping_trie = new_trie(
|
||||
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
|
||||
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
|
||||
)
|
||||
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||
|
||||
|
@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
def empty(self) -> bool:
|
||||
return not self.mapping
|
||||
|
||||
def _depth(self) -> int:
|
||||
def depth(self) -> int:
|
||||
return dict_depth(self.mapping)
|
||||
|
||||
@property
|
||||
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||
if not self._supported_table_args and self.mapping:
|
||||
depth = self._depth()
|
||||
depth = self.depth()
|
||||
|
||||
if not depth: # None
|
||||
self._supported_table_args = tuple()
|
||||
|
@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
self.visible = visible or {}
|
||||
self.normalize = normalize
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||
self._depth = 0
|
||||
|
||||
super().__init__(self._normalize(schema or {}))
|
||||
|
||||
|
@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
schema=mapping_schema.mapping,
|
||||
visible=mapping_schema.visible,
|
||||
dialect=mapping_schema.dialect,
|
||||
normalize=mapping_schema.normalize,
|
||||
)
|
||||
|
||||
def copy(self, **kwargs) -> MappingSchema:
|
||||
|
@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
"schema": self.mapping.copy(),
|
||||
"visible": self.visible.copy(),
|
||||
"dialect": self.dialect,
|
||||
"normalize": self.normalize,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
table: exp.Table | str,
|
||||
column_mapping: t.Optional[ColumnMapping] = None,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
match_depth: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
|
||||
|
||||
Args:
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
match_depth: whether to enforce that the table must match the schema's depth or not.
|
||||
"""
|
||||
normalized_table = self._normalize_table(table, dialect=dialect)
|
||||
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
|
||||
raise SchemaError(
|
||||
f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
|
||||
f"schema's nesting level: {self.depth()}."
|
||||
)
|
||||
|
||||
normalized_column_mapping = {
|
||||
self._normalize_name(key, dialect=dialect): value
|
||||
self._normalize_name(key, dialect=dialect, normalize=normalize): value
|
||||
for key, value in ensure_column_mapping(column_mapping).items()
|
||||
}
|
||||
|
||||
|
@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
table: exp.Table | str,
|
||||
only_visible: bool = False,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> t.List[str]:
|
||||
normalized_table = self._normalize_table(table, dialect=dialect)
|
||||
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
schema = self.find(normalized_table)
|
||||
if schema is None:
|
||||
|
@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.DataType:
|
||||
normalized_table = self._normalize_table(table, dialect=dialect)
|
||||
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
normalized_column_name = self._normalize_name(
|
||||
column if isinstance(column, str) else column.this, dialect=dialect
|
||||
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
|
||||
)
|
||||
|
||||
table_schema = self.find(normalized_table, raise_on_missing=False)
|
||||
|
@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
Returns:
|
||||
The normalized schema mapping.
|
||||
"""
|
||||
normalized_mapping: t.Dict = {}
|
||||
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
|
||||
|
||||
normalized_mapping: t.Dict = {}
|
||||
for keys in flattened_schema:
|
||||
columns = nested_get(schema, *zip(keys, keys))
|
||||
assert columns is not None
|
||||
|
||||
if not isinstance(columns, dict):
|
||||
raise SchemaError(
|
||||
f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
|
||||
)
|
||||
|
||||
normalized_keys = [
|
||||
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
|
||||
|
@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
|
||||
return normalized_mapping
|
||||
|
||||
def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
|
||||
def _normalize_table(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.Table:
|
||||
normalized_table = exp.maybe_parse(
|
||||
table, into=exp.Table, dialect=dialect or self.dialect, copy=True
|
||||
)
|
||||
|
@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
if isinstance(value, (str, exp.Identifier)):
|
||||
normalized_table.set(
|
||||
arg,
|
||||
exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
|
||||
exp.to_identifier(
|
||||
self._normalize_name(
|
||||
value, dialect=dialect, is_table=True, normalize=normalize
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return normalized_table
|
||||
|
||||
def _normalize_name(
|
||||
self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
|
||||
self,
|
||||
name: str | exp.Identifier,
|
||||
dialect: DialectType = None,
|
||||
is_table: bool = False,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> str:
|
||||
dialect = dialect or self.dialect
|
||||
normalize = self.normalize if normalize is None else normalize
|
||||
|
||||
try:
|
||||
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
|
||||
|
@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
return name if isinstance(name, str) else name.name
|
||||
|
||||
name = identifier.name
|
||||
if not self.normalize:
|
||||
if not normalize:
|
||||
return name
|
||||
|
||||
# This can be useful for normalize_identifier
|
||||
identifier.meta["is_table"] = is_table
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
|
||||
|
||||
def _depth(self) -> int:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
return super()._depth() - 1
|
||||
def depth(self) -> int:
|
||||
if not self.empty and not self._depth:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
self._depth = super().depth() - 1
|
||||
return self._depth
|
||||
|
||||
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
|
||||
"""
|
||||
|
|
|
@ -147,6 +147,8 @@ class TokenType(AutoName):
|
|||
VARIANT = auto()
|
||||
OBJECT = auto()
|
||||
INET = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
ENUM = auto()
|
||||
|
||||
# keywords
|
||||
|
|
|
@ -100,7 +100,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
|
||||
qualify_filters = expression.args["qualify"].pop().this
|
||||
|
||||
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
|
||||
select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
|
||||
for expr in qualify_filters.find_all(select_candidates):
|
||||
if isinstance(expr, exp.Window):
|
||||
alias = find_new_name(expression.named_selects, "_w")
|
||||
expression.select(exp.alias_(expr, alias), copy=False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue