1
0
Fork 0

Adding upstream version 16.2.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 16:00:14 +01:00
parent 577b79f5a7
commit d61627452f
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
106 changed files with 41940 additions and 40162 deletions

View file

@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot")
class Generator:
"""
Generator interprets the given syntax tree and produces a SQL string as an output.
Generator converts a given syntax tree to the corresponding SQL string.
Args:
time_mapping (dict): the dictionary of custom time mappings in which the key
represents a python time format and the output the target time format
time_trie (trie): a trie of the time_mapping keys
pretty (bool): if set to True the returned string will be formatted. Default: False.
quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
normalize_functions (str): normalize function names, "upper", "lower", or None
Default: "upper"
alias_post_tablesample (bool): if the table alias comes after tablesample
Default: False
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
Default: False
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
unsupported expressions. Default ErrorLevel.WARN.
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
Default: "nulls_are_small"
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
pretty: Whether or not to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
normalize: Whether or not to normalize identifiers to lowercase.
Default: False.
pad: Determines the pad size in a formatted string.
Default: 2.
indent: Determines the indentation size in a formatted string.
Default: 2.
normalize_functions: Whether or not to normalize all function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
Default ErrorLevel.WARN.
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma (bool): if the the comma is leading or trailing in select statements
leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
@ -86,6 +71,7 @@ class Generator:
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@ -138,15 +124,24 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether a table is allowed to be renamed with a db
# Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
# The string used for creating index on a table
# The string used for creating an index on a table
INDEX_ON = "ON"
# Whether or not join hints should be generated
JOIN_HINTS = True
# Whether or not table hints should be generated
TABLE_HINTS = True
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -228,6 +223,7 @@ class Generator:
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
@ -235,128 +231,110 @@ class Generator:
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
JOIN_HINTS = True
TABLE_HINTS = True
IS_BOOL = True
# Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Select,
exp.From,
exp.Where,
exp.With,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
exp.Literal,
exp.Neg,
exp.Paren,
)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
STRING_ESCAPE = "'"
IDENTIFIER_ESCAPE = '"'
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
RAW_START: t.Optional[str] = None
RAW_END: t.Optional[str] = None
__slots__ = (
"time_mapping",
"time_trie",
"pretty",
"quote_start",
"quote_end",
"identifier_start",
"identifier_end",
"bit_start",
"bit_end",
"hex_start",
"hex_end",
"byte_start",
"byte_end",
"raw_start",
"raw_end",
"identify",
"normalize",
"string_escape",
"identifier_escape",
"pad",
"index_offset",
"unnest_column_only",
"alias_post_tablesample",
"identifiers_can_start_with_digit",
"_indent",
"normalize_functions",
"unsupported_level",
"unsupported_messages",
"null_ordering",
"max_unsupported",
"_indent",
"leading_comma",
"max_text_width",
"comments",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_leading_comma",
"_max_text_width",
"_comments",
"_cache",
)
def __init__(
self,
time_mapping=None,
time_trie=None,
pretty=None,
quote_start=None,
quote_end=None,
identifier_start=None,
identifier_end=None,
bit_start=None,
bit_end=None,
hex_start=None,
hex_end=None,
byte_start=None,
byte_end=None,
raw_start=None,
raw_end=None,
identify=False,
normalize=False,
string_escape=None,
identifier_escape=None,
pad=2,
indent=2,
index_offset=0,
unnest_column_only=False,
alias_post_tablesample=False,
identifiers_can_start_with_digit=False,
normalize_functions="upper",
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
leading_comma=False,
max_text_width=80,
comments=True,
pretty: t.Optional[bool] = None,
identify: str | bool = False,
normalize: bool = False,
pad: int = 2,
indent: int = 2,
normalize_functions: t.Optional[str | bool] = None,
unsupported_level: ErrorLevel = ErrorLevel.WARN,
max_unsupported: int = 3,
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
):
import sqlglot
self.time_mapping = time_mapping or {}
self.time_trie = time_trie
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.quote_start = quote_start or "'"
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
self.identifier_end = identifier_end or '"'
self.bit_start = bit_start
self.bit_end = bit_end
self.hex_start = hex_start
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
self.raw_start = raw_start
self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
self.identifier_escape = identifier_escape or '"'
self.pad = pad
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
self.alias_post_tablesample = alias_post_tablesample
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
self.normalize_functions = normalize_functions
self.unsupported_level = unsupported_level
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
self._escaped_quote_end = self.string_escape + self.quote_end
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
self._cache = None
self.unsupported_level = unsupported_level
self.max_unsupported = max_unsupported
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
@ -364,17 +342,19 @@ class Generator:
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Generates the SQL string corresponding to the given syntax tree.
Args
expression: the syntax tree.
cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Args:
expression: The syntax tree.
cache: An optional sql string cache. This leverages the hash of an Expression
which can be slow to compute, so only use it if you set _hash on each node.
Returns
the SQL string.
Returns:
The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
@ -414,7 +394,11 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
if self.comments
else None
)
if not comments or isinstance(expression, exp.Binary):
return sql
@ -454,7 +438,7 @@ class Generator:
return result
def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper":
if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
@ -522,7 +506,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
@ -770,25 +754,25 @@ class Generator:
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
if self.bit_start:
return f"{self.bit_start}{this}{self.bit_end}"
if self.BIT_START:
return f"{self.BIT_START}{this}{self.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
if self.hex_start:
return f"{self.hex_start}{this}{self.hex_end}"
if self.HEX_START:
return f"{self.HEX_START}{this}{self.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
if self.byte_start:
return f"{self.byte_start}{this}{self.byte_end}"
if self.BYTE_START:
return f"{self.BYTE_START}{this}{self.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
if self.raw_start:
return f"{self.raw_start}{expression.name}{self.raw_end}"
if self.RAW_START:
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
@ -883,24 +867,27 @@ class Generator:
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table} " if table else ""
using = self.sql(expression, "using")
using = f"USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@ -1197,7 +1184,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@ -1372,7 +1359,15 @@ class Generator:
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
args = ", ".join(
sql
for sql in (
self.sql(expression, "offset"),
self.sql(expression, "expression"),
)
if sql
)
return f"{this}{self.seg('LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@ -1418,10 +1413,10 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = text.replace(self.quote_end, self._escaped_quote_end)
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
@ -1463,9 +1458,9 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.null_ordering == "nulls_are_large"
nulls_are_small = self.null_ordering == "nulls_are_small"
nulls_are_last = self.null_ordering == "nulls_are_last"
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
@ -1521,7 +1516,7 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
@ -1540,12 +1535,19 @@ class Generator:
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
*self.offset_limit_modifiers(expression, fetch, limit),
*self.after_limit_modifiers(expression),
sep="",
)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:
return [
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
]
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
@ -1634,7 +1636,7 @@ class Generator:
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.unnest_column_only:
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@ -1697,7 +1699,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@ -1729,7 +1731,7 @@ class Generator:
statements.append("END")
if self.pretty and self.text_width(statements) > self._max_text_width:
if self.pretty and self.text_width(statements) > self.max_text_width:
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
@ -1759,10 +1761,11 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
expressions = expression.expressions
if self.STRICT_STRING_CONCAT:
expressions = (exp.cast(e, "text") for e in expressions)
return self.func("CONCAT", *expressions)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
@ -1785,9 +1788,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
modifier = expression.args.get("modifier")
@ -1798,7 +1799,6 @@ class Generator:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
expressions = self.expressions(expression)
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
@ -1811,7 +1811,11 @@ class Generator:
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
return self.func(
"JSON_OBJECT",
*expression.expressions,
suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
)
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
@ -1930,7 +1934,7 @@ class Generator:
for i, e in enumerate(expression.flatten(unnest=False))
)
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
@ -2093,6 +2097,11 @@ class Generator:
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
if self.STRICT_STRING_CONCAT:
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.dpipe_sql(expression)
def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
@ -2127,7 +2136,7 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
@ -2197,12 +2206,18 @@ class Generator:
return self.func(expression.sql_name(), *args)
def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
return f"{self.normalize_func(name)}({self.format_args(*args)})"
def func(
self,
name: str,
*args: t.Optional[exp.Expression | str],
prefix: str = "(",
suffix: str = ")",
) -> str:
return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
@ -2210,7 +2225,9 @@ class Generator:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
return format_time(
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
)
def expressions(
self,
@ -2242,7 +2259,7 @@ class Generator:
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
if self._leading_comma:
if self.leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(