Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    BIGQUERY = "bigquery"
  35    CLICKHOUSE = "clickhouse"
  36    DATABRICKS = "databricks"
  37    DORIS = "doris"
  38    DRILL = "drill"
  39    DUCKDB = "duckdb"
  40    HIVE = "hive"
  41    MYSQL = "mysql"
  42    ORACLE = "oracle"
  43    POSTGRES = "postgres"
  44    PRESTO = "presto"
  45    REDSHIFT = "redshift"
  46    SNOWFLAKE = "snowflake"
  47    SPARK = "spark"
  48    SPARK2 = "spark2"
  49    SQLITE = "sqlite"
  50    STARROCKS = "starrocks"
  51    TABLEAU = "tableau"
  52    TERADATA = "teradata"
  53    TRINO = "trino"
  54    TSQL = "tsql"
  55
  56
  57class NormalizationStrategy(str, AutoName):
  58    """Specifies the strategy according to which identifiers should be normalized."""
  59
  60    LOWERCASE = auto()
  61    """Unquoted identifiers are lowercased."""
  62
  63    UPPERCASE = auto()
  64    """Unquoted identifiers are uppercased."""
  65
  66    CASE_SENSITIVE = auto()
  67    """Always case-sensitive, regardless of quotes."""
  68
  69    CASE_INSENSITIVE = auto()
  70    """Always case-insensitive, regardless of quotes."""
  71
  72
  73class _Dialect(type):
  74    classes: t.Dict[str, t.Type[Dialect]] = {}
  75
  76    def __eq__(cls, other: t.Any) -> bool:
  77        if cls is other:
  78            return True
  79        if isinstance(other, str):
  80            return cls is cls.get(other)
  81        if isinstance(other, Dialect):
  82            return cls is type(other)
  83
  84        return False
  85
  86    def __hash__(cls) -> int:
  87        return hash(cls.__name__.lower())
  88
  89    @classmethod
  90    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  91        return cls.classes[key]
  92
  93    @classmethod
  94    def get(
  95        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  96    ) -> t.Optional[t.Type[Dialect]]:
  97        return cls.classes.get(key, default)
  98
  99    def __new__(cls, clsname, bases, attrs):
 100        klass = super().__new__(cls, clsname, bases, attrs)
 101        enum = Dialects.__members__.get(clsname.upper())
 102        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 103
 104        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 105        klass.FORMAT_TRIE = (
 106            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 107        )
 108        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 109        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 110
 111        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 112
 113        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 114        klass.parser_class = getattr(klass, "Parser", Parser)
 115        klass.generator_class = getattr(klass, "Generator", Generator)
 116
 117        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 118        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 119            klass.tokenizer_class._IDENTIFIERS.items()
 120        )[0]
 121
 122        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 123            return next(
 124                (
 125                    (s, e)
 126                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 127                    if t == token_type
 128                ),
 129                (None, None),
 130            )
 131
 132        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 133        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 134        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 135        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 136
 137        if enum not in ("", "bigquery"):
 138            klass.generator_class.SELECT_KINDS = ()
 139
 140        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 141            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 142                TokenType.ANTI,
 143                TokenType.SEMI,
 144            }
 145
 146        return klass
 147
 148
 149class Dialect(metaclass=_Dialect):
 150    INDEX_OFFSET = 0
 151    """The base index offset for arrays."""
 152
 153    WEEK_OFFSET = 0
 154    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 155
 156    UNNEST_COLUMN_ONLY = False
 157    """Whether `UNNEST` table aliases are treated as column aliases."""
 158
 159    ALIAS_POST_TABLESAMPLE = False
 160    """Whether the table alias comes after tablesample."""
 161
 162    TABLESAMPLE_SIZE_IS_PERCENT = False
 163    """Whether a size in the table sample clause represents percentage."""
 164
 165    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 166    """Specifies the strategy according to which identifiers should be normalized."""
 167
 168    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 169    """Whether an unquoted identifier can start with a digit."""
 170
 171    DPIPE_IS_STRING_CONCAT = True
 172    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 173
 174    STRICT_STRING_CONCAT = False
 175    """Whether `CONCAT`'s arguments must be strings."""
 176
 177    SUPPORTS_USER_DEFINED_TYPES = True
 178    """Whether user-defined data types are supported."""
 179
 180    SUPPORTS_SEMI_ANTI_JOIN = True
 181    """Whether `SEMI` or `ANTI` joins are supported."""
 182
 183    NORMALIZE_FUNCTIONS: bool | str = "upper"
 184    """
 185    Determines how function names are going to be normalized.
 186    Possible values:
 187        "upper" or True: Convert names to uppercase.
 188        "lower": Convert names to lowercase.
 189        False: Disables function name normalization.
 190    """
 191
 192    LOG_BASE_FIRST = True
 193    """Whether the base comes first in the `LOG` function."""
 194
 195    NULL_ORDERING = "nulls_are_small"
 196    """
 197    Default `NULL` ordering method to use if not explicitly set.
 198    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 199    """
 200
 201    TYPED_DIVISION = False
 202    """
 203    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 204    False means `a / b` is always float division.
 205    True means `a / b` is integer division if both `a` and `b` are integers.
 206    """
 207
 208    SAFE_DIVISION = False
 209    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 210
 211    CONCAT_COALESCE = False
 212    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 213
 214    DATE_FORMAT = "'%Y-%m-%d'"
 215    DATEINT_FORMAT = "'%Y%m%d'"
 216    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 217
 218    TIME_MAPPING: t.Dict[str, str] = {}
 219    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 220
 221    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 222    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 223    FORMAT_MAPPING: t.Dict[str, str] = {}
 224    """
 225    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 226    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 227    """
 228
 229    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 230    """Mapping of an unescaped escape sequence to the corresponding character."""
 231
 232    PSEUDOCOLUMNS: t.Set[str] = set()
 233    """
 234    Columns that are auto-generated by the engine corresponding to this dialect.
 235    For example, such columns may be excluded from `SELECT *` queries.
 236    """
 237
 238    PREFER_CTE_ALIAS_COLUMN = False
 239    """
 240    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 241    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 242    any projection aliases in the subquery.
 243
 244    For example,
 245        WITH y(c) AS (
 246            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 247        ) SELECT c FROM y;
 248
 249        will be rewritten as
 250
 251        WITH y(c) AS (
 252            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 253        ) SELECT c FROM y;
 254    """
 255
 256    # --- Autofilled ---
 257
 258    tokenizer_class = Tokenizer
 259    parser_class = Parser
 260    generator_class = Generator
 261
 262    # A trie of the time_mapping keys
 263    TIME_TRIE: t.Dict = {}
 264    FORMAT_TRIE: t.Dict = {}
 265
 266    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 267    INVERSE_TIME_TRIE: t.Dict = {}
 268
 269    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 270
 271    # Delimiters for string literals and identifiers
 272    QUOTE_START = "'"
 273    QUOTE_END = "'"
 274    IDENTIFIER_START = '"'
 275    IDENTIFIER_END = '"'
 276
 277    # Delimiters for bit, hex, byte and unicode literals
 278    BIT_START: t.Optional[str] = None
 279    BIT_END: t.Optional[str] = None
 280    HEX_START: t.Optional[str] = None
 281    HEX_END: t.Optional[str] = None
 282    BYTE_START: t.Optional[str] = None
 283    BYTE_END: t.Optional[str] = None
 284    UNICODE_START: t.Optional[str] = None
 285    UNICODE_END: t.Optional[str] = None
 286
 287    @classmethod
 288    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 289        """
 290        Look up a dialect in the global dialect registry and return it if it exists.
 291
 292        Args:
 293            dialect: The target dialect. If this is a string, it can be optionally followed by
 294                additional key-value pairs that are separated by commas and are used to specify
 295                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 296
 297        Example:
 298            >>> dialect = dialect_class = get_or_raise("duckdb")
 299            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 300
 301        Returns:
 302            The corresponding Dialect instance.
 303        """
 304
 305        if not dialect:
 306            return cls()
 307        if isinstance(dialect, _Dialect):
 308            return dialect()
 309        if isinstance(dialect, Dialect):
 310            return dialect
 311        if isinstance(dialect, str):
 312            try:
 313                dialect_name, *kv_pairs = dialect.split(",")
 314                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 315            except ValueError:
 316                raise ValueError(
 317                    f"Invalid dialect format: '{dialect}'. "
 318                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 319                )
 320
 321            result = cls.get(dialect_name.strip())
 322            if not result:
 323                from difflib import get_close_matches
 324
 325                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 326                if similar:
 327                    similar = f" Did you mean {similar}?"
 328
 329                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 330
 331            return result(**kwargs)
 332
 333        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 334
 335    @classmethod
 336    def format_time(
 337        cls, expression: t.Optional[str | exp.Expression]
 338    ) -> t.Optional[exp.Expression]:
 339        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 340        if isinstance(expression, str):
 341            return exp.Literal.string(
 342                # the time formats are quoted
 343                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 344            )
 345
 346        if expression and expression.is_string:
 347            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 348
 349        return expression
 350
 351    def __init__(self, **kwargs) -> None:
 352        normalization_strategy = kwargs.get("normalization_strategy")
 353
 354        if normalization_strategy is None:
 355            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 356        else:
 357            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 358
 359    def __eq__(self, other: t.Any) -> bool:
 360        # Does not currently take dialect state into account
 361        return type(self) == other
 362
 363    def __hash__(self) -> int:
 364        # Does not currently take dialect state into account
 365        return hash(type(self))
 366
 367    def normalize_identifier(self, expression: E) -> E:
 368        """
 369        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 370
 371        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 372        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 373        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 374        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 375
 376        There are also dialects like Spark, which are case-insensitive even when quotes are
 377        present, and dialects like MySQL, whose resolution rules match those employed by the
 378        underlying operating system, for example they may always be case-sensitive in Linux.
 379
 380        Finally, the normalization behavior of some engines can even be controlled through flags,
 381        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 382
 383        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 384        that it can analyze queries in the optimizer and successfully capture their semantics.
 385        """
 386        if (
 387            isinstance(expression, exp.Identifier)
 388            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 389            and (
 390                not expression.quoted
 391                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 392            )
 393        ):
 394            expression.set(
 395                "this",
 396                (
 397                    expression.this.upper()
 398                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 399                    else expression.this.lower()
 400                ),
 401            )
 402
 403        return expression
 404
 405    def case_sensitive(self, text: str) -> bool:
 406        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 407        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 408            return False
 409
 410        unsafe = (
 411            str.islower
 412            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 413            else str.isupper
 414        )
 415        return any(unsafe(char) for char in text)
 416
 417    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 418        """Checks if text can be identified given an identify option.
 419
 420        Args:
 421            text: The text to check.
 422            identify:
 423                `"always"` or `True`: Always returns `True`.
 424                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 425
 426        Returns:
 427            Whether the given text can be identified.
 428        """
 429        if identify is True or identify == "always":
 430            return True
 431
 432        if identify == "safe":
 433            return not self.case_sensitive(text)
 434
 435        return False
 436
 437    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 438        """
 439        Adds quotes to a given identifier.
 440
 441        Args:
 442            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 443            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 444                "unsafe", with respect to its characters and this dialect's normalization strategy.
 445        """
 446        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 447            name = expression.this
 448            expression.set(
 449                "quoted",
 450                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 451            )
 452
 453        return expression
 454
 455    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 456        if isinstance(path, exp.Literal):
 457            path_text = path.name
 458            if path.is_number:
 459                path_text = f"[{path_text}]"
 460
 461            try:
 462                return parse_json_path(path_text)
 463            except ParseError as e:
 464                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 465
 466        return path
 467
 468    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 469        return self.parser(**opts).parse(self.tokenize(sql), sql)
 470
 471    def parse_into(
 472        self, expression_type: exp.IntoType, sql: str, **opts
 473    ) -> t.List[t.Optional[exp.Expression]]:
 474        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 475
 476    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 477        return self.generator(**opts).generate(expression, copy=copy)
 478
 479    def transpile(self, sql: str, **opts) -> t.List[str]:
 480        return [
 481            self.generate(expression, copy=False, **opts) if expression else ""
 482            for expression in self.parse(sql)
 483        ]
 484
 485    def tokenize(self, sql: str) -> t.List[Token]:
 486        return self.tokenizer.tokenize(sql)
 487
 488    @property
 489    def tokenizer(self) -> Tokenizer:
 490        if not hasattr(self, "_tokenizer"):
 491            self._tokenizer = self.tokenizer_class(dialect=self)
 492        return self._tokenizer
 493
 494    def parser(self, **opts) -> Parser:
 495        return self.parser_class(dialect=self, **opts)
 496
 497    def generator(self, **opts) -> Generator:
 498        return self.generator_class(dialect=self, **opts)
 499
 500
 501DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 502
 503
 504def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 505    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 506
 507
 508def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 509    if expression.args.get("accuracy"):
 510        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 511    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 512
 513
 514def if_sql(
 515    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 516) -> t.Callable[[Generator, exp.If], str]:
 517    def _if_sql(self: Generator, expression: exp.If) -> str:
 518        return self.func(
 519            name,
 520            expression.this,
 521            expression.args.get("true"),
 522            expression.args.get("false") or false_value,
 523        )
 524
 525    return _if_sql
 526
 527
 528def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 529    this = expression.this
 530    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 531        this.replace(exp.cast(this, "json"))
 532
 533    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 534
 535
 536def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 537    return f"[{self.expressions(expression, flat=True)}]"
 538
 539
 540def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 541    return self.like_sql(
 542        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 543    )
 544
 545
 546def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 547    zone = self.sql(expression, "this")
 548    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 549
 550
 551def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 552    if expression.args.get("recursive"):
 553        self.unsupported("Recursive CTEs are unsupported")
 554        expression.args["recursive"] = False
 555    return self.with_sql(expression)
 556
 557
 558def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 559    n = self.sql(expression, "this")
 560    d = self.sql(expression, "expression")
 561    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 562
 563
 564def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 565    self.unsupported("TABLESAMPLE unsupported")
 566    return self.sql(expression.this)
 567
 568
 569def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 570    self.unsupported("PIVOT unsupported")
 571    return ""
 572
 573
 574def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 575    return self.cast_sql(expression)
 576
 577
 578def no_comment_column_constraint_sql(
 579    self: Generator, expression: exp.CommentColumnConstraint
 580) -> str:
 581    self.unsupported("CommentColumnConstraint unsupported")
 582    return ""
 583
 584
 585def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 586    self.unsupported("MAP_FROM_ENTRIES unsupported")
 587    return ""
 588
 589
 590def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 591    this = self.sql(expression, "this")
 592    substr = self.sql(expression, "substr")
 593    position = self.sql(expression, "position")
 594    if position:
 595        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 596    return f"STRPOS({this}, {substr})"
 597
 598
 599def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 600    return (
 601        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 602    )
 603
 604
 605def var_map_sql(
 606    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 607) -> str:
 608    keys = expression.args["keys"]
 609    values = expression.args["values"]
 610
 611    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 612        self.unsupported("Cannot convert array columns into map.")
 613        return self.func(map_func_name, keys, values)
 614
 615    args = []
 616    for key, value in zip(keys.expressions, values.expressions):
 617        args.append(self.sql(key))
 618        args.append(self.sql(value))
 619
 620    return self.func(map_func_name, *args)
 621
 622
 623def build_formatted_time(
 624    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 625) -> t.Callable[[t.List], E]:
 626    """Helper used for time expressions.
 627
 628    Args:
 629        exp_class: the expression class to instantiate.
 630        dialect: target sql dialect.
 631        default: the default format, True being time.
 632
 633    Returns:
 634        A callable that can be used to return the appropriately formatted time expression.
 635    """
 636
 637    def _builder(args: t.List):
 638        return exp_class(
 639            this=seq_get(args, 0),
 640            format=Dialect[dialect].format_time(
 641                seq_get(args, 1)
 642                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 643            ),
 644        )
 645
 646    return _builder
 647
 648
 649def time_format(
 650    dialect: DialectType = None,
 651) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 652    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 653        """
 654        Returns the time format for a given expression, unless it's equivalent
 655        to the default time format of the dialect of interest.
 656        """
 657        time_format = self.format_time(expression)
 658        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 659
 660    return _time_format
 661
 662
 663def build_date_delta(
 664    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 665) -> t.Callable[[t.List], E]:
 666    def _builder(args: t.List) -> E:
 667        unit_based = len(args) == 3
 668        this = args[2] if unit_based else seq_get(args, 0)
 669        unit = args[0] if unit_based else exp.Literal.string("DAY")
 670        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 671        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 672
 673    return _builder
 674
 675
 676def build_date_delta_with_interval(
 677    expression_class: t.Type[E],
 678) -> t.Callable[[t.List], t.Optional[E]]:
 679    def _builder(args: t.List) -> t.Optional[E]:
 680        if len(args) < 2:
 681            return None
 682
 683        interval = args[1]
 684
 685        if not isinstance(interval, exp.Interval):
 686            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 687
 688        expression = interval.this
 689        if expression and expression.is_string:
 690            expression = exp.Literal.number(expression.this)
 691
 692        return expression_class(
 693            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 694        )
 695
 696    return _builder
 697
 698
 699def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 700    unit = seq_get(args, 0)
 701    this = seq_get(args, 1)
 702
 703    if isinstance(this, exp.Cast) and this.is_type("date"):
 704        return exp.DateTrunc(unit=unit, this=this)
 705    return exp.TimestampTrunc(this=this, unit=unit)
 706
 707
 708def date_add_interval_sql(
 709    data_type: str, kind: str
 710) -> t.Callable[[Generator, exp.Expression], str]:
 711    def func(self: Generator, expression: exp.Expression) -> str:
 712        this = self.sql(expression, "this")
 713        unit = expression.args.get("unit")
 714        unit = exp.var(unit.name.upper() if unit else "DAY")
 715        interval = exp.Interval(this=expression.expression, unit=unit)
 716        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 717
 718    return func
 719
 720
 721def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 722    return self.func(
 723        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 724    )
 725
 726
 727def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 728    if not expression.expression:
 729        from sqlglot.optimizer.annotate_types import annotate_types
 730
 731        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 732        return self.sql(exp.cast(expression.this, to=target_type))
 733    if expression.text("expression").lower() in TIMEZONES:
 734        return self.sql(
 735            exp.AtTimeZone(
 736                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 737                zone=expression.expression,
 738            )
 739        )
 740    return self.func("TIMESTAMP", expression.this, expression.expression)
 741
 742
 743def locate_to_strposition(args: t.List) -> exp.Expression:
 744    return exp.StrPosition(
 745        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 746    )
 747
 748
 749def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 750    return self.func(
 751        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 752    )
 753
 754
 755def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 756    return self.sql(
 757        exp.Substring(
 758            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 759        )
 760    )
 761
 762
 763def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 764    return self.sql(
 765        exp.Substring(
 766            this=expression.this,
 767            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 768        )
 769    )
 770
 771
 772def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 773    return self.sql(exp.cast(expression.this, "timestamp"))
 774
 775
 776def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 777    return self.sql(exp.cast(expression.this, "date"))
 778
 779
 780# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 781def encode_decode_sql(
 782    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 783) -> str:
 784    charset = expression.args.get("charset")
 785    if charset and charset.name.lower() != "utf-8":
 786        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 787
 788    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 789
 790
 791def min_or_least(self: Generator, expression: exp.Min) -> str:
 792    name = "LEAST" if expression.expressions else "MIN"
 793    return rename_func(name)(self, expression)
 794
 795
 796def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 797    name = "GREATEST" if expression.expressions else "MAX"
 798    return rename_func(name)(self, expression)
 799
 800
 801def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 802    cond = expression.this
 803
 804    if isinstance(expression.this, exp.Distinct):
 805        cond = expression.this.expressions[0]
 806        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 807
 808    return self.func("sum", exp.func("if", cond, 1, 0))
 809
 810
 811def trim_sql(self: Generator, expression: exp.Trim) -> str:
 812    target = self.sql(expression, "this")
 813    trim_type = self.sql(expression, "position")
 814    remove_chars = self.sql(expression, "expression")
 815    collation = self.sql(expression, "collation")
 816
 817    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 818    if not remove_chars and not collation:
 819        return self.trim_sql(expression)
 820
 821    trim_type = f"{trim_type} " if trim_type else ""
 822    remove_chars = f"{remove_chars} " if remove_chars else ""
 823    from_part = "FROM " if trim_type or remove_chars else ""
 824    collation = f" COLLATE {collation}" if collation else ""
 825    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 826
 827
 828def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 829    return self.func("STRPTIME", expression.this, self.format_time(expression))
 830
 831
 832def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 833    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 834
 835
 836def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 837    delim, *rest_args = expression.expressions
 838    return self.sql(
 839        reduce(
 840            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 841            rest_args,
 842        )
 843    )
 844
 845
 846def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 847    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 848    if bad_args:
 849        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 850
 851    return self.func(
 852        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 853    )
 854
 855
 856def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 857    bad_args = list(
 858        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 859    )
 860    if bad_args:
 861        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 862
 863    return self.func(
 864        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 865    )
 866
 867
 868def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 869    names = []
 870    for agg in aggregations:
 871        if isinstance(agg, exp.Alias):
 872            names.append(agg.alias)
 873        else:
 874            """
 875            This case corresponds to aggregations without aliases being used as suffixes
 876            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 877            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 878            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 879            """
 880            agg_all_unquoted = agg.transform(
 881                lambda node: (
 882                    exp.Identifier(this=node.name, quoted=False)
 883                    if isinstance(node, exp.Identifier)
 884                    else node
 885                )
 886            )
 887            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 888
 889    return names
 890
 891
 892def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 893    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 894
 895
 896# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 897def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 898    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 899
 900
 901def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 902    return self.func("MAX", expression.this)
 903
 904
 905def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 906    a = self.sql(expression.left)
 907    b = self.sql(expression.right)
 908    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 909
 910
 911def is_parse_json(expression: exp.Expression) -> bool:
 912    return isinstance(expression, exp.ParseJSON) or (
 913        isinstance(expression, exp.Cast) and expression.is_type("json")
 914    )
 915
 916
 917def isnull_to_is_null(args: t.List) -> exp.Expression:
 918    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 919
 920
 921def generatedasidentitycolumnconstraint_sql(
 922    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 923) -> str:
 924    start = self.sql(expression, "start") or "1"
 925    increment = self.sql(expression, "increment") or "1"
 926    return f"IDENTITY({start}, {increment})"
 927
 928
 929def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 930    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 931        if expression.args.get("count"):
 932            self.unsupported(f"Only two arguments are supported in function {name}.")
 933
 934        return self.func(name, expression.this, expression.expression)
 935
 936    return _arg_max_or_min_sql
 937
 938
 939def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 940    this = expression.this.copy()
 941
 942    return_type = expression.return_type
 943    if return_type.is_type(exp.DataType.Type.DATE):
 944        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 945        # can truncate timestamp strings, because some dialects can't cast them to DATE
 946        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 947
 948    expression.this.replace(exp.cast(this, return_type))
 949    return expression
 950
 951
 952def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 953    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 954        if cast and isinstance(expression, exp.TsOrDsAdd):
 955            expression = ts_or_ds_add_cast(expression)
 956
 957        return self.func(
 958            name,
 959            exp.var(expression.text("unit").upper() or "DAY"),
 960            expression.expression,
 961            expression.this,
 962        )
 963
 964    return _delta_sql
 965
 966
 967def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 968    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 969    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 970    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 971
 972    return self.sql(exp.cast(minus_one_day, "date"))
 973
 974
 975def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 976    """Remove table refs from columns in when statements."""
 977    alias = expression.this.args.get("alias")
 978
 979    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 980        return self.dialect.normalize_identifier(identifier).name if identifier else None
 981
 982    targets = {normalize(expression.this.this)}
 983
 984    if alias:
 985        targets.add(normalize(alias.this))
 986
 987    for when in expression.expressions:
 988        when.transform(
 989            lambda node: (
 990                exp.column(node.this)
 991                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
 992                else node
 993            ),
 994            copy=False,
 995        )
 996
 997    return self.merge_sql(expression)
 998
 999
1000def build_json_extract_path(
1001    expr_type: t.Type[F], zero_based_indexing: bool = True
1002) -> t.Callable[[t.List], F]:
1003    def _builder(args: t.List) -> F:
1004        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1005        for arg in args[1:]:
1006            if not isinstance(arg, exp.Literal):
1007                # We use the fallback parser because we can't really transpile non-literals safely
1008                return expr_type.from_arg_list(args)
1009
1010            text = arg.name
1011            if is_int(text):
1012                index = int(text)
1013                segments.append(
1014                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1015                )
1016            else:
1017                segments.append(exp.JSONPathKey(this=text))
1018
1019        # This is done to avoid failing in the expression validator due to the arg count
1020        del args[2:]
1021        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1022
1023    return _builder
1024
1025
1026def json_extract_segments(
1027    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1028) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1029    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1030        path = expression.expression
1031        if not isinstance(path, exp.JSONPath):
1032            return rename_func(name)(self, expression)
1033
1034        segments = []
1035        for segment in path.expressions:
1036            path = self.sql(segment)
1037            if path:
1038                if isinstance(segment, exp.JSONPathPart) and (
1039                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1040                ):
1041                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1042
1043                segments.append(path)
1044
1045        if op:
1046            return f" {op} ".join([self.sql(expression.this), *segments])
1047        return self.func(name, expression.this, *segments)
1048
1049    return _json_extract_segments
1050
1051
1052def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1053    if isinstance(expression.this, exp.JSONPathWildcard):
1054        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1055
1056    return expression.name
1057
1058
1059def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1060    cond = expression.expression
1061    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1062        alias = cond.expressions[0]
1063        cond = cond.this
1064    elif isinstance(cond, exp.Predicate):
1065        alias = "_u"
1066    else:
1067        self.unsupported("Unsupported filter condition")
1068        return ""
1069
1070    unnest = exp.Unnest(expressions=[expression.this])
1071    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1072    return self.sql(exp.Array(expressions=[filtered]))
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    BIGQUERY = "bigquery"
36    CLICKHOUSE = "clickhouse"
37    DATABRICKS = "databricks"
38    DORIS = "doris"
39    DRILL = "drill"
40    DUCKDB = "duckdb"
41    HIVE = "hive"
42    MYSQL = "mysql"
43    ORACLE = "oracle"
44    POSTGRES = "postgres"
45    PRESTO = "presto"
46    REDSHIFT = "redshift"
47    SNOWFLAKE = "snowflake"
48    SPARK = "spark"
49    SPARK2 = "spark2"
50    SQLITE = "sqlite"
51    STARROCKS = "starrocks"
52    TABLEAU = "tableau"
53    TERADATA = "teradata"
54    TRINO = "trino"
55    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
58class NormalizationStrategy(str, AutoName):
59    """Specifies the strategy according to which identifiers should be normalized."""
60
61    LOWERCASE = auto()
62    """Unquoted identifiers are lowercased."""
63
64    UPPERCASE = auto()
65    """Unquoted identifiers are uppercased."""
66
67    CASE_SENSITIVE = auto()
68    """Always case-sensitive, regardless of quotes."""
69
70    CASE_INSENSITIVE = auto()
71    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
150class Dialect(metaclass=_Dialect):
151    INDEX_OFFSET = 0
152    """The base index offset for arrays."""
153
154    WEEK_OFFSET = 0
155    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
156
157    UNNEST_COLUMN_ONLY = False
158    """Whether `UNNEST` table aliases are treated as column aliases."""
159
160    ALIAS_POST_TABLESAMPLE = False
161    """Whether the table alias comes after tablesample."""
162
163    TABLESAMPLE_SIZE_IS_PERCENT = False
164    """Whether a size in the table sample clause represents percentage."""
165
166    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
167    """Specifies the strategy according to which identifiers should be normalized."""
168
169    IDENTIFIERS_CAN_START_WITH_DIGIT = False
170    """Whether an unquoted identifier can start with a digit."""
171
172    DPIPE_IS_STRING_CONCAT = True
173    """Whether the DPIPE token (`||`) is a string concatenation operator."""
174
175    STRICT_STRING_CONCAT = False
176    """Whether `CONCAT`'s arguments must be strings."""
177
178    SUPPORTS_USER_DEFINED_TYPES = True
179    """Whether user-defined data types are supported."""
180
181    SUPPORTS_SEMI_ANTI_JOIN = True
182    """Whether `SEMI` or `ANTI` joins are supported."""
183
184    NORMALIZE_FUNCTIONS: bool | str = "upper"
185    """
186    Determines how function names are going to be normalized.
187    Possible values:
188        "upper" or True: Convert names to uppercase.
189        "lower": Convert names to lowercase.
190        False: Disables function name normalization.
191    """
192
193    LOG_BASE_FIRST = True
194    """Whether the base comes first in the `LOG` function."""
195
196    NULL_ORDERING = "nulls_are_small"
197    """
198    Default `NULL` ordering method to use if not explicitly set.
199    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
200    """
201
202    TYPED_DIVISION = False
203    """
204    Whether the behavior of `a / b` depends on the types of `a` and `b`.
205    False means `a / b` is always float division.
206    True means `a / b` is integer division if both `a` and `b` are integers.
207    """
208
209    SAFE_DIVISION = False
210    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
211
212    CONCAT_COALESCE = False
213    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
214
215    DATE_FORMAT = "'%Y-%m-%d'"
216    DATEINT_FORMAT = "'%Y%m%d'"
217    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
218
219    TIME_MAPPING: t.Dict[str, str] = {}
220    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
221
222    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
223    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
224    FORMAT_MAPPING: t.Dict[str, str] = {}
225    """
226    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
227    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
228    """
229
230    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
231    """Mapping of an unescaped escape sequence to the corresponding character."""
232
233    PSEUDOCOLUMNS: t.Set[str] = set()
234    """
235    Columns that are auto-generated by the engine corresponding to this dialect.
236    For example, such columns may be excluded from `SELECT *` queries.
237    """
238
239    PREFER_CTE_ALIAS_COLUMN = False
240    """
241    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
242    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
243    any projection aliases in the subquery.
244
245    For example,
246        WITH y(c) AS (
247            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
248        ) SELECT c FROM y;
249
250        will be rewritten as
251
252        WITH y(c) AS (
253            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
254        ) SELECT c FROM y;
255    """
256
257    # --- Autofilled ---
258
259    tokenizer_class = Tokenizer
260    parser_class = Parser
261    generator_class = Generator
262
263    # A trie of the time_mapping keys
264    TIME_TRIE: t.Dict = {}
265    FORMAT_TRIE: t.Dict = {}
266
267    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
268    INVERSE_TIME_TRIE: t.Dict = {}
269
270    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
271
272    # Delimiters for string literals and identifiers
273    QUOTE_START = "'"
274    QUOTE_END = "'"
275    IDENTIFIER_START = '"'
276    IDENTIFIER_END = '"'
277
278    # Delimiters for bit, hex, byte and unicode literals
279    BIT_START: t.Optional[str] = None
280    BIT_END: t.Optional[str] = None
281    HEX_START: t.Optional[str] = None
282    HEX_END: t.Optional[str] = None
283    BYTE_START: t.Optional[str] = None
284    BYTE_END: t.Optional[str] = None
285    UNICODE_START: t.Optional[str] = None
286    UNICODE_END: t.Optional[str] = None
287
288    @classmethod
289    def get_or_raise(cls, dialect: DialectType) -> Dialect:
290        """
291        Look up a dialect in the global dialect registry and return it if it exists.
292
293        Args:
294            dialect: The target dialect. If this is a string, it can be optionally followed by
295                additional key-value pairs that are separated by commas and are used to specify
296                dialect settings, such as whether the dialect's identifiers are case-sensitive.
297
298        Example:
299            >>> dialect = dialect_class = get_or_raise("duckdb")
300            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
301
302        Returns:
303            The corresponding Dialect instance.
304        """
305
306        if not dialect:
307            return cls()
308        if isinstance(dialect, _Dialect):
309            return dialect()
310        if isinstance(dialect, Dialect):
311            return dialect
312        if isinstance(dialect, str):
313            try:
314                dialect_name, *kv_pairs = dialect.split(",")
315                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
316            except ValueError:
317                raise ValueError(
318                    f"Invalid dialect format: '{dialect}'. "
319                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
320                )
321
322            result = cls.get(dialect_name.strip())
323            if not result:
324                from difflib import get_close_matches
325
326                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
327                if similar:
328                    similar = f" Did you mean {similar}?"
329
330                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
331
332            return result(**kwargs)
333
334        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
335
336    @classmethod
337    def format_time(
338        cls, expression: t.Optional[str | exp.Expression]
339    ) -> t.Optional[exp.Expression]:
340        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
341        if isinstance(expression, str):
342            return exp.Literal.string(
343                # the time formats are quoted
344                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
345            )
346
347        if expression and expression.is_string:
348            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
349
350        return expression
351
352    def __init__(self, **kwargs) -> None:
353        normalization_strategy = kwargs.get("normalization_strategy")
354
355        if normalization_strategy is None:
356            self.normalization_strategy = self.NORMALIZATION_STRATEGY
357        else:
358            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
359
360    def __eq__(self, other: t.Any) -> bool:
361        # Does not currently take dialect state into account
362        return type(self) == other
363
364    def __hash__(self) -> int:
365        # Does not currently take dialect state into account
366        return hash(type(self))
367
368    def normalize_identifier(self, expression: E) -> E:
369        """
370        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
371
372        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
373        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
374        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
375        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
376
377        There are also dialects like Spark, which are case-insensitive even when quotes are
378        present, and dialects like MySQL, whose resolution rules match those employed by the
379        underlying operating system, for example they may always be case-sensitive in Linux.
380
381        Finally, the normalization behavior of some engines can even be controlled through flags,
382        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
383
384        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
385        that it can analyze queries in the optimizer and successfully capture their semantics.
386        """
387        if (
388            isinstance(expression, exp.Identifier)
389            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
390            and (
391                not expression.quoted
392                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
393            )
394        ):
395            expression.set(
396                "this",
397                (
398                    expression.this.upper()
399                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
400                    else expression.this.lower()
401                ),
402            )
403
404        return expression
405
406    def case_sensitive(self, text: str) -> bool:
407        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
408        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
409            return False
410
411        unsafe = (
412            str.islower
413            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
414            else str.isupper
415        )
416        return any(unsafe(char) for char in text)
417
418    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
419        """Checks if text can be identified given an identify option.
420
421        Args:
422            text: The text to check.
423            identify:
424                `"always"` or `True`: Always returns `True`.
425                `"safe"`: Only returns `True` if the identifier is case-insensitive.
426
427        Returns:
428            Whether the given text can be identified.
429        """
430        if identify is True or identify == "always":
431            return True
432
433        if identify == "safe":
434            return not self.case_sensitive(text)
435
436        return False
437
438    def quote_identifier(self, expression: E, identify: bool = True) -> E:
439        """
440        Adds quotes to a given identifier.
441
442        Args:
443            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
444            identify: If set to `False`, the quotes will only be added if the identifier is deemed
445                "unsafe", with respect to its characters and this dialect's normalization strategy.
446        """
447        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
448            name = expression.this
449            expression.set(
450                "quoted",
451                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
452            )
453
454        return expression
455
456    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
457        if isinstance(path, exp.Literal):
458            path_text = path.name
459            if path.is_number:
460                path_text = f"[{path_text}]"
461
462            try:
463                return parse_json_path(path_text)
464            except ParseError as e:
465                logger.warning(f"Invalid JSON path syntax. {str(e)}")
466
467        return path
468
469    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
470        return self.parser(**opts).parse(self.tokenize(sql), sql)
471
472    def parse_into(
473        self, expression_type: exp.IntoType, sql: str, **opts
474    ) -> t.List[t.Optional[exp.Expression]]:
475        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
476
477    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
478        return self.generator(**opts).generate(expression, copy=copy)
479
480    def transpile(self, sql: str, **opts) -> t.List[str]:
481        return [
482            self.generate(expression, copy=False, **opts) if expression else ""
483            for expression in self.parse(sql)
484        ]
485
486    def tokenize(self, sql: str) -> t.List[Token]:
487        return self.tokenizer.tokenize(sql)
488
489    @property
490    def tokenizer(self) -> Tokenizer:
491        if not hasattr(self, "_tokenizer"):
492            self._tokenizer = self.tokenizer_class(dialect=self)
493        return self._tokenizer
494
495    def parser(self, **opts) -> Parser:
496        return self.parser_class(dialect=self, **opts)
497
498    def generator(self, **opts) -> Generator:
499        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
352    def __init__(self, **kwargs) -> None:
353        normalization_strategy = kwargs.get("normalization_strategy")
354
355        if normalization_strategy is None:
356            self.normalization_strategy = self.NORMALIZATION_STRATEGY
357        else:
358            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST = True

Whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
288    @classmethod
289    def get_or_raise(cls, dialect: DialectType) -> Dialect:
290        """
291        Look up a dialect in the global dialect registry and return it if it exists.
292
293        Args:
294            dialect: The target dialect. If this is a string, it can be optionally followed by
295                additional key-value pairs that are separated by commas and are used to specify
296                dialect settings, such as whether the dialect's identifiers are case-sensitive.
297
298        Example:
299            >>> dialect = dialect_class = get_or_raise("duckdb")
300            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
301
302        Returns:
303            The corresponding Dialect instance.
304        """
305
306        if not dialect:
307            return cls()
308        if isinstance(dialect, _Dialect):
309            return dialect()
310        if isinstance(dialect, Dialect):
311            return dialect
312        if isinstance(dialect, str):
313            try:
314                dialect_name, *kv_pairs = dialect.split(",")
315                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
316            except ValueError:
317                raise ValueError(
318                    f"Invalid dialect format: '{dialect}'. "
319                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
320                )
321
322            result = cls.get(dialect_name.strip())
323            if not result:
324                from difflib import get_close_matches
325
326                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
327                if similar:
328                    similar = f" Did you mean {similar}?"
329
330                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
331
332            return result(**kwargs)
333
334        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
336    @classmethod
337    def format_time(
338        cls, expression: t.Optional[str | exp.Expression]
339    ) -> t.Optional[exp.Expression]:
340        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
341        if isinstance(expression, str):
342            return exp.Literal.string(
343                # the time formats are quoted
344                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
345            )
346
347        if expression and expression.is_string:
348            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
349
350        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
368    def normalize_identifier(self, expression: E) -> E:
369        """
370        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
371
372        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
373        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
374        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
375        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
376
377        There are also dialects like Spark, which are case-insensitive even when quotes are
378        present, and dialects like MySQL, whose resolution rules match those employed by the
379        underlying operating system, for example they may always be case-sensitive in Linux.
380
381        Finally, the normalization behavior of some engines can even be controlled through flags,
382        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
383
384        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
385        that it can analyze queries in the optimizer and successfully capture their semantics.
386        """
387        if (
388            isinstance(expression, exp.Identifier)
389            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
390            and (
391                not expression.quoted
392                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
393            )
394        ):
395            expression.set(
396                "this",
397                (
398                    expression.this.upper()
399                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
400                    else expression.this.lower()
401                ),
402            )
403
404        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
406    def case_sensitive(self, text: str) -> bool:
407        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
408        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
409            return False
410
411        unsafe = (
412            str.islower
413            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
414            else str.isupper
415        )
416        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
418    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
419        """Checks if text can be identified given an identify option.
420
421        Args:
422            text: The text to check.
423            identify:
424                `"always"` or `True`: Always returns `True`.
425                `"safe"`: Only returns `True` if the identifier is case-insensitive.
426
427        Returns:
428            Whether the given text can be identified.
429        """
430        if identify is True or identify == "always":
431            return True
432
433        if identify == "safe":
434            return not self.case_sensitive(text)
435
436        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
438    def quote_identifier(self, expression: E, identify: bool = True) -> E:
439        """
440        Adds quotes to a given identifier.
441
442        Args:
443            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
444            identify: If set to `False`, the quotes will only be added if the identifier is deemed
445                "unsafe", with respect to its characters and this dialect's normalization strategy.
446        """
447        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
448            name = expression.this
449            expression.set(
450                "quoted",
451                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
452            )
453
454        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
456    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
457        if isinstance(path, exp.Literal):
458            path_text = path.name
459            if path.is_number:
460                path_text = f"[{path_text}]"
461
462            try:
463                return parse_json_path(path_text)
464            except ParseError as e:
465                logger.warning(f"Invalid JSON path syntax. {str(e)}")
466
467        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
469    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
470        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
472    def parse_into(
473        self, expression_type: exp.IntoType, sql: str, **opts
474    ) -> t.List[t.Optional[exp.Expression]]:
475        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
477    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
478        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
480    def transpile(self, sql: str, **opts) -> t.List[str]:
481        return [
482            self.generate(expression, copy=False, **opts) if expression else ""
483            for expression in self.parse(sql)
484        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
486    def tokenize(self, sql: str) -> t.List[Token]:
487        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
489    @property
490    def tokenizer(self) -> Tokenizer:
491        if not hasattr(self, "_tokenizer"):
492            self._tokenizer = self.tokenizer_class(dialect=self)
493        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
495    def parser(self, **opts) -> Parser:
496        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
498    def generator(self, **opts) -> Generator:
499        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
505def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
506    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
509def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
510    if expression.args.get("accuracy"):
511        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
512    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
515def if_sql(
516    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
517) -> t.Callable[[Generator, exp.If], str]:
518    def _if_sql(self: Generator, expression: exp.If) -> str:
519        return self.func(
520            name,
521            expression.this,
522            expression.args.get("true"),
523            expression.args.get("false") or false_value,
524        )
525
526    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
529def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
530    this = expression.this
531    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
532        this.replace(exp.cast(this, "json"))
533
534    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
537def inline_array_sql(self: Generator, expression: exp.Array) -> str:
538    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
541def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
542    return self.like_sql(
543        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
544    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
547def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
548    zone = self.sql(expression, "this")
549    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
552def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
553    if expression.args.get("recursive"):
554        self.unsupported("Recursive CTEs are unsupported")
555        expression.args["recursive"] = False
556    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
559def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
560    n = self.sql(expression, "this")
561    d = self.sql(expression, "expression")
562    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
565def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
566    self.unsupported("TABLESAMPLE unsupported")
567    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
570def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
571    self.unsupported("PIVOT unsupported")
572    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
575def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
576    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
579def no_comment_column_constraint_sql(
580    self: Generator, expression: exp.CommentColumnConstraint
581) -> str:
582    self.unsupported("CommentColumnConstraint unsupported")
583    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
586def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
587    self.unsupported("MAP_FROM_ENTRIES unsupported")
588    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
591def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
592    this = self.sql(expression, "this")
593    substr = self.sql(expression, "substr")
594    position = self.sql(expression, "position")
595    if position:
596        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
597    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
600def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
601    return (
602        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
603    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
606def var_map_sql(
607    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
608) -> str:
609    keys = expression.args["keys"]
610    values = expression.args["values"]
611
612    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
613        self.unsupported("Cannot convert array columns into map.")
614        return self.func(map_func_name, keys, values)
615
616    args = []
617    for key, value in zip(keys.expressions, values.expressions):
618        args.append(self.sql(key))
619        args.append(self.sql(value))
620
621    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
624def build_formatted_time(
625    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
626) -> t.Callable[[t.List], E]:
627    """Helper used for time expressions.
628
629    Args:
630        exp_class: the expression class to instantiate.
631        dialect: target sql dialect.
632        default: the default format, True being time.
633
634    Returns:
635        A callable that can be used to return the appropriately formatted time expression.
636    """
637
638    def _builder(args: t.List):
639        return exp_class(
640            this=seq_get(args, 0),
641            format=Dialect[dialect].format_time(
642                seq_get(args, 1)
643                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
644            ),
645        )
646
647    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
650def time_format(
651    dialect: DialectType = None,
652) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
653    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
654        """
655        Returns the time format for a given expression, unless it's equivalent
656        to the default time format of the dialect of interest.
657        """
658        time_format = self.format_time(expression)
659        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
660
661    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
664def build_date_delta(
665    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
666) -> t.Callable[[t.List], E]:
667    def _builder(args: t.List) -> E:
668        unit_based = len(args) == 3
669        this = args[2] if unit_based else seq_get(args, 0)
670        unit = args[0] if unit_based else exp.Literal.string("DAY")
671        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
672        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
673
674    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
677def build_date_delta_with_interval(
678    expression_class: t.Type[E],
679) -> t.Callable[[t.List], t.Optional[E]]:
680    def _builder(args: t.List) -> t.Optional[E]:
681        if len(args) < 2:
682            return None
683
684        interval = args[1]
685
686        if not isinstance(interval, exp.Interval):
687            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
688
689        expression = interval.this
690        if expression and expression.is_string:
691            expression = exp.Literal.number(expression.this)
692
693        return expression_class(
694            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
695        )
696
697    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
700def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
701    unit = seq_get(args, 0)
702    this = seq_get(args, 1)
703
704    if isinstance(this, exp.Cast) and this.is_type("date"):
705        return exp.DateTrunc(unit=unit, this=this)
706    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
709def date_add_interval_sql(
710    data_type: str, kind: str
711) -> t.Callable[[Generator, exp.Expression], str]:
712    def func(self: Generator, expression: exp.Expression) -> str:
713        this = self.sql(expression, "this")
714        unit = expression.args.get("unit")
715        unit = exp.var(unit.name.upper() if unit else "DAY")
716        interval = exp.Interval(this=expression.expression, unit=unit)
717        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
718
719    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
722def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
723    return self.func(
724        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
725    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
728def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
729    if not expression.expression:
730        from sqlglot.optimizer.annotate_types import annotate_types
731
732        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
733        return self.sql(exp.cast(expression.this, to=target_type))
734    if expression.text("expression").lower() in TIMEZONES:
735        return self.sql(
736            exp.AtTimeZone(
737                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
738                zone=expression.expression,
739            )
740        )
741    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
744def locate_to_strposition(args: t.List) -> exp.Expression:
745    return exp.StrPosition(
746        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
747    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
750def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
751    return self.func(
752        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
753    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
756def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
757    return self.sql(
758        exp.Substring(
759            this=expression.this, start=exp.Literal.number(1), length=expression.expression
760        )
761    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
764def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
765    return self.sql(
766        exp.Substring(
767            this=expression.this,
768            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
769        )
770    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
773def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
774    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
777def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
778    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
782def encode_decode_sql(
783    self: Generator, expression: exp.Expression, name: str, replace: bool = True
784) -> str:
785    charset = expression.args.get("charset")
786    if charset and charset.name.lower() != "utf-8":
787        self.unsupported(f"Expected utf-8 character set, got {charset}.")
788
789    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
792def min_or_least(self: Generator, expression: exp.Min) -> str:
793    name = "LEAST" if expression.expressions else "MIN"
794    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
797def max_or_greatest(self: Generator, expression: exp.Max) -> str:
798    name = "GREATEST" if expression.expressions else "MAX"
799    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
802def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
803    cond = expression.this
804
805    if isinstance(expression.this, exp.Distinct):
806        cond = expression.this.expressions[0]
807        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
808
809    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
812def trim_sql(self: Generator, expression: exp.Trim) -> str:
813    target = self.sql(expression, "this")
814    trim_type = self.sql(expression, "position")
815    remove_chars = self.sql(expression, "expression")
816    collation = self.sql(expression, "collation")
817
818    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
819    if not remove_chars and not collation:
820        return self.trim_sql(expression)
821
822    trim_type = f"{trim_type} " if trim_type else ""
823    remove_chars = f"{remove_chars} " if remove_chars else ""
824    from_part = "FROM " if trim_type or remove_chars else ""
825    collation = f" COLLATE {collation}" if collation else ""
826    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
829def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
830    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
833def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
834    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
837def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
838    delim, *rest_args = expression.expressions
839    return self.sql(
840        reduce(
841            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
842            rest_args,
843        )
844    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
847def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
848    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
849    if bad_args:
850        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
851
852    return self.func(
853        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
854    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
857def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
858    bad_args = list(
859        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
860    )
861    if bad_args:
862        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
863
864    return self.func(
865        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
866    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
869def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
870    names = []
871    for agg in aggregations:
872        if isinstance(agg, exp.Alias):
873            names.append(agg.alias)
874        else:
875            """
876            This case corresponds to aggregations without aliases being used as suffixes
877            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
878            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
879            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
880            """
881            agg_all_unquoted = agg.transform(
882                lambda node: (
883                    exp.Identifier(this=node.name, quoted=False)
884                    if isinstance(node, exp.Identifier)
885                    else node
886                )
887            )
888            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
889
890    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
893def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
894    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
898def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
899    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
902def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
903    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
906def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
907    a = self.sql(expression.left)
908    b = self.sql(expression.right)
909    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
912def is_parse_json(expression: exp.Expression) -> bool:
913    return isinstance(expression, exp.ParseJSON) or (
914        isinstance(expression, exp.Cast) and expression.is_type("json")
915    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
918def isnull_to_is_null(args: t.List) -> exp.Expression:
919    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
922def generatedasidentitycolumnconstraint_sql(
923    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
924) -> str:
925    start = self.sql(expression, "start") or "1"
926    increment = self.sql(expression, "increment") or "1"
927    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
930def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
931    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
932        if expression.args.get("count"):
933            self.unsupported(f"Only two arguments are supported in function {name}.")
934
935        return self.func(name, expression.this, expression.expression)
936
937    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
940def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
941    this = expression.this.copy()
942
943    return_type = expression.return_type
944    if return_type.is_type(exp.DataType.Type.DATE):
945        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
946        # can truncate timestamp strings, because some dialects can't cast them to DATE
947        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
948
949    expression.this.replace(exp.cast(this, return_type))
950    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
953def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
954    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
955        if cast and isinstance(expression, exp.TsOrDsAdd):
956            expression = ts_or_ds_add_cast(expression)
957
958        return self.func(
959            name,
960            exp.var(expression.text("unit").upper() or "DAY"),
961            expression.expression,
962            expression.this,
963        )
964
965    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
968def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
969    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
970    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
971    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
972
973    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
976def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
977    """Remove table refs from columns in when statements."""
978    alias = expression.this.args.get("alias")
979
980    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
981        return self.dialect.normalize_identifier(identifier).name if identifier else None
982
983    targets = {normalize(expression.this.this)}
984
985    if alias:
986        targets.add(normalize(alias.this))
987
988    for when in expression.expressions:
989        when.transform(
990            lambda node: (
991                exp.column(node.this)
992                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
993                else node
994            ),
995            copy=False,
996        )
997
998    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True) -> Callable[[List], ~F]:
1001def build_json_extract_path(
1002    expr_type: t.Type[F], zero_based_indexing: bool = True
1003) -> t.Callable[[t.List], F]:
1004    def _builder(args: t.List) -> F:
1005        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1006        for arg in args[1:]:
1007            if not isinstance(arg, exp.Literal):
1008                # We use the fallback parser because we can't really transpile non-literals safely
1009                return expr_type.from_arg_list(args)
1010
1011            text = arg.name
1012            if is_int(text):
1013                index = int(text)
1014                segments.append(
1015                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1016                )
1017            else:
1018                segments.append(exp.JSONPathKey(this=text))
1019
1020        # This is done to avoid failing in the expression validator due to the arg count
1021        del args[2:]
1022        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1023
1024    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1027def json_extract_segments(
1028    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1029) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1030    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1031        path = expression.expression
1032        if not isinstance(path, exp.JSONPath):
1033            return rename_func(name)(self, expression)
1034
1035        segments = []
1036        for segment in path.expressions:
1037            path = self.sql(segment)
1038            if path:
1039                if isinstance(segment, exp.JSONPathPart) and (
1040                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1041                ):
1042                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1043
1044                segments.append(path)
1045
1046        if op:
1047            return f" {op} ".join([self.sql(expression.this), *segments])
1048        return self.func(name, expression.this, *segments)
1049
1050    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1053def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1054    if isinstance(expression.this, exp.JSONPathWildcard):
1055        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1056
1057    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1060def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1061    cond = expression.expression
1062    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1063        alias = cond.expressions[0]
1064        cond = cond.this
1065    elif isinstance(cond, exp.Predicate):
1066        alias = "_u"
1067    else:
1068        self.unsupported("Unsupported filter condition")
1069        return ""
1070
1071    unnest = exp.Unnest(expressions=[expression.this])
1072    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1073    return self.sql(exp.Array(expressions=[filtered]))