1
0
Fork 0

Merging upstream version 21.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:20:36 +01:00
parent 3759c601a7
commit 96b10de29a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
115 changed files with 66603 additions and 60920 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import typing as t
from enum import Enum, auto
from functools import reduce
@ -7,7 +8,8 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
if t.TYPE_CHECKING:
from sqlglot._typing import B, E
from sqlglot._typing import B, E, F
JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
logger = logging.getLogger("sqlglot")
class Dialects(str, Enum):
@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for quotes, identifiers and the corresponding escape characters
# Delimiters for string literals and identifiers
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect):
"""
if (
isinstance(expression, exp.Identifier)
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect):
return expression
def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"
try:
return parse_json_path(path_text)
except ParseError as e:
logger.warning(f"Invalid JSON path syntax. {str(e)}")
return path
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)
@ -500,14 +519,12 @@ def if_sql(
return _if_sql
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
this = expression.this
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
this.replace(exp.cast(this, "json"))
def arrow_json_extract_scalar_sql(
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
) -> str:
return self.binary(expression, "->>")
return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported")
return ""
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
from sqlglot.optimizer.simplify import simplify
# Makes sure the path will be evaluated correctly at runtime to include the path root.
# For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
path = expression.expression
path = exp.func(
"if",
exp.func("startswith", path, "'['"),
exp.func("concat", "'$'", path),
exp.func("concat", "'$.'", path),
)
expression.expression.replace(simplify(path))
return expression
def path_to_jsonpath(
name: str = "JSON_EXTRACT",
) -> t.Callable[[Generator, exp.GetPath], str]:
def _transform(self: Generator, expression: exp.GetPath) -> str:
return rename_func(name)(self, prepend_dollar_to_path(expression))
return _transform
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
normalize = lambda identifier: (
self.dialect.normalize_identifier(identifier).name if identifier else None
)
def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
return self.dialect.normalize_identifier(identifier).name if identifier else None
targets = {normalize(expression.this.this)}
@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
)
return self.merge_sql(expression)
def parse_json_extract_path(
expr_type: t.Type[F], zero_based_indexing: bool = True
) -> t.Callable[[t.List], F]:
def _parse_json_extract_path(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
for arg in args[1:]:
if not isinstance(arg, exp.Literal):
# We use the fallback parser because we can't really transpile non-literals safely
return expr_type.from_arg_list(args)
text = arg.name
if is_int(text):
index = int(text)
segments.append(
exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
)
else:
segments.append(exp.JSONPathKey(this=text))
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
return _parse_json_extract_path
def json_extract_segments(
name: str, quoted_index: bool = True
) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
path = expression.expression
if not isinstance(path, exp.JSONPath):
return rename_func(name)(self, expression)
segments = []
for segment in path.expressions:
path = self.sql(segment)
if path:
if isinstance(segment, exp.JSONPathPart) and (
quoted_index or not isinstance(segment, exp.JSONPathSubscript)
):
path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
segments.append(path)
return self.func(name, expression.this, *segments)
return _json_extract_segments
def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
if isinstance(expression.this, exp.JSONPathWildcard):
self.unsupported("Unsupported wildcard in JSONPathKey expression")
return expression.name