Merging upstream version 17.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f4a8b128b0
commit
bf82c6c1c0
78 changed files with 35859 additions and 34717 deletions
|
@ -620,7 +620,16 @@ def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat
|
|||
return self.sql(this)
|
||||
|
||||
|
||||
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
|
||||
def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
|
||||
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
|
||||
)
|
||||
|
||||
|
||||
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
||||
names = []
|
||||
for agg in aggregations:
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
pivot_column_names,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
|
@ -88,19 +89,6 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
|
||||
|
||||
return self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
expression.args.get("this"),
|
||||
expression.args.get("expression"),
|
||||
expression.args.get("group"),
|
||||
)
|
||||
|
||||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
|
||||
return f"CAST({sql} AS TEXT)"
|
||||
|
@ -156,6 +144,9 @@ class DuckDB(Dialect):
|
|||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
|
||||
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
|
||||
"STRING_SPLIT": exp.Split.from_arg_list,
|
||||
|
@ -227,7 +218,7 @@ class DuckDB(Dialect):
|
|||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: _regexp_extract_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
|
|
@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_recursive_cte_sql,
|
||||
no_safe_divide_sql,
|
||||
no_trycast_sql,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
strposition_to_locate_sql,
|
||||
|
@ -230,24 +231,25 @@ class Hive(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"COLLECT_SET": exp.SetAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
),
|
||||
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
|
||||
[
|
||||
exp.TimeStrToTime(this=seq_get(args, 0)),
|
||||
seq_get(args, 1),
|
||||
]
|
||||
),
|
||||
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
|
@ -256,7 +258,9 @@ class Hive(Dialect):
|
|||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
|
||||
"COLLECT_SET": exp.SetAgg.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
"SIZE": exp.ArraySize.from_arg_list,
|
||||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
|
||||
|
@ -363,6 +367,7 @@ class Hive(Dialect):
|
|||
exp.Create: create_with_partitions_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
|
||||
exp.RegexpSplit: rename_func("SPLIT"),
|
||||
exp.Right: right_to_substring_sql,
|
||||
|
@ -422,5 +427,12 @@ class Hive(Dialect):
|
|||
expression = exp.DataType.build("text")
|
||||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
elif expression.is_type("float"):
|
||||
size_expression = expression.find(exp.DataTypeSize)
|
||||
if size_expression:
|
||||
size = int(size_expression.name)
|
||||
expression = (
|
||||
exp.DataType.build("float") if size <= 32 else exp.DataType.build("double")
|
||||
)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -193,6 +193,12 @@ class MySQL(Dialect):
|
|||
TokenType.VALUES,
|
||||
}
|
||||
|
||||
CONJUNCTION = {
|
||||
**parser.Parser.CONJUNCTION,
|
||||
TokenType.DAMP: exp.And,
|
||||
TokenType.XOR: exp.Xor,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = (
|
||||
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
|
||||
)
|
||||
|
|
|
@ -99,6 +99,9 @@ class Oracle(Dialect):
|
|||
LOCKING_READS_SUPPORTED = True
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = True
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -110,6 +113,7 @@ class Oracle(Dialect):
|
|||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
exp.DataType.Type.VARCHAR: "VARCHAR2",
|
||||
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
|
||||
exp.DataType.Type.NCHAR: "NCHAR",
|
||||
exp.DataType.Type.TEXT: "CLOB",
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
|
@ -140,15 +144,9 @@ class Oracle(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def column_sql(self, expression: exp.Column) -> str:
|
||||
column = super().column_sql(expression)
|
||||
return f"{column} (+)" if expression.args.get("join_mark") else column
|
||||
|
||||
def xmltable_sql(self, expression: exp.XMLTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
passing = self.expressions(expression, key="passing")
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
struct_extract_sql,
|
||||
|
@ -215,6 +216,9 @@ class Presto(Dialect):
|
|||
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
|
||||
|
@ -293,6 +297,7 @@ class Presto(Dialect):
|
|||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeBracket: lambda self, e: self.func(
|
||||
"ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
|
||||
|
|
|
@ -223,13 +223,14 @@ class Snowflake(Dialect):
|
|||
"IFF": exp.If.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
|
||||
"RLIKE": exp.RegexpLike.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEDIFF": _parse_datediff,
|
||||
"TIMESTAMPDIFF": _parse_datediff,
|
||||
"TO_ARRAY": exp.Array.from_arg_list,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
}
|
||||
|
||||
|
@ -361,12 +362,12 @@ class Snowflake(Dialect):
|
|||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
|
||||
),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
|
||||
|
@ -390,6 +391,24 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
|
||||
# Other dialects don't support all of the following parameters, so we need to
|
||||
# generate default values as necessary to ensure the transpilation is correct
|
||||
group = expression.args.get("group")
|
||||
parameters = expression.args.get("parameters") or (group and exp.Literal.string("c"))
|
||||
occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1))
|
||||
position = expression.args.get("position") or (occurrence and exp.Literal.number(1))
|
||||
|
||||
return self.func(
|
||||
"REGEXP_SUBSTR",
|
||||
expression.this,
|
||||
expression.expression,
|
||||
position,
|
||||
occurrence,
|
||||
parameters,
|
||||
group,
|
||||
)
|
||||
|
||||
def except_op(self, expression: exp.Except) -> str:
|
||||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||
|
|
|
@ -302,6 +302,7 @@ class TSQL(Dialect):
|
|||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
"XML": TokenType.XML,
|
||||
"OUTPUT": TokenType.RETURNING,
|
||||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
}
|
||||
|
||||
|
@ -469,6 +470,7 @@ class TSQL(Dialect):
|
|||
LOCKING_READS_SUPPORTED = True
|
||||
LIMIT_IS_TOP = True
|
||||
QUERY_HINTS = False
|
||||
RETURNING_END = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -532,3 +534,8 @@ class TSQL(Dialect):
|
|||
table = expression.args.get("table")
|
||||
table = f"{table} " if table else ""
|
||||
return f"RETURNS {table}{self.sql(expression, 'this')}"
|
||||
|
||||
def returning_sql(self, expression: exp.Returning) -> str:
|
||||
into = self.sql(expression, "into")
|
||||
into = self.seg(f"INTO {into}") if into else ""
|
||||
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
|
||||
|
|
|
@ -878,11 +878,11 @@ class DerivedTable(Expression):
|
|||
return [c.name for c in table_alias.args.get("columns") or []]
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
def selects(self) -> t.List[Expression]:
|
||||
return self.this.selects if isinstance(self.this, Subqueryable) else []
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [select.output_name for select in self.selects]
|
||||
|
||||
|
||||
|
@ -959,7 +959,7 @@ class Unionable(Expression):
|
|||
|
||||
class UDTF(DerivedTable, Unionable):
|
||||
@property
|
||||
def selects(self):
|
||||
def selects(self) -> t.List[Expression]:
|
||||
alias = self.args.get("alias")
|
||||
return alias.columns if alias else []
|
||||
|
||||
|
@ -1576,7 +1576,7 @@ class OnConflict(Expression):
|
|||
|
||||
|
||||
class Returning(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
arg_types = {"expressions": True, "into": False}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
|
||||
|
@ -2194,11 +2194,11 @@ class Subqueryable(Unionable):
|
|||
return with_.expressions
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
def selects(self) -> t.List[Expression]:
|
||||
raise NotImplementedError("Subqueryable objects must implement `selects`")
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
def named_selects(self) -> t.List[str]:
|
||||
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
|
||||
|
||||
def with_(
|
||||
|
@ -2282,7 +2282,6 @@ class Table(Expression):
|
|||
"pivots": False,
|
||||
"hints": False,
|
||||
"system_time": False,
|
||||
"wrapped": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2299,14 +2298,28 @@ class Table(Expression):
|
|||
def catalog(self) -> str:
|
||||
return self.text("catalog")
|
||||
|
||||
@property
|
||||
def selects(self) -> t.List[Expression]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def parts(self) -> t.List[Identifier]:
|
||||
"""Return the parts of a table in order catalog, db, table."""
|
||||
return [
|
||||
t.cast(Identifier, self.args[part])
|
||||
for part in ("catalog", "db", "this")
|
||||
if self.args.get(part)
|
||||
]
|
||||
parts: t.List[Identifier] = []
|
||||
|
||||
for arg in ("catalog", "db", "this"):
|
||||
part = self.args.get(arg)
|
||||
|
||||
if isinstance(part, Identifier):
|
||||
parts.append(part)
|
||||
elif isinstance(part, Dot):
|
||||
parts.extend(part.flatten())
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
# See the TSQL "Querying data in a system-versioned temporal table" page
|
||||
|
@ -2390,7 +2403,7 @@ class Union(Subqueryable):
|
|||
return this
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return self.this.unnest().named_selects
|
||||
|
||||
@property
|
||||
|
@ -2398,7 +2411,7 @@ class Union(Subqueryable):
|
|||
return self.this.is_star or self.expression.is_star
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
def selects(self) -> t.List[Expression]:
|
||||
return self.this.unnest().selects
|
||||
|
||||
@property
|
||||
|
@ -3517,6 +3530,10 @@ class Or(Connector):
|
|||
pass
|
||||
|
||||
|
||||
class Xor(Connector):
|
||||
pass
|
||||
|
||||
|
||||
class BitwiseAnd(Binary):
|
||||
pass
|
||||
|
||||
|
@ -4409,6 +4426,7 @@ class RegexpExtract(Func):
|
|||
"expression": True,
|
||||
"position": False,
|
||||
"occurrence": False,
|
||||
"parameters": False,
|
||||
"group": False,
|
||||
}
|
||||
|
||||
|
@ -5756,7 +5774,9 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
|
|||
raise ValueError(f"Cannot parse {table}")
|
||||
|
||||
return ".".join(
|
||||
part.sql(dialect=dialect) if not SAFE_IDENTIFIER_RE.match(part.name) else part.name
|
||||
part.sql(dialect=dialect, identify=True)
|
||||
if not SAFE_IDENTIFIER_RE.match(part.name)
|
||||
else part.name
|
||||
for part in table.parts
|
||||
)
|
||||
|
||||
|
|
|
@ -155,6 +155,12 @@ class Generator:
|
|||
# Whether or not to generate the limit as TOP <value> instead of LIMIT <value>
|
||||
LIMIT_IS_TOP = False
|
||||
|
||||
# Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
|
||||
RETURNING_END = True
|
||||
|
||||
# Whether or not to generate the (+) suffix for columns used in old-style join conditions
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
|
@ -556,7 +562,13 @@ class Generator:
|
|||
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
|
||||
|
||||
def column_sql(self, expression: exp.Column) -> str:
|
||||
return ".".join(
|
||||
join_mark = " (+)" if expression.args.get("join_mark") else ""
|
||||
|
||||
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
|
||||
join_mark = ""
|
||||
self.unsupported("Outer join syntax using the (+) operator is not supported.")
|
||||
|
||||
column = ".".join(
|
||||
self.sql(part)
|
||||
for part in (
|
||||
expression.args.get("catalog"),
|
||||
|
@ -567,6 +579,8 @@ class Generator:
|
|||
if part
|
||||
)
|
||||
|
||||
return f"{column}{join_mark}"
|
||||
|
||||
def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
|
@ -836,8 +850,11 @@ class Generator:
|
|||
limit = self.sql(expression, "limit")
|
||||
tables = self.expressions(expression, key="tables")
|
||||
tables = f" {tables}" if tables else ""
|
||||
sql = f"DELETE{tables}{this}{using}{where}{returning}{limit}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
if self.RETURNING_END:
|
||||
expression_sql = f"{this}{using}{where}{returning}{limit}"
|
||||
else:
|
||||
expression_sql = f"{returning}{this}{using}{where}{limit}"
|
||||
return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}")
|
||||
|
||||
def drop_sql(self, expression: exp.Drop) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -887,7 +904,8 @@ class Generator:
|
|||
unique = "UNIQUE " if expression.args.get("unique") else ""
|
||||
primary = "PRIMARY " if expression.args.get("primary") else ""
|
||||
amp = "AMP " if expression.args.get("amp") else ""
|
||||
name = f"{expression.name} " if expression.name else ""
|
||||
name = self.sql(expression, "this")
|
||||
name = f"{name} " if name else ""
|
||||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table} " if table else ""
|
||||
using = self.sql(expression, "using")
|
||||
|
@ -1134,7 +1152,13 @@ class Generator:
|
|||
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
|
||||
conflict = self.sql(expression, "conflict")
|
||||
returning = self.sql(expression, "returning")
|
||||
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
|
||||
|
||||
if self.RETURNING_END:
|
||||
expression_sql = f"{expression_sql}{conflict}{returning}"
|
||||
else:
|
||||
expression_sql = f"{returning}{expression_sql}{conflict}"
|
||||
|
||||
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1215,8 +1239,7 @@ class Generator:
|
|||
system_time = expression.args.get("system_time")
|
||||
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
|
||||
|
||||
sql = f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
return f"({sql})" if expression.args.get("wrapped") else sql
|
||||
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
|
@ -1276,7 +1299,11 @@ class Generator:
|
|||
where_sql = self.sql(expression, "where")
|
||||
returning = self.sql(expression, "returning")
|
||||
limit = self.sql(expression, "limit")
|
||||
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}{limit}"
|
||||
if self.RETURNING_END:
|
||||
expression_sql = f"{from_sql}{where_sql}{returning}{limit}"
|
||||
else:
|
||||
expression_sql = f"{returning}{from_sql}{where_sql}{limit}"
|
||||
sql = f"UPDATE {this} SET {set_sql}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
|
@ -2016,6 +2043,9 @@ class Generator:
|
|||
def and_sql(self, expression: exp.And) -> str:
|
||||
return self.connector_sql(expression, "AND")
|
||||
|
||||
def xor_sql(self, expression: exp.And) -> str:
|
||||
return self.connector_sql(expression, "XOR")
|
||||
|
||||
def connector_sql(self, expression: exp.Connector, op: str) -> str:
|
||||
if not self.pretty:
|
||||
return self.binary(expression, op)
|
||||
|
|
|
@ -104,7 +104,7 @@ def lineage(
|
|||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = next(
|
||||
(select for select in scope.selects if select.alias_or_name == column_name),
|
||||
(select for select in scope.expression.selects if select.alias_or_name == column_name),
|
||||
exp.Star() if scope.expression.is_star else None,
|
||||
)
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ def _unique_outputs(scope):
|
|||
grouped_outputs = set()
|
||||
|
||||
unique_outputs = set()
|
||||
for select in scope.selects:
|
||||
for select in scope.expression.selects:
|
||||
output = select.unalias()
|
||||
if output in grouped_expressions:
|
||||
grouped_outputs.add(output)
|
||||
|
@ -105,7 +105,7 @@ def _unique_outputs(scope):
|
|||
|
||||
def _has_single_output_row(scope):
|
||||
return isinstance(scope.expression, exp.Select) and (
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
|
||||
or _is_limit_1(scope)
|
||||
or not scope.expression.args.get("from")
|
||||
)
|
||||
|
|
|
@ -113,7 +113,7 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.selects
|
||||
expressions = scope.expression.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
|
||||
for e in expressions
|
||||
|
|
|
@ -12,7 +12,12 @@ def isolate_table_selects(expression, schema=None):
|
|||
continue
|
||||
|
||||
for _, source in scope.selected_sources.values():
|
||||
if not isinstance(source, exp.Table) or not schema.column_names(source):
|
||||
if (
|
||||
not isinstance(source, exp.Table)
|
||||
or not schema.column_names(source)
|
||||
or isinstance(source.parent, exp.Subquery)
|
||||
or isinstance(source.parent.parent, exp.Table)
|
||||
):
|
||||
continue
|
||||
|
||||
if not source.alias:
|
||||
|
|
|
@ -107,6 +107,7 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -166,7 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
if not inner_from:
|
||||
return False
|
||||
inner_from_table = inner_from.alias_or_name
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
|
||||
return any(
|
||||
col.table != inner_from_table
|
||||
for selection in selections
|
||||
|
|
|
@ -59,7 +59,7 @@ def reorder_joins(expression):
|
|||
dag = {name: other_table_names(join) for name, join in joins.items()}
|
||||
parent.set(
|
||||
"joins",
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
|
||||
)
|
||||
return expression
|
||||
|
||||
|
|
|
@ -42,7 +42,10 @@ def pushdown_predicates(expression):
|
|||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
if name in scope.selected_sources:
|
||||
pushdown(
|
||||
join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -48,12 +48,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
if any(select.is_star for select in right.selects):
|
||||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.selects):
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
referenced_columns[right] = [
|
||||
right.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.selects)
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
|
@ -90,7 +90,7 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
removed = False
|
||||
star = False
|
||||
|
||||
for selection in scope.selects:
|
||||
for selection in scope.expression.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
|
|
|
@ -192,13 +192,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if table and (not alias_expr or double_agg):
|
||||
column.set("table", table)
|
||||
elif not column.table and alias_expr and not double_agg:
|
||||
if isinstance(alias_expr, exp.Literal):
|
||||
if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
|
||||
if literal_index:
|
||||
column.replace(exp.Literal.number(i))
|
||||
else:
|
||||
column.replace(alias_expr.copy())
|
||||
|
||||
for i, projection in enumerate(scope.selects):
|
||||
for i, projection in enumerate(scope.expression.selects):
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
|
@ -239,7 +239,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
|
|||
ordered.set("this", new_expression)
|
||||
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
|
||||
|
||||
for ordered in ordereds:
|
||||
ordered = ordered.this
|
||||
|
@ -270,7 +270,7 @@ def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t
|
|||
|
||||
def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
||||
try:
|
||||
return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
|
||||
return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
|
||||
except IndexError:
|
||||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
|
||||
|
@ -347,7 +347,7 @@ def _expand_stars(
|
|||
if not pivot_output_columns:
|
||||
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
|
||||
|
||||
for expression in scope.selects:
|
||||
for expression in scope.expression.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
|
@ -446,7 +446,7 @@ def _qualify_outputs(scope: Scope):
|
|||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
):
|
||||
if isinstance(selection, exp.Subquery):
|
||||
if not selection.output_name:
|
||||
|
|
|
@ -15,7 +15,8 @@ def qualify_tables(
|
|||
schema: t.Optional[Schema] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified, unnested tables.
|
||||
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
|
||||
(t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
|
||||
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
|
@ -23,18 +24,9 @@ def qualify_tables(
|
|||
>>> qualify_tables(expression, db="db").sql()
|
||||
'SELECT 1 FROM db.tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl)")
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2'
|
||||
|
||||
Note:
|
||||
This rule effectively enforces a left-to-right join order, since all joins
|
||||
are unnested. This means that the optimizer doesn't necessarily preserve the
|
||||
original join order, e.g. when parentheses are used to specify it explicitly.
|
||||
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
|
||||
|
||||
Args:
|
||||
expression: Expression to qualify
|
||||
|
@ -49,6 +41,13 @@ def qualify_tables(
|
|||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
unnested = derived_table.unnest()
|
||||
if isinstance(unnested, exp.Table):
|
||||
joins = unnested.args.pop("joins", None)
|
||||
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
||||
derived_table.this.set("joins", joins)
|
||||
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = next_alias_name()
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
@ -66,19 +65,9 @@ def qualify_tables(
|
|||
if not source.args.get("catalog"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
# Unnest joins attached in tables by appending them to the closest query
|
||||
for join in source.args.get("joins") or []:
|
||||
scope.expression.append("joins", join)
|
||||
|
||||
source.set("joins", None)
|
||||
source.set("wrapped", None)
|
||||
|
||||
if not source.alias:
|
||||
source = source.replace(
|
||||
alias(
|
||||
source, name or source.name or next_alias_name(), copy=True, table=True
|
||||
)
|
||||
)
|
||||
# Mutates the source by attaching an alias to it
|
||||
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
|
||||
|
||||
pivots = source.args.get("pivots")
|
||||
if pivots and not pivots[0].alias:
|
||||
|
|
|
@ -122,7 +122,11 @@ class Scope:
|
|||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
|
@ -274,6 +278,7 @@ class Scope:
|
|||
not ancestor
|
||||
or column.table
|
||||
or isinstance(ancestor, exp.Select)
|
||||
or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
|
||||
or (
|
||||
isinstance(ancestor, exp.Order)
|
||||
and (
|
||||
|
@ -340,23 +345,6 @@ class Scope:
|
|||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
Select expressions of this scope.
|
||||
|
||||
For example, for the following expression:
|
||||
SELECT 1 as a, 2 as b FROM x
|
||||
|
||||
The outputs are the "1 as a" and "2 as b" expressions.
|
||||
|
||||
Returns:
|
||||
list[exp.Expression]: expressions
|
||||
"""
|
||||
if isinstance(self.expression, exp.Union):
|
||||
return self.expression.unnest().selects
|
||||
return self.expression.selects
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
|
@ -548,6 +536,8 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
|
@ -620,6 +610,15 @@ def _traverse_ctes(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _is_subquery_scope(expression: exp.Subquery) -> bool:
|
||||
"""
|
||||
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope.
|
||||
If an alias is present, it shadows all names under the Subquery, so that's an
|
||||
exception to this rule.
|
||||
"""
|
||||
return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
sources = {}
|
||||
|
||||
|
@ -629,9 +628,8 @@ def _traverse_tables(scope):
|
|||
if from_:
|
||||
expressions.append(from_.this)
|
||||
|
||||
for expression in (scope.expression, *scope.find_all(exp.Table)):
|
||||
for join in expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
if isinstance(scope.expression, exp.Table):
|
||||
expressions.append(scope.expression)
|
||||
|
@ -655,6 +653,8 @@ def _traverse_tables(scope):
|
|||
sources[find_new_name(sources, table_name)] = expression
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
continue
|
||||
|
||||
if not isinstance(expression, exp.DerivedTable):
|
||||
|
@ -664,10 +664,15 @@ def _traverse_tables(scope):
|
|||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
scopes = scope.udtf_scopes
|
||||
else:
|
||||
elif _is_subquery_scope(expression):
|
||||
lateral_sources = None
|
||||
scope_type = ScopeType.DERIVED_TABLE
|
||||
scopes = scope.derived_table_scopes
|
||||
else:
|
||||
# Makes sure we check for possible sources in nested table constructs
|
||||
expressions.append(expression.this)
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
continue
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
|
@ -728,7 +733,11 @@ def walk_in_scope(expression, bfs=True):
|
|||
continue
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
|
|
|
@ -1708,6 +1708,8 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.TABLE)
|
||||
this = self._parse_table(schema=True)
|
||||
|
||||
returning = self._parse_returning()
|
||||
|
||||
return self.expression(
|
||||
exp.Insert,
|
||||
this=this,
|
||||
|
@ -1717,7 +1719,7 @@ class Parser(metaclass=_Parser):
|
|||
and self._parse_conjunction(),
|
||||
expression=self._parse_ddl_select(),
|
||||
conflict=self._parse_on_conflict(),
|
||||
returning=self._parse_returning(),
|
||||
returning=returning or self._parse_returning(),
|
||||
overwrite=overwrite,
|
||||
alternative=alternative,
|
||||
ignore=ignore,
|
||||
|
@ -1761,8 +1763,11 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_returning(self) -> t.Optional[exp.Returning]:
|
||||
if not self._match(TokenType.RETURNING):
|
||||
return None
|
||||
|
||||
return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column))
|
||||
return self.expression(
|
||||
exp.Returning,
|
||||
expressions=self._parse_csv(self._parse_expression),
|
||||
into=self._match(TokenType.INTO) and self._parse_table_part(),
|
||||
)
|
||||
|
||||
def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
|
||||
if not self._match(TokenType.FORMAT):
|
||||
|
@ -1824,25 +1829,30 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.FROM, advance=False):
|
||||
tables = self._parse_csv(self._parse_table) or None
|
||||
|
||||
returning = self._parse_returning()
|
||||
|
||||
return self.expression(
|
||||
exp.Delete,
|
||||
tables=tables,
|
||||
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
|
||||
using=self._match(TokenType.USING) and self._parse_table(joins=True),
|
||||
where=self._parse_where(),
|
||||
returning=self._parse_returning(),
|
||||
returning=returning or self._parse_returning(),
|
||||
limit=self._parse_limit(),
|
||||
)
|
||||
|
||||
def _parse_update(self) -> exp.Update:
|
||||
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
|
||||
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
|
||||
returning = self._parse_returning()
|
||||
return self.expression(
|
||||
exp.Update,
|
||||
**{ # type: ignore
|
||||
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
|
||||
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
|
||||
"this": this,
|
||||
"expressions": expressions,
|
||||
"from": self._parse_from(joins=True),
|
||||
"where": self._parse_where(),
|
||||
"returning": self._parse_returning(),
|
||||
"returning": returning or self._parse_returning(),
|
||||
"limit": self._parse_limit(),
|
||||
},
|
||||
)
|
||||
|
@ -1969,31 +1979,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._match_r_paren()
|
||||
|
||||
alias = None
|
||||
|
||||
# Ensure "wrapped" tables are not parsed as Subqueries. The exception to this is when there's
|
||||
# an alias that can be applied to the parentheses, because that would shadow all wrapped table
|
||||
# names, and so we want to parse it as a Subquery to represent the inner scope appropriately.
|
||||
# Additionally, we want the node under the Subquery to be an actual query, so we will replace
|
||||
# the table reference with a star query that selects from it.
|
||||
if isinstance(this, exp.Table):
|
||||
alias = self._parse_table_alias()
|
||||
if not alias:
|
||||
this.set("wrapped", True)
|
||||
return this
|
||||
|
||||
this.set("wrapped", None)
|
||||
joins = this.args.pop("joins", None)
|
||||
this = this.replace(exp.select("*").from_(this.copy(), copy=False))
|
||||
this.set("joins", joins)
|
||||
|
||||
subquery = self._parse_subquery(this, parse_alias=parse_subquery_alias and not alias)
|
||||
if subquery and alias:
|
||||
subquery.set("alias", alias)
|
||||
|
||||
# We return early here so that the UNION isn't attached to the subquery by the
|
||||
# following call to _parse_set_operations, but instead becomes the parent node
|
||||
return subquery
|
||||
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
|
||||
elif self._match(TokenType.VALUES):
|
||||
this = self.expression(
|
||||
exp.Values,
|
||||
|
@ -3086,7 +3074,13 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
||||
expressions=[
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value],
|
||||
expressions=expressions,
|
||||
nested=nested,
|
||||
)
|
||||
],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
|
@ -3147,7 +3141,7 @@ class Parser(metaclass=_Parser):
|
|||
return value
|
||||
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value.upper()],
|
||||
this=exp.DataType.Type[type_token.value],
|
||||
expressions=expressions,
|
||||
nested=nested,
|
||||
values=values,
|
||||
|
|
|
@ -52,6 +52,7 @@ class TokenType(AutoName):
|
|||
PARAMETER = auto()
|
||||
SESSION_PARAMETER = auto()
|
||||
DAMP = auto()
|
||||
XOR = auto()
|
||||
|
||||
BLOCK_START = auto()
|
||||
BLOCK_END = auto()
|
||||
|
@ -590,6 +591,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"OFFSET": TokenType.OFFSET,
|
||||
"ON": TokenType.ON,
|
||||
"OR": TokenType.OR,
|
||||
"XOR": TokenType.XOR,
|
||||
"ORDER BY": TokenType.ORDER_BY,
|
||||
"ORDINALITY": TokenType.ORDINALITY,
|
||||
"OUTER": TokenType.OUTER,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue