Adding upstream version 7.1.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
291e0c125c
commit
768d386bf5
42 changed files with 1430 additions and 253 deletions
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||
from sqlglot.helper import apply_index_offset, csv, ensure_list
|
||||
from sqlglot.helper import apply_index_offset, csv
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -43,14 +43,18 @@ class Generator:
|
|||
Default: 3
|
||||
leading_comma (bool): if the the comma is leading or trailing in select statements
|
||||
Default: False
|
||||
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
|
||||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})",
|
||||
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
|
||||
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -111,6 +115,7 @@ class Generator:
|
|||
"_replace_backslash",
|
||||
"_escaped_quote_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -135,6 +140,7 @@ class Generator:
|
|||
null_ordering=None,
|
||||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -162,6 +168,7 @@ class Generator:
|
|||
self._replace_backslash = self.escape == "\\"
|
||||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
|
||||
def generate(self, expression):
|
||||
"""
|
||||
|
@ -268,7 +275,7 @@ class Generator:
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
return self.sql(expression, "expression")
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
|
||||
|
||||
def uncache_sql(self, expression):
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -364,6 +371,9 @@ class Generator:
|
|||
)
|
||||
return self.prepend_ctes(expression, expression_sql)
|
||||
|
||||
def describe_sql(self, expression):
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
|
||||
def prepend_ctes(self, expression, sql):
|
||||
with_ = self.sql(expression, "with")
|
||||
if with_:
|
||||
|
@ -405,6 +415,12 @@ class Generator:
|
|||
)
|
||||
return f"{type_sql}{nested}"
|
||||
|
||||
def directory_sql(self, expression):
|
||||
local = "LOCAL " if expression.args.get("local") else ""
|
||||
row_format = self.sql(expression, "row_format")
|
||||
row_format = f" {row_format}" if row_format else ""
|
||||
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
|
||||
|
||||
def delete_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
where_sql = self.sql(expression, "where")
|
||||
|
@ -513,13 +529,19 @@ class Generator:
|
|||
return f"{key}={value}"
|
||||
|
||||
def insert_sql(self, expression):
|
||||
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
|
||||
this = self.sql(expression, "this")
|
||||
overwrite = expression.args.get("overwrite")
|
||||
|
||||
if isinstance(expression.this, exp.Directory):
|
||||
this = "OVERWRITE " if overwrite else "INTO "
|
||||
else:
|
||||
this = "OVERWRITE TABLE " if overwrite else "INTO "
|
||||
|
||||
this = f"{this}{self.sql(expression, 'this')}"
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression):
|
||||
|
@ -534,6 +556,21 @@ class Generator:
|
|||
def introducer_sql(self, expression):
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def rowformat_sql(self, expression):
|
||||
fields = expression.args.get("fields")
|
||||
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
|
||||
escaped = expression.args.get("escaped")
|
||||
escaped = f" ESCAPED BY {escaped}" if escaped else ""
|
||||
items = expression.args.get("collection_items")
|
||||
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
|
||||
keys = expression.args.get("map_keys")
|
||||
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
|
||||
lines = expression.args.get("lines")
|
||||
lines = f" LINES TERMINATED BY {lines}" if lines else ""
|
||||
null = expression.args.get("null")
|
||||
null = f" NULL DEFINED AS {null}" if null else ""
|
||||
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
|
||||
|
||||
def table_sql(self, expression):
|
||||
table = ".".join(
|
||||
part
|
||||
|
@ -688,6 +725,19 @@ class Generator:
|
|||
return f"{self.quote_start}{text}{self.quote_end}"
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression):
|
||||
local = " LOCAL" if expression.args.get("local") else ""
|
||||
inpath = f" INPATH {self.sql(expression, 'inpath')}"
|
||||
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
|
||||
this = f" INTO TABLE {self.sql(expression, 'this')}"
|
||||
partition = self.sql(expression, "partition")
|
||||
partition = f" {partition}" if partition else ""
|
||||
input_format = self.sql(expression, "input_format")
|
||||
input_format = f" INPUTFORMAT {input_format}" if input_format else ""
|
||||
serde = self.sql(expression, "serde")
|
||||
serde = f" SERDE {serde}" if serde else ""
|
||||
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
|
||||
|
||||
def null_sql(self, *_):
|
||||
return "NULL"
|
||||
|
||||
|
@ -885,20 +935,24 @@ class Generator:
|
|||
return f"EXISTS{self.wrap(expression)}"
|
||||
|
||||
def case_sql(self, expression):
|
||||
this = self.indent(self.sql(expression, "this"), skip_first=True)
|
||||
this = f" {this}" if this else ""
|
||||
ifs = []
|
||||
this = self.sql(expression, "this")
|
||||
statements = [f"CASE {this}" if this else "CASE"]
|
||||
|
||||
for e in expression.args["ifs"]:
|
||||
ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}"))
|
||||
ifs.append(self.indent(f"THEN {self.sql(e, 'true')}"))
|
||||
statements.append(f"WHEN {self.sql(e, 'this')}")
|
||||
statements.append(f"THEN {self.sql(e, 'true')}")
|
||||
|
||||
if expression.args.get("default") is not None:
|
||||
ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}"))
|
||||
default = self.sql(expression, "default")
|
||||
|
||||
ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs)
|
||||
statement = f"CASE{this}{ifs}{self.seg('END')}"
|
||||
return statement
|
||||
if default:
|
||||
statements.append(f"ELSE {default}")
|
||||
|
||||
statements.append("END")
|
||||
|
||||
if self.pretty and self.text_width(statements) > self._max_text_width:
|
||||
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
|
||||
|
||||
return " ".join(statements)
|
||||
|
||||
def constraint_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -970,7 +1024,7 @@ class Generator:
|
|||
return f"REFERENCES {this}({expressions})"
|
||||
|
||||
def anonymous_sql(self, expression):
|
||||
args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
|
||||
args = self.format_args(*expression.expressions)
|
||||
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
|
||||
|
||||
def paren_sql(self, expression):
|
||||
|
@ -1008,7 +1062,9 @@ class Generator:
|
|||
if not self.pretty:
|
||||
return self.binary(expression, op)
|
||||
|
||||
return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False))
|
||||
sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
|
||||
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
|
||||
return f"{sep}{op} ".join(sqls)
|
||||
|
||||
def bitwiseand_sql(self, expression):
|
||||
return self.binary(expression, "&")
|
||||
|
@ -1039,7 +1095,7 @@ class Generator:
|
|||
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
|
||||
|
||||
def distinct_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
this = self.expressions(expression, flat=True)
|
||||
this = f" {this}" if this else ""
|
||||
|
||||
on = self.sql(expression, "on")
|
||||
|
@ -1128,13 +1184,23 @@ class Generator:
|
|||
|
||||
def function_fallback_sql(self, expression):
|
||||
args = []
|
||||
for arg_key in expression.arg_types:
|
||||
arg_value = ensure_list(expression.args.get(arg_key) or [])
|
||||
for a in arg_value:
|
||||
args.append(self.sql(a))
|
||||
for arg_value in expression.args.values():
|
||||
if isinstance(arg_value, list):
|
||||
for value in arg_value:
|
||||
args.append(value)
|
||||
elif arg_value:
|
||||
args.append(arg_value)
|
||||
|
||||
args_str = self.indent(", ".join(args), skip_first=True, skip_last=True)
|
||||
return f"{self.normalize_func(expression.sql_name())}({args_str})"
|
||||
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
|
||||
|
||||
def format_args(self, *args):
|
||||
args = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
if self.pretty and self.text_width(args) > self._max_text_width:
|
||||
return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(args)
|
||||
|
||||
def text_width(self, args):
|
||||
return sum(len(arg) for arg in args)
|
||||
|
||||
def format_time(self, expression):
|
||||
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue