1
0
Fork 0

Merging upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:17:09 +01:00
parent d4fe7bdb16
commit 90988d8258
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions

View file

@ -9,10 +9,11 @@ from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
logger = logging.getLogger("sqlglot")
@ -58,9 +59,6 @@ class Generator:
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.TsOrDsAdd: lambda self, e: self.func(
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
@ -108,9 +106,6 @@ class Generator:
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
}
# Whether the base comes first
LOG_BASE_FIRST = True
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
@ -201,7 +196,7 @@ class Generator:
VALUES_AS_TABLE = True
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
UNNEST_WITH_ORDINALITY = True
@ -212,9 +207,6 @@ class Generator:
# Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
SEMI_ANTI_JOIN_WITH_SIDE = True
# Whether or not session variables / parameters are supported, e.g. @x in T-SQL
SUPPORTS_PARAMETERS = True
# Whether or not to include the type of a computed column in the CREATE DDL
COMPUTED_COLUMN_WITH_TYPE = True
@ -230,12 +222,15 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
# Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
SUPPORTS_NESTED_CTES = True
# Whether or not conditions require booleans WHERE x = 0 vs WHERE x
ENSURE_BOOLS = False
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
# Whether or not CONCAT requires >1 arguments
SUPPORTS_SINGLE_ARG_CONCAT = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -335,6 +330,7 @@ class Generator:
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
# Keywords that can't be used as unquoted identifier names
@ -368,37 +364,13 @@ class Generator:
exp.Paren,
)
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
can_identify: t.Callable[[str, str | bool], bool]
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
TOKENIZER_CLASS = Tokenizer
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
__slots__ = (
"pretty",
"identify",
@ -411,6 +383,7 @@ class Generator:
"leading_comma",
"max_text_width",
"comments",
"dialect",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
@ -429,8 +402,10 @@ class Generator:
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
dialect: DialectType = None,
):
import sqlglot
from sqlglot.dialects import Dialect
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.identify = identify
@ -442,16 +417,19 @@ class Generator:
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
self.dialect = Dialect.get_or_raise(dialect)
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
self._escaped_quote_end: str = (
self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END
)
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END
)
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
@ -469,23 +447,14 @@ class Generator:
if copy:
expression = expression.copy()
# Some dialects only support CTEs at the top level expression, so we need to bubble up nested
# CTEs to that level in order to produce a syntactically valid expression. This transformation
# happens here to minimize code duplication, since many expressions support CTEs.
if (
not self.SUPPORTS_NESTED_CTES
and isinstance(expression, exp.Expression)
and not expression.parent
and "with" in expression.arg_types
and any(node.parent is not expression for node in expression.find_all(exp.With))
):
from sqlglot.transforms import move_ctes_to_top_level
expression = move_ctes_to_top_level(expression)
expression = self.preprocess(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@ -495,10 +464,26 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def preprocess(self, expression: exp.Expression) -> exp.Expression:
"""Apply generic preprocessing transformations to a given expression."""
if (
not expression.parent
and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES
and any(node.parent is not expression for node in expression.find_all(exp.With))
):
from sqlglot.transforms import move_ctes_to_top_level
expression = move_ctes_to_top_level(expression)
if self.ENSURE_BOOLS:
from sqlglot.transforms import ensure_bools
expression = ensure_bools(expression)
return expression
def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
@ -752,9 +737,24 @@ class Generator:
return f"GENERATED{this} AS {expr}{sequence_opts}"
def generatedasrowcolumnconstraint_sql(
self, expression: exp.GeneratedAsRowColumnConstraint
) -> str:
start = "START" if expression.args["start"] else "END"
hidden = " HIDDEN" if expression.args.get("hidden") else ""
return f"GENERATED ALWAYS AS ROW {start}{hidden}"
def periodforsystemtimeconstraint_sql(
self, expression: exp.PeriodForSystemTimeConstraint
) -> str:
return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
def transformcolumnconstraint_sql(self, expression: exp.TransformColumnConstraint) -> str:
return f"AS {self.sql(expression, 'this')}"
def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
@ -900,32 +900,32 @@ class Generator:
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
if not alias and not self.UNNEST_COLUMN_ONLY:
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
alias = "_t"
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
if self.BIT_START:
return f"{self.BIT_START}{this}{self.BIT_END}"
if self.dialect.BIT_START:
return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
if self.HEX_START:
return f"{self.HEX_START}{this}{self.HEX_END}"
if self.dialect.HEX_START:
return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
if self.BYTE_START:
return f"{self.BYTE_START}{this}{self.BYTE_END}"
if self.dialect.BYTE_START:
return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
@ -1065,14 +1065,14 @@ class Generator:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
text = text.replace(self.dialect.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or self.can_identify(text, self.identify)
or self.dialect.can_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
or (not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@ -1121,7 +1121,7 @@ class Generator:
expressions = self.expressions(properties, sep=sep, indent=False)
if expressions:
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
@ -1286,6 +1286,21 @@ class Generator:
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
return f"{data_sql}{statistics_sql}"
def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str:
sql = "WITH(SYSTEM_VERSIONING=ON"
if expression.this:
history_table = self.sql(expression, "this")
sql = f"{sql}(HISTORY_TABLE={history_table}"
if expression.expression:
data_consistency_check = self.sql(expression, "expression")
sql = f"{sql}, DATA_CONSISTENCY_CHECK={data_consistency_check}"
sql = f"{sql})"
return f"{sql})"
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
@ -1387,13 +1402,13 @@ class Generator:
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
self.sql(expression, "db"),
self.sql(expression, "this"),
]
if part
self.sql(part)
for part in (
expression.args.get("catalog"),
expression.args.get("db"),
expression.args.get("this"),
)
if part is not None
)
version = self.sql(expression, "version")
@ -1426,7 +1441,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@ -1676,12 +1691,16 @@ class Generator:
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
args = ", ".join(
self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e)
args = [
self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e
for e in (expression.args.get(k) for k in ("offset", "expression"))
if e
)
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
]
args_sql = ", ".join(self.sql(e) for e in args)
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@ -1732,13 +1751,13 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}"
return text
def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.INVERSE_ESCAPE_SEQUENCES:
text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
if self.dialect.INVERSE_ESCAPE_SEQUENCES:
text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
@ -1782,9 +1801,11 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large"
nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last"
this = self.sql(expression, "this")
sort_order = " DESC" if desc else (" ASC" if desc is False else "")
nulls_sort_change = ""
@ -1799,13 +1820,13 @@ class Generator:
):
nulls_sort_change = " NULLS LAST"
# If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
return f"{this}{sort_order}{nulls_sort_change}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
@ -1933,10 +1954,13 @@ class Generator:
)
kind = ""
# We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
# are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
f"SELECT{top}{hint}{distinct}{kind}{expressions}",
f"SELECT{top_distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@ -1961,7 +1985,7 @@ class Generator:
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this
return f"{self.PARAMETER_TOKEN}{this}"
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
@ -2009,7 +2033,7 @@ class Generator:
if alias and isinstance(offset, exp.Expression):
alias.append("columns", offset)
if alias and self.UNNEST_COLUMN_ONLY:
if alias and self.dialect.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@ -2080,14 +2104,14 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions = apply_index_offset(
expression.this,
expression.expressions,
self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
def safebracket_sql(self, expression: exp.SafeBracket) -> str:
return self.bracket_sql(expression)
def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
@ -2145,12 +2169,33 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
expressions = expression.expressions
if self.STRICT_STRING_CONCAT:
expressions = (exp.cast(e, "text") for e in expressions)
def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]:
args = expression.expressions
if isinstance(expression, exp.ConcatWs):
args = args[1:] # Skip the delimiter
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
args = [exp.cast(e, "text") for e in args]
if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"):
args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args]
return args
def concat_sql(self, expression: exp.Concat) -> str:
expressions = self.convert_concat_args(expression)
# Some dialects don't allow a single-argument CONCAT call
if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1:
return self.sql(expressions[0])
return self.func("CONCAT", *expressions)
def concatws_sql(self, expression: exp.ConcatWs) -> str:
return self.func(
"CONCAT_WS", seq_get(expression.expressions, 0), *self.convert_concat_args(expression)
)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@ -2493,14 +2538,7 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
actions = self.expressions(
expression,
key="actions",
prefix="ADD COLUMN ",
)
else:
actions = f"ADD {self.expressions(expression, key='actions')}"
actions = self.add_column_sql(expression)
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
@ -2512,6 +2550,15 @@ class Generator:
only = " ONLY" if expression.args.get("only") else ""
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
def add_column_sql(self, expression: exp.AlterTable) -> str:
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
return self.expressions(
expression,
key="actions",
prefix="ADD COLUMN ",
)
return f"ADD {self.expressions(expression, key='actions', flat=True)}"
def droppartition_sql(self, expression: exp.DropPartition) -> str:
expressions = self.expressions(expression)
exists = " IF EXISTS " if expression.args.get("exists") else " "
@ -2551,14 +2598,31 @@ class Generator:
)
def dpipe_sql(self, expression: exp.DPipe) -> str:
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.binary(expression, "||")
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
if self.STRICT_STRING_CONCAT:
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.dpipe_sql(expression)
def div_sql(self, expression: exp.Div) -> str:
l, r = expression.left, expression.right
if not self.dialect.SAFE_DIVISION and expression.args.get("safe"):
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))
if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
*exp.DataType.FLOAT_TYPES
):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))
elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type(*exp.DataType.INTEGER_TYPES):
return self.sql(
exp.cast(
l / r,
to=exp.DataType.Type.BIGINT,
)
)
return self.binary(expression, "/")
def overlaps_sql(self, expression: exp.Overlaps) -> str:
@ -2573,6 +2637,9 @@ class Generator:
def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
def propertyeq_sql(self, expression: exp.PropertyEQ) -> str:
return self.binary(expression, ":=")
def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
@ -2641,10 +2708,13 @@ class Generator:
return self.cast_sql(expression, safe_prefix="TRY_")
def log_sql(self, expression: exp.Log) -> str:
args = list(expression.args.values())
if not self.LOG_BASE_FIRST:
args.reverse()
return self.func("LOG", *args)
this = expression.this
expr = expression.expression
if not self.dialect.LOG_BASE_FIRST:
this, expr = expr, this
return self.func("LOG", this, expr)
def use_sql(self, expression: exp.Use) -> str:
kind = self.sql(expression, "kind")
@ -2696,7 +2766,9 @@ class Generator:
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
self.sql(expression, "format"),
self.dialect.INVERSE_TIME_MAPPING,
self.dialect.INVERSE_TIME_TRIE,
)
def expressions(
@ -2963,6 +3035,19 @@ class Generator:
parameters = self.sql(expression, "params_struct")
return self.func("PREDICT", model, table, parameters or None)
def forin_sql(self, expression: exp.ForIn) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
return f"FOR {this} DO {expression_sql}"
def refresh_sql(self, expression: exp.Refresh) -> str:
this = self.sql(expression, "this")
table = "" if isinstance(expression.this, exp.Literal) else "TABLE "
return f"REFRESH {table}{this}"
def operator_sql(self, expression: exp.Operator) -> str:
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
@ -2970,3 +3055,10 @@ class Generator:
expression = simplify(expression)
return expression
def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
return [
exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
for value in values
if value
]