Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import typing as t
   4from enum import Enum, auto
   5from functools import reduce
   6
   7from sqlglot import exp
   8from sqlglot._typing import E
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, seq_get
  12from sqlglot.parser import Parser
  13from sqlglot.time import TIMEZONES, format_time
  14from sqlglot.tokens import Token, Tokenizer, TokenType
  15from sqlglot.trie import new_trie
  16
  17B = t.TypeVar("B", bound=exp.Binary)
  18
  19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  21
  22
  23class Dialects(str, Enum):
  24    """Dialects supported by SQLGLot."""
  25
  26    DIALECT = ""
  27
  28    BIGQUERY = "bigquery"
  29    CLICKHOUSE = "clickhouse"
  30    DATABRICKS = "databricks"
  31    DORIS = "doris"
  32    DRILL = "drill"
  33    DUCKDB = "duckdb"
  34    HIVE = "hive"
  35    MYSQL = "mysql"
  36    ORACLE = "oracle"
  37    POSTGRES = "postgres"
  38    PRESTO = "presto"
  39    REDSHIFT = "redshift"
  40    SNOWFLAKE = "snowflake"
  41    SPARK = "spark"
  42    SPARK2 = "spark2"
  43    SQLITE = "sqlite"
  44    STARROCKS = "starrocks"
  45    TABLEAU = "tableau"
  46    TERADATA = "teradata"
  47    TRINO = "trino"
  48    TSQL = "tsql"
  49
  50
  51class NormalizationStrategy(str, AutoName):
  52    """Specifies the strategy according to which identifiers should be normalized."""
  53
  54    LOWERCASE = auto()
  55    """Unquoted identifiers are lowercased."""
  56
  57    UPPERCASE = auto()
  58    """Unquoted identifiers are uppercased."""
  59
  60    CASE_SENSITIVE = auto()
  61    """Always case-sensitive, regardless of quotes."""
  62
  63    CASE_INSENSITIVE = auto()
  64    """Always case-insensitive, regardless of quotes."""
  65
  66
  67class _Dialect(type):
  68    classes: t.Dict[str, t.Type[Dialect]] = {}
  69
  70    def __eq__(cls, other: t.Any) -> bool:
  71        if cls is other:
  72            return True
  73        if isinstance(other, str):
  74            return cls is cls.get(other)
  75        if isinstance(other, Dialect):
  76            return cls is type(other)
  77
  78        return False
  79
  80    def __hash__(cls) -> int:
  81        return hash(cls.__name__.lower())
  82
  83    @classmethod
  84    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  85        return cls.classes[key]
  86
  87    @classmethod
  88    def get(
  89        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  90    ) -> t.Optional[t.Type[Dialect]]:
  91        return cls.classes.get(key, default)
  92
  93    def __new__(cls, clsname, bases, attrs):
  94        klass = super().__new__(cls, clsname, bases, attrs)
  95        enum = Dialects.__members__.get(clsname.upper())
  96        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
  97
  98        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
  99        klass.FORMAT_TRIE = (
 100            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 101        )
 102        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 103        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 104
 105        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 106
 107        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 108        klass.parser_class = getattr(klass, "Parser", Parser)
 109        klass.generator_class = getattr(klass, "Generator", Generator)
 110
 111        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 112        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 113            klass.tokenizer_class._IDENTIFIERS.items()
 114        )[0]
 115
 116        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 117            return next(
 118                (
 119                    (s, e)
 120                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 121                    if t == token_type
 122                ),
 123                (None, None),
 124            )
 125
 126        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 127        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 128        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 129        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 130
 131        if enum not in ("", "bigquery"):
 132            klass.generator_class.SELECT_KINDS = ()
 133
 134        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 135            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 136                TokenType.ANTI,
 137                TokenType.SEMI,
 138            }
 139
 140        return klass
 141
 142
 143class Dialect(metaclass=_Dialect):
 144    INDEX_OFFSET = 0
 145    """Determines the base index offset for arrays."""
 146
 147    WEEK_OFFSET = 0
 148    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 149
 150    UNNEST_COLUMN_ONLY = False
 151    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
 152
 153    ALIAS_POST_TABLESAMPLE = False
 154    """Determines whether or not the table alias comes after tablesample."""
 155
 156    TABLESAMPLE_SIZE_IS_PERCENT = False
 157    """Determines whether or not a size in the table sample clause represents percentage."""
 158
 159    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 160    """Specifies the strategy according to which identifiers should be normalized."""
 161
 162    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 163    """Determines whether or not an unquoted identifier can start with a digit."""
 164
 165    DPIPE_IS_STRING_CONCAT = True
 166    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
 167
 168    STRICT_STRING_CONCAT = False
 169    """Determines whether or not `CONCAT`'s arguments must be strings."""
 170
 171    SUPPORTS_USER_DEFINED_TYPES = True
 172    """Determines whether or not user-defined data types are supported."""
 173
 174    SUPPORTS_SEMI_ANTI_JOIN = True
 175    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
 176
 177    NORMALIZE_FUNCTIONS: bool | str = "upper"
 178    """Determines how function names are going to be normalized."""
 179
 180    LOG_BASE_FIRST = True
 181    """Determines whether the base comes first in the `LOG` function."""
 182
 183    NULL_ORDERING = "nulls_are_small"
 184    """
 185    Indicates the default `NULL` ordering method to use if not explicitly set.
 186    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 187    """
 188
 189    TYPED_DIVISION = False
 190    """
 191    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 192    False means `a / b` is always float division.
 193    True means `a / b` is integer division if both `a` and `b` are integers.
 194    """
 195
 196    SAFE_DIVISION = False
 197    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 198
 199    CONCAT_COALESCE = False
 200    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 201
 202    DATE_FORMAT = "'%Y-%m-%d'"
 203    DATEINT_FORMAT = "'%Y%m%d'"
 204    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 205
 206    TIME_MAPPING: t.Dict[str, str] = {}
 207    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
 208
 209    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 210    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 211    FORMAT_MAPPING: t.Dict[str, str] = {}
 212    """
 213    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 214    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 215    """
 216
 217    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 218    """Mapping of an unescaped escape sequence to the corresponding character."""
 219
 220    PSEUDOCOLUMNS: t.Set[str] = set()
 221    """
 222    Columns that are auto-generated by the engine corresponding to this dialect.
 223    For example, such columns may be excluded from `SELECT *` queries.
 224    """
 225
 226    PREFER_CTE_ALIAS_COLUMN = False
 227    """
 228    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 229    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 230    any projection aliases in the subquery.
 231
 232    For example,
 233        WITH y(c) AS (
 234            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 235        ) SELECT c FROM y;
 236
 237        will be rewritten as
 238
 239        WITH y(c) AS (
 240            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 241        ) SELECT c FROM y;
 242    """
 243
 244    # --- Autofilled ---
 245
 246    tokenizer_class = Tokenizer
 247    parser_class = Parser
 248    generator_class = Generator
 249
 250    # A trie of the time_mapping keys
 251    TIME_TRIE: t.Dict = {}
 252    FORMAT_TRIE: t.Dict = {}
 253
 254    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 255    INVERSE_TIME_TRIE: t.Dict = {}
 256
 257    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 258
 259    # Delimiters for quotes, identifiers and the corresponding escape characters
 260    QUOTE_START = "'"
 261    QUOTE_END = "'"
 262    IDENTIFIER_START = '"'
 263    IDENTIFIER_END = '"'
 264
 265    # Delimiters for bit, hex, byte and unicode literals
 266    BIT_START: t.Optional[str] = None
 267    BIT_END: t.Optional[str] = None
 268    HEX_START: t.Optional[str] = None
 269    HEX_END: t.Optional[str] = None
 270    BYTE_START: t.Optional[str] = None
 271    BYTE_END: t.Optional[str] = None
 272    UNICODE_START: t.Optional[str] = None
 273    UNICODE_END: t.Optional[str] = None
 274
 275    @classmethod
 276    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 277        """
 278        Look up a dialect in the global dialect registry and return it if it exists.
 279
 280        Args:
 281            dialect: The target dialect. If this is a string, it can be optionally followed by
 282                additional key-value pairs that are separated by commas and are used to specify
 283                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 284
 285        Example:
 286            >>> dialect = dialect_class = get_or_raise("duckdb")
 287            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 288
 289        Returns:
 290            The corresponding Dialect instance.
 291        """
 292
 293        if not dialect:
 294            return cls()
 295        if isinstance(dialect, _Dialect):
 296            return dialect()
 297        if isinstance(dialect, Dialect):
 298            return dialect
 299        if isinstance(dialect, str):
 300            try:
 301                dialect_name, *kv_pairs = dialect.split(",")
 302                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 303            except ValueError:
 304                raise ValueError(
 305                    f"Invalid dialect format: '{dialect}'. "
 306                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 307                )
 308
 309            result = cls.get(dialect_name.strip())
 310            if not result:
 311                from difflib import get_close_matches
 312
 313                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 314                if similar:
 315                    similar = f" Did you mean {similar}?"
 316
 317                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 318
 319            return result(**kwargs)
 320
 321        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 322
 323    @classmethod
 324    def format_time(
 325        cls, expression: t.Optional[str | exp.Expression]
 326    ) -> t.Optional[exp.Expression]:
 327        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 328        if isinstance(expression, str):
 329            return exp.Literal.string(
 330                # the time formats are quoted
 331                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 332            )
 333
 334        if expression and expression.is_string:
 335            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 336
 337        return expression
 338
 339    def __init__(self, **kwargs) -> None:
 340        normalization_strategy = kwargs.get("normalization_strategy")
 341
 342        if normalization_strategy is None:
 343            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 344        else:
 345            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 346
 347    def __eq__(self, other: t.Any) -> bool:
 348        # Does not currently take dialect state into account
 349        return type(self) == other
 350
 351    def __hash__(self) -> int:
 352        # Does not currently take dialect state into account
 353        return hash(type(self))
 354
 355    def normalize_identifier(self, expression: E) -> E:
 356        """
 357        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 358
 359        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 360        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 361        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 362        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 363
 364        There are also dialects like Spark, which are case-insensitive even when quotes are
 365        present, and dialects like MySQL, whose resolution rules match those employed by the
 366        underlying operating system, for example they may always be case-sensitive in Linux.
 367
 368        Finally, the normalization behavior of some engines can even be controlled through flags,
 369        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 370
 371        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 372        that it can analyze queries in the optimizer and successfully capture their semantics.
 373        """
 374        if (
 375            isinstance(expression, exp.Identifier)
 376            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
 377            and (
 378                not expression.quoted
 379                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 380            )
 381        ):
 382            expression.set(
 383                "this",
 384                expression.this.upper()
 385                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 386                else expression.this.lower(),
 387            )
 388
 389        return expression
 390
 391    def case_sensitive(self, text: str) -> bool:
 392        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 393        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 394            return False
 395
 396        unsafe = (
 397            str.islower
 398            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 399            else str.isupper
 400        )
 401        return any(unsafe(char) for char in text)
 402
 403    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 404        """Checks if text can be identified given an identify option.
 405
 406        Args:
 407            text: The text to check.
 408            identify:
 409                `"always"` or `True`: Always returns `True`.
 410                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 411
 412        Returns:
 413            Whether or not the given text can be identified.
 414        """
 415        if identify is True or identify == "always":
 416            return True
 417
 418        if identify == "safe":
 419            return not self.case_sensitive(text)
 420
 421        return False
 422
 423    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 424        """
 425        Adds quotes to a given identifier.
 426
 427        Args:
 428            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 429            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 430                "unsafe", with respect to its characters and this dialect's normalization strategy.
 431        """
 432        if isinstance(expression, exp.Identifier):
 433            name = expression.this
 434            expression.set(
 435                "quoted",
 436                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 437            )
 438
 439        return expression
 440
 441    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 442        return self.parser(**opts).parse(self.tokenize(sql), sql)
 443
 444    def parse_into(
 445        self, expression_type: exp.IntoType, sql: str, **opts
 446    ) -> t.List[t.Optional[exp.Expression]]:
 447        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 448
 449    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 450        return self.generator(**opts).generate(expression, copy=copy)
 451
 452    def transpile(self, sql: str, **opts) -> t.List[str]:
 453        return [
 454            self.generate(expression, copy=False, **opts) if expression else ""
 455            for expression in self.parse(sql)
 456        ]
 457
 458    def tokenize(self, sql: str) -> t.List[Token]:
 459        return self.tokenizer.tokenize(sql)
 460
 461    @property
 462    def tokenizer(self) -> Tokenizer:
 463        if not hasattr(self, "_tokenizer"):
 464            self._tokenizer = self.tokenizer_class(dialect=self)
 465        return self._tokenizer
 466
 467    def parser(self, **opts) -> Parser:
 468        return self.parser_class(dialect=self, **opts)
 469
 470    def generator(self, **opts) -> Generator:
 471        return self.generator_class(dialect=self, **opts)
 472
 473
 474DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 475
 476
 477def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 478    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 479
 480
 481def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 482    if expression.args.get("accuracy"):
 483        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 484    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 485
 486
 487def if_sql(
 488    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 489) -> t.Callable[[Generator, exp.If], str]:
 490    def _if_sql(self: Generator, expression: exp.If) -> str:
 491        return self.func(
 492            name,
 493            expression.this,
 494            expression.args.get("true"),
 495            expression.args.get("false") or false_value,
 496        )
 497
 498    return _if_sql
 499
 500
 501def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
 502    return self.binary(expression, "->")
 503
 504
 505def arrow_json_extract_scalar_sql(
 506    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
 507) -> str:
 508    return self.binary(expression, "->>")
 509
 510
 511def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 512    return f"[{self.expressions(expression, flat=True)}]"
 513
 514
 515def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 516    return self.like_sql(
 517        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 518    )
 519
 520
 521def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 522    zone = self.sql(expression, "this")
 523    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 524
 525
 526def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 527    if expression.args.get("recursive"):
 528        self.unsupported("Recursive CTEs are unsupported")
 529        expression.args["recursive"] = False
 530    return self.with_sql(expression)
 531
 532
 533def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 534    n = self.sql(expression, "this")
 535    d = self.sql(expression, "expression")
 536    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 537
 538
 539def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 540    self.unsupported("TABLESAMPLE unsupported")
 541    return self.sql(expression.this)
 542
 543
 544def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 545    self.unsupported("PIVOT unsupported")
 546    return ""
 547
 548
 549def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 550    return self.cast_sql(expression)
 551
 552
 553def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
 554    self.unsupported("Properties unsupported")
 555    return ""
 556
 557
 558def no_comment_column_constraint_sql(
 559    self: Generator, expression: exp.CommentColumnConstraint
 560) -> str:
 561    self.unsupported("CommentColumnConstraint unsupported")
 562    return ""
 563
 564
 565def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 566    self.unsupported("MAP_FROM_ENTRIES unsupported")
 567    return ""
 568
 569
 570def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 571    this = self.sql(expression, "this")
 572    substr = self.sql(expression, "substr")
 573    position = self.sql(expression, "position")
 574    if position:
 575        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 576    return f"STRPOS({this}, {substr})"
 577
 578
 579def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 580    return (
 581        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 582    )
 583
 584
 585def var_map_sql(
 586    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 587) -> str:
 588    keys = expression.args["keys"]
 589    values = expression.args["values"]
 590
 591    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 592        self.unsupported("Cannot convert array columns into map.")
 593        return self.func(map_func_name, keys, values)
 594
 595    args = []
 596    for key, value in zip(keys.expressions, values.expressions):
 597        args.append(self.sql(key))
 598        args.append(self.sql(value))
 599
 600    return self.func(map_func_name, *args)
 601
 602
 603def format_time_lambda(
 604    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 605) -> t.Callable[[t.List], E]:
 606    """Helper used for time expressions.
 607
 608    Args:
 609        exp_class: the expression class to instantiate.
 610        dialect: target sql dialect.
 611        default: the default format, True being time.
 612
 613    Returns:
 614        A callable that can be used to return the appropriately formatted time expression.
 615    """
 616
 617    def _format_time(args: t.List):
 618        return exp_class(
 619            this=seq_get(args, 0),
 620            format=Dialect[dialect].format_time(
 621                seq_get(args, 1)
 622                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 623            ),
 624        )
 625
 626    return _format_time
 627
 628
 629def time_format(
 630    dialect: DialectType = None,
 631) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 632    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 633        """
 634        Returns the time format for a given expression, unless it's equivalent
 635        to the default time format of the dialect of interest.
 636        """
 637        time_format = self.format_time(expression)
 638        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 639
 640    return _time_format
 641
 642
 643def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
 644    """
 645    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
 646    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
 647    columns are removed from the create statement.
 648    """
 649    has_schema = isinstance(expression.this, exp.Schema)
 650    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
 651
 652    if has_schema and is_partitionable:
 653        prop = expression.find(exp.PartitionedByProperty)
 654        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 655            schema = expression.this
 656            columns = {v.name.upper() for v in prop.this.expressions}
 657            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 658            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 659            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 660            expression.set("this", schema)
 661
 662    return self.create_sql(expression)
 663
 664
 665def parse_date_delta(
 666    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 667) -> t.Callable[[t.List], E]:
 668    def inner_func(args: t.List) -> E:
 669        unit_based = len(args) == 3
 670        this = args[2] if unit_based else seq_get(args, 0)
 671        unit = args[0] if unit_based else exp.Literal.string("DAY")
 672        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 673        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 674
 675    return inner_func
 676
 677
 678def parse_date_delta_with_interval(
 679    expression_class: t.Type[E],
 680) -> t.Callable[[t.List], t.Optional[E]]:
 681    def func(args: t.List) -> t.Optional[E]:
 682        if len(args) < 2:
 683            return None
 684
 685        interval = args[1]
 686
 687        if not isinstance(interval, exp.Interval):
 688            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 689
 690        expression = interval.this
 691        if expression and expression.is_string:
 692            expression = exp.Literal.number(expression.this)
 693
 694        return expression_class(
 695            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 696        )
 697
 698    return func
 699
 700
 701def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 702    unit = seq_get(args, 0)
 703    this = seq_get(args, 1)
 704
 705    if isinstance(this, exp.Cast) and this.is_type("date"):
 706        return exp.DateTrunc(unit=unit, this=this)
 707    return exp.TimestampTrunc(this=this, unit=unit)
 708
 709
 710def date_add_interval_sql(
 711    data_type: str, kind: str
 712) -> t.Callable[[Generator, exp.Expression], str]:
 713    def func(self: Generator, expression: exp.Expression) -> str:
 714        this = self.sql(expression, "this")
 715        unit = expression.args.get("unit")
 716        unit = exp.var(unit.name.upper() if unit else "DAY")
 717        interval = exp.Interval(this=expression.expression, unit=unit)
 718        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 719
 720    return func
 721
 722
 723def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 724    return self.func(
 725        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 726    )
 727
 728
 729def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 730    if not expression.expression:
 731        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
 732    if expression.text("expression").lower() in TIMEZONES:
 733        return self.sql(
 734            exp.AtTimeZone(
 735                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 736                zone=expression.expression,
 737            )
 738        )
 739    return self.function_fallback_sql(expression)
 740
 741
 742def locate_to_strposition(args: t.List) -> exp.Expression:
 743    return exp.StrPosition(
 744        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 745    )
 746
 747
 748def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 749    return self.func(
 750        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 751    )
 752
 753
 754def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 755    return self.sql(
 756        exp.Substring(
 757            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 758        )
 759    )
 760
 761
 762def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 763    return self.sql(
 764        exp.Substring(
 765            this=expression.this,
 766            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 767        )
 768    )
 769
 770
 771def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 772    return self.sql(exp.cast(expression.this, "timestamp"))
 773
 774
 775def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 776    return self.sql(exp.cast(expression.this, "date"))
 777
 778
 779# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 780def encode_decode_sql(
 781    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 782) -> str:
 783    charset = expression.args.get("charset")
 784    if charset and charset.name.lower() != "utf-8":
 785        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 786
 787    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 788
 789
 790def min_or_least(self: Generator, expression: exp.Min) -> str:
 791    name = "LEAST" if expression.expressions else "MIN"
 792    return rename_func(name)(self, expression)
 793
 794
 795def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 796    name = "GREATEST" if expression.expressions else "MAX"
 797    return rename_func(name)(self, expression)
 798
 799
 800def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 801    cond = expression.this
 802
 803    if isinstance(expression.this, exp.Distinct):
 804        cond = expression.this.expressions[0]
 805        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 806
 807    return self.func("sum", exp.func("if", cond, 1, 0))
 808
 809
 810def trim_sql(self: Generator, expression: exp.Trim) -> str:
 811    target = self.sql(expression, "this")
 812    trim_type = self.sql(expression, "position")
 813    remove_chars = self.sql(expression, "expression")
 814    collation = self.sql(expression, "collation")
 815
 816    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 817    if not remove_chars and not collation:
 818        return self.trim_sql(expression)
 819
 820    trim_type = f"{trim_type} " if trim_type else ""
 821    remove_chars = f"{remove_chars} " if remove_chars else ""
 822    from_part = "FROM " if trim_type or remove_chars else ""
 823    collation = f" COLLATE {collation}" if collation else ""
 824    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 825
 826
 827def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 828    return self.func("STRPTIME", expression.this, self.format_time(expression))
 829
 830
 831def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 832    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 833
 834
 835def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 836    delim, *rest_args = expression.expressions
 837    return self.sql(
 838        reduce(
 839            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 840            rest_args,
 841        )
 842    )
 843
 844
 845def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 846    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 847    if bad_args:
 848        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 849
 850    return self.func(
 851        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 852    )
 853
 854
 855def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 856    bad_args = list(
 857        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 858    )
 859    if bad_args:
 860        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 861
 862    return self.func(
 863        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 864    )
 865
 866
 867def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 868    names = []
 869    for agg in aggregations:
 870        if isinstance(agg, exp.Alias):
 871            names.append(agg.alias)
 872        else:
 873            """
 874            This case corresponds to aggregations without aliases being used as suffixes
 875            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 876            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 877            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 878            """
 879            agg_all_unquoted = agg.transform(
 880                lambda node: exp.Identifier(this=node.name, quoted=False)
 881                if isinstance(node, exp.Identifier)
 882                else node
 883            )
 884            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 885
 886    return names
 887
 888
 889def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 890    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 891
 892
 893# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 894def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 895    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 896
 897
 898def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 899    return self.func("MAX", expression.this)
 900
 901
 902def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 903    a = self.sql(expression.left)
 904    b = self.sql(expression.right)
 905    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 906
 907
 908def is_parse_json(expression: exp.Expression) -> bool:
 909    return isinstance(expression, exp.ParseJSON) or (
 910        isinstance(expression, exp.Cast) and expression.is_type("json")
 911    )
 912
 913
 914def isnull_to_is_null(args: t.List) -> exp.Expression:
 915    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 916
 917
 918def generatedasidentitycolumnconstraint_sql(
 919    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 920) -> str:
 921    start = self.sql(expression, "start") or "1"
 922    increment = self.sql(expression, "increment") or "1"
 923    return f"IDENTITY({start}, {increment})"
 924
 925
 926def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 927    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 928        if expression.args.get("count"):
 929            self.unsupported(f"Only two arguments are supported in function {name}.")
 930
 931        return self.func(name, expression.this, expression.expression)
 932
 933    return _arg_max_or_min_sql
 934
 935
 936def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 937    this = expression.this.copy()
 938
 939    return_type = expression.return_type
 940    if return_type.is_type(exp.DataType.Type.DATE):
 941        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 942        # can truncate timestamp strings, because some dialects can't cast them to DATE
 943        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 944
 945    expression.this.replace(exp.cast(this, return_type))
 946    return expression
 947
 948
 949def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 950    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 951        if cast and isinstance(expression, exp.TsOrDsAdd):
 952            expression = ts_or_ds_add_cast(expression)
 953
 954        return self.func(
 955            name,
 956            exp.var(expression.text("unit").upper() or "DAY"),
 957            expression.expression,
 958            expression.this,
 959        )
 960
 961    return _delta_sql
 962
 963
 964def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
 965    from sqlglot.optimizer.simplify import simplify
 966
 967    # Makes sure the path will be evaluated correctly at runtime to include the path root.
 968    # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
 969    path = expression.expression
 970    path = exp.func(
 971        "if",
 972        exp.func("startswith", path, "'['"),
 973        exp.func("concat", "'$'", path),
 974        exp.func("concat", "'$.'", path),
 975    )
 976
 977    expression.expression.replace(simplify(path))
 978    return expression
 979
 980
 981def path_to_jsonpath(
 982    name: str = "JSON_EXTRACT",
 983) -> t.Callable[[Generator, exp.GetPath], str]:
 984    def _transform(self: Generator, expression: exp.GetPath) -> str:
 985        return rename_func(name)(self, prepend_dollar_to_path(expression))
 986
 987    return _transform
 988
 989
 990def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 991    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 992    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 993    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 994
 995    return self.sql(exp.cast(minus_one_day, "date"))
 996
 997
 998def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 999    """Remove table refs from columns in when statements."""
1000    alias = expression.this.args.get("alias")
1001
1002    normalize = (
1003        lambda identifier: self.dialect.normalize_identifier(identifier).name
1004        if identifier
1005        else None
1006    )
1007
1008    targets = {normalize(expression.this.this)}
1009
1010    if alias:
1011        targets.add(normalize(alias.this))
1012
1013    for when in expression.expressions:
1014        when.transform(
1015            lambda node: exp.column(node.this)
1016            if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1017            else node,
1018            copy=False,
1019        )
1020
1021    return self.merge_sql(expression)
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    """Dialects supported by SQLGLot."""
26
27    DIALECT = ""
28
29    BIGQUERY = "bigquery"
30    CLICKHOUSE = "clickhouse"
31    DATABRICKS = "databricks"
32    DORIS = "doris"
33    DRILL = "drill"
34    DUCKDB = "duckdb"
35    HIVE = "hive"
36    MYSQL = "mysql"
37    ORACLE = "oracle"
38    POSTGRES = "postgres"
39    PRESTO = "presto"
40    REDSHIFT = "redshift"
41    SNOWFLAKE = "snowflake"
42    SPARK = "spark"
43    SPARK2 = "spark2"
44    SQLITE = "sqlite"
45    STARROCKS = "starrocks"
46    TABLEAU = "tableau"
47    TERADATA = "teradata"
48    TRINO = "trino"
49    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
52class NormalizationStrategy(str, AutoName):
53    """Specifies the strategy according to which identifiers should be normalized."""
54
55    LOWERCASE = auto()
56    """Unquoted identifiers are lowercased."""
57
58    UPPERCASE = auto()
59    """Unquoted identifiers are uppercased."""
60
61    CASE_SENSITIVE = auto()
62    """Always case-sensitive, regardless of quotes."""
63
64    CASE_INSENSITIVE = auto()
65    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
144class Dialect(metaclass=_Dialect):
145    INDEX_OFFSET = 0
146    """Determines the base index offset for arrays."""
147
148    WEEK_OFFSET = 0
149    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
150
151    UNNEST_COLUMN_ONLY = False
152    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
153
154    ALIAS_POST_TABLESAMPLE = False
155    """Determines whether or not the table alias comes after tablesample."""
156
157    TABLESAMPLE_SIZE_IS_PERCENT = False
158    """Determines whether or not a size in the table sample clause represents percentage."""
159
160    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
161    """Specifies the strategy according to which identifiers should be normalized."""
162
163    IDENTIFIERS_CAN_START_WITH_DIGIT = False
164    """Determines whether or not an unquoted identifier can start with a digit."""
165
166    DPIPE_IS_STRING_CONCAT = True
167    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
168
169    STRICT_STRING_CONCAT = False
170    """Determines whether or not `CONCAT`'s arguments must be strings."""
171
172    SUPPORTS_USER_DEFINED_TYPES = True
173    """Determines whether or not user-defined data types are supported."""
174
175    SUPPORTS_SEMI_ANTI_JOIN = True
176    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
177
178    NORMALIZE_FUNCTIONS: bool | str = "upper"
179    """Determines how function names are going to be normalized."""
180
181    LOG_BASE_FIRST = True
182    """Determines whether the base comes first in the `LOG` function."""
183
184    NULL_ORDERING = "nulls_are_small"
185    """
186    Indicates the default `NULL` ordering method to use if not explicitly set.
187    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
188    """
189
190    TYPED_DIVISION = False
191    """
192    Whether the behavior of `a / b` depends on the types of `a` and `b`.
193    False means `a / b` is always float division.
194    True means `a / b` is integer division if both `a` and `b` are integers.
195    """
196
197    SAFE_DIVISION = False
198    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
199
200    CONCAT_COALESCE = False
201    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
202
203    DATE_FORMAT = "'%Y-%m-%d'"
204    DATEINT_FORMAT = "'%Y%m%d'"
205    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
206
207    TIME_MAPPING: t.Dict[str, str] = {}
208    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
209
210    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
211    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
212    FORMAT_MAPPING: t.Dict[str, str] = {}
213    """
214    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
215    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
216    """
217
218    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
219    """Mapping of an unescaped escape sequence to the corresponding character."""
220
221    PSEUDOCOLUMNS: t.Set[str] = set()
222    """
223    Columns that are auto-generated by the engine corresponding to this dialect.
224    For example, such columns may be excluded from `SELECT *` queries.
225    """
226
227    PREFER_CTE_ALIAS_COLUMN = False
228    """
229    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
230    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
231    any projection aliases in the subquery.
232
233    For example,
234        WITH y(c) AS (
235            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
236        ) SELECT c FROM y;
237
238        will be rewritten as
239
240        WITH y(c) AS (
241            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
242        ) SELECT c FROM y;
243    """
244
245    # --- Autofilled ---
246
247    tokenizer_class = Tokenizer
248    parser_class = Parser
249    generator_class = Generator
250
251    # A trie of the time_mapping keys
252    TIME_TRIE: t.Dict = {}
253    FORMAT_TRIE: t.Dict = {}
254
255    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
256    INVERSE_TIME_TRIE: t.Dict = {}
257
258    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
259
260    # Delimiters for quotes, identifiers and the corresponding escape characters
261    QUOTE_START = "'"
262    QUOTE_END = "'"
263    IDENTIFIER_START = '"'
264    IDENTIFIER_END = '"'
265
266    # Delimiters for bit, hex, byte and unicode literals
267    BIT_START: t.Optional[str] = None
268    BIT_END: t.Optional[str] = None
269    HEX_START: t.Optional[str] = None
270    HEX_END: t.Optional[str] = None
271    BYTE_START: t.Optional[str] = None
272    BYTE_END: t.Optional[str] = None
273    UNICODE_START: t.Optional[str] = None
274    UNICODE_END: t.Optional[str] = None
275
276    @classmethod
277    def get_or_raise(cls, dialect: DialectType) -> Dialect:
278        """
279        Look up a dialect in the global dialect registry and return it if it exists.
280
281        Args:
282            dialect: The target dialect. If this is a string, it can be optionally followed by
283                additional key-value pairs that are separated by commas and are used to specify
284                dialect settings, such as whether the dialect's identifiers are case-sensitive.
285
286        Example:
287            >>> dialect = dialect_class = get_or_raise("duckdb")
288            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
289
290        Returns:
291            The corresponding Dialect instance.
292        """
293
294        if not dialect:
295            return cls()
296        if isinstance(dialect, _Dialect):
297            return dialect()
298        if isinstance(dialect, Dialect):
299            return dialect
300        if isinstance(dialect, str):
301            try:
302                dialect_name, *kv_pairs = dialect.split(",")
303                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
304            except ValueError:
305                raise ValueError(
306                    f"Invalid dialect format: '{dialect}'. "
307                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
308                )
309
310            result = cls.get(dialect_name.strip())
311            if not result:
312                from difflib import get_close_matches
313
314                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
315                if similar:
316                    similar = f" Did you mean {similar}?"
317
318                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
319
320            return result(**kwargs)
321
322        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
323
324    @classmethod
325    def format_time(
326        cls, expression: t.Optional[str | exp.Expression]
327    ) -> t.Optional[exp.Expression]:
328        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
329        if isinstance(expression, str):
330            return exp.Literal.string(
331                # the time formats are quoted
332                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
333            )
334
335        if expression and expression.is_string:
336            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
337
338        return expression
339
340    def __init__(self, **kwargs) -> None:
341        normalization_strategy = kwargs.get("normalization_strategy")
342
343        if normalization_strategy is None:
344            self.normalization_strategy = self.NORMALIZATION_STRATEGY
345        else:
346            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
347
348    def __eq__(self, other: t.Any) -> bool:
349        # Does not currently take dialect state into account
350        return type(self) == other
351
352    def __hash__(self) -> int:
353        # Does not currently take dialect state into account
354        return hash(type(self))
355
356    def normalize_identifier(self, expression: E) -> E:
357        """
358        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
359
360        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
361        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
362        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
363        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
364
365        There are also dialects like Spark, which are case-insensitive even when quotes are
366        present, and dialects like MySQL, whose resolution rules match those employed by the
367        underlying operating system, for example they may always be case-sensitive in Linux.
368
369        Finally, the normalization behavior of some engines can even be controlled through flags,
370        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
371
372        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
373        that it can analyze queries in the optimizer and successfully capture their semantics.
374        """
375        if (
376            isinstance(expression, exp.Identifier)
377            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
378            and (
379                not expression.quoted
380                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
381            )
382        ):
383            expression.set(
384                "this",
385                expression.this.upper()
386                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
387                else expression.this.lower(),
388            )
389
390        return expression
391
392    def case_sensitive(self, text: str) -> bool:
393        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
394        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
395            return False
396
397        unsafe = (
398            str.islower
399            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
400            else str.isupper
401        )
402        return any(unsafe(char) for char in text)
403
404    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
405        """Checks if text can be identified given an identify option.
406
407        Args:
408            text: The text to check.
409            identify:
410                `"always"` or `True`: Always returns `True`.
411                `"safe"`: Only returns `True` if the identifier is case-insensitive.
412
413        Returns:
414            Whether or not the given text can be identified.
415        """
416        if identify is True or identify == "always":
417            return True
418
419        if identify == "safe":
420            return not self.case_sensitive(text)
421
422        return False
423
424    def quote_identifier(self, expression: E, identify: bool = True) -> E:
425        """
426        Adds quotes to a given identifier.
427
428        Args:
429            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
430            identify: If set to `False`, the quotes will only be added if the identifier is deemed
431                "unsafe", with respect to its characters and this dialect's normalization strategy.
432        """
433        if isinstance(expression, exp.Identifier):
434            name = expression.this
435            expression.set(
436                "quoted",
437                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
438            )
439
440        return expression
441
442    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
443        return self.parser(**opts).parse(self.tokenize(sql), sql)
444
445    def parse_into(
446        self, expression_type: exp.IntoType, sql: str, **opts
447    ) -> t.List[t.Optional[exp.Expression]]:
448        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
449
450    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
451        return self.generator(**opts).generate(expression, copy=copy)
452
453    def transpile(self, sql: str, **opts) -> t.List[str]:
454        return [
455            self.generate(expression, copy=False, **opts) if expression else ""
456            for expression in self.parse(sql)
457        ]
458
459    def tokenize(self, sql: str) -> t.List[Token]:
460        return self.tokenizer.tokenize(sql)
461
462    @property
463    def tokenizer(self) -> Tokenizer:
464        if not hasattr(self, "_tokenizer"):
465            self._tokenizer = self.tokenizer_class(dialect=self)
466        return self._tokenizer
467
468    def parser(self, **opts) -> Parser:
469        return self.parser_class(dialect=self, **opts)
470
471    def generator(self, **opts) -> Generator:
472        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
340    def __init__(self, **kwargs) -> None:
341        normalization_strategy = kwargs.get("normalization_strategy")
342
343        if normalization_strategy is None:
344            self.normalization_strategy = self.NORMALIZATION_STRATEGY
345        else:
346            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Determines whether or not a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
276    @classmethod
277    def get_or_raise(cls, dialect: DialectType) -> Dialect:
278        """
279        Look up a dialect in the global dialect registry and return it if it exists.
280
281        Args:
282            dialect: The target dialect. If this is a string, it can be optionally followed by
283                additional key-value pairs that are separated by commas and are used to specify
284                dialect settings, such as whether the dialect's identifiers are case-sensitive.
285
286        Example:
287            >>> dialect = dialect_class = get_or_raise("duckdb")
288            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
289
290        Returns:
291            The corresponding Dialect instance.
292        """
293
294        if not dialect:
295            return cls()
296        if isinstance(dialect, _Dialect):
297            return dialect()
298        if isinstance(dialect, Dialect):
299            return dialect
300        if isinstance(dialect, str):
301            try:
302                dialect_name, *kv_pairs = dialect.split(",")
303                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
304            except ValueError:
305                raise ValueError(
306                    f"Invalid dialect format: '{dialect}'. "
307                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
308                )
309
310            result = cls.get(dialect_name.strip())
311            if not result:
312                from difflib import get_close_matches
313
314                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
315                if similar:
316                    similar = f" Did you mean {similar}?"
317
318                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
319
320            return result(**kwargs)
321
322        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
324    @classmethod
325    def format_time(
326        cls, expression: t.Optional[str | exp.Expression]
327    ) -> t.Optional[exp.Expression]:
328        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
329        if isinstance(expression, str):
330            return exp.Literal.string(
331                # the time formats are quoted
332                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
333            )
334
335        if expression and expression.is_string:
336            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
337
338        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
356    def normalize_identifier(self, expression: E) -> E:
357        """
358        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
359
360        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
361        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
362        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
363        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
364
365        There are also dialects like Spark, which are case-insensitive even when quotes are
366        present, and dialects like MySQL, whose resolution rules match those employed by the
367        underlying operating system, for example they may always be case-sensitive in Linux.
368
369        Finally, the normalization behavior of some engines can even be controlled through flags,
370        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
371
372        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
373        that it can analyze queries in the optimizer and successfully capture their semantics.
374        """
375        if (
376            isinstance(expression, exp.Identifier)
377            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
378            and (
379                not expression.quoted
380                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
381            )
382        ):
383            expression.set(
384                "this",
385                expression.this.upper()
386                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
387                else expression.this.lower(),
388            )
389
390        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
392    def case_sensitive(self, text: str) -> bool:
393        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
394        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
395            return False
396
397        unsafe = (
398            str.islower
399            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
400            else str.isupper
401        )
402        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
404    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
405        """Checks if text can be identified given an identify option.
406
407        Args:
408            text: The text to check.
409            identify:
410                `"always"` or `True`: Always returns `True`.
411                `"safe"`: Only returns `True` if the identifier is case-insensitive.
412
413        Returns:
414            Whether or not the given text can be identified.
415        """
416        if identify is True or identify == "always":
417            return True
418
419        if identify == "safe":
420            return not self.case_sensitive(text)
421
422        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
424    def quote_identifier(self, expression: E, identify: bool = True) -> E:
425        """
426        Adds quotes to a given identifier.
427
428        Args:
429            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
430            identify: If set to `False`, the quotes will only be added if the identifier is deemed
431                "unsafe", with respect to its characters and this dialect's normalization strategy.
432        """
433        if isinstance(expression, exp.Identifier):
434            name = expression.this
435            expression.set(
436                "quoted",
437                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
438            )
439
440        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
442    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
443        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
445    def parse_into(
446        self, expression_type: exp.IntoType, sql: str, **opts
447    ) -> t.List[t.Optional[exp.Expression]]:
448        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
450    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
451        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
453    def transpile(self, sql: str, **opts) -> t.List[str]:
454        return [
455            self.generate(expression, copy=False, **opts) if expression else ""
456            for expression in self.parse(sql)
457        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
459    def tokenize(self, sql: str) -> t.List[Token]:
460        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
462    @property
463    def tokenizer(self) -> Tokenizer:
464        if not hasattr(self, "_tokenizer"):
465            self._tokenizer = self.tokenizer_class(dialect=self)
466        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
468    def parser(self, **opts) -> Parser:
469        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
471    def generator(self, **opts) -> Generator:
472        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
478def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
479    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
482def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
483    if expression.args.get("accuracy"):
484        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
485    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
488def if_sql(
489    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
490) -> t.Callable[[Generator, exp.If], str]:
491    def _if_sql(self: Generator, expression: exp.If) -> str:
492        return self.func(
493            name,
494            expression.this,
495            expression.args.get("true"),
496            expression.args.get("false") or false_value,
497        )
498
499    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
502def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
503    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
506def arrow_json_extract_scalar_sql(
507    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
508) -> str:
509    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
512def inline_array_sql(self: Generator, expression: exp.Array) -> str:
513    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
516def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
517    return self.like_sql(
518        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
519    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
522def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
523    zone = self.sql(expression, "this")
524    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
527def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
528    if expression.args.get("recursive"):
529        self.unsupported("Recursive CTEs are unsupported")
530        expression.args["recursive"] = False
531    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
534def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
535    n = self.sql(expression, "this")
536    d = self.sql(expression, "expression")
537    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
540def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
541    self.unsupported("TABLESAMPLE unsupported")
542    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
545def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
546    self.unsupported("PIVOT unsupported")
547    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
550def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
551    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
554def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
555    self.unsupported("Properties unsupported")
556    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
559def no_comment_column_constraint_sql(
560    self: Generator, expression: exp.CommentColumnConstraint
561) -> str:
562    self.unsupported("CommentColumnConstraint unsupported")
563    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
566def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
567    self.unsupported("MAP_FROM_ENTRIES unsupported")
568    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
571def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
572    this = self.sql(expression, "this")
573    substr = self.sql(expression, "substr")
574    position = self.sql(expression, "position")
575    if position:
576        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
577    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
580def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
581    return (
582        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
583    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
586def var_map_sql(
587    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
588) -> str:
589    keys = expression.args["keys"]
590    values = expression.args["values"]
591
592    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
593        self.unsupported("Cannot convert array columns into map.")
594        return self.func(map_func_name, keys, values)
595
596    args = []
597    for key, value in zip(keys.expressions, values.expressions):
598        args.append(self.sql(key))
599        args.append(self.sql(value))
600
601    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
604def format_time_lambda(
605    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
606) -> t.Callable[[t.List], E]:
607    """Helper used for time expressions.
608
609    Args:
610        exp_class: the expression class to instantiate.
611        dialect: target sql dialect.
612        default: the default format, True being time.
613
614    Returns:
615        A callable that can be used to return the appropriately formatted time expression.
616    """
617
618    def _format_time(args: t.List):
619        return exp_class(
620            this=seq_get(args, 0),
621            format=Dialect[dialect].format_time(
622                seq_get(args, 1)
623                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
624            ),
625        )
626
627    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
630def time_format(
631    dialect: DialectType = None,
632) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
633    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
634        """
635        Returns the time format for a given expression, unless it's equivalent
636        to the default time format of the dialect of interest.
637        """
638        time_format = self.format_time(expression)
639        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
640
641    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
644def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
645    """
646    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
647    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
648    columns are removed from the create statement.
649    """
650    has_schema = isinstance(expression.this, exp.Schema)
651    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
652
653    if has_schema and is_partitionable:
654        prop = expression.find(exp.PartitionedByProperty)
655        if prop and prop.this and not isinstance(prop.this, exp.Schema):
656            schema = expression.this
657            columns = {v.name.upper() for v in prop.this.expressions}
658            partitions = [col for col in schema.expressions if col.name.upper() in columns]
659            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
660            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
661            expression.set("this", schema)
662
663    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
666def parse_date_delta(
667    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
668) -> t.Callable[[t.List], E]:
669    def inner_func(args: t.List) -> E:
670        unit_based = len(args) == 3
671        this = args[2] if unit_based else seq_get(args, 0)
672        unit = args[0] if unit_based else exp.Literal.string("DAY")
673        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
674        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
675
676    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
679def parse_date_delta_with_interval(
680    expression_class: t.Type[E],
681) -> t.Callable[[t.List], t.Optional[E]]:
682    def func(args: t.List) -> t.Optional[E]:
683        if len(args) < 2:
684            return None
685
686        interval = args[1]
687
688        if not isinstance(interval, exp.Interval):
689            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
690
691        expression = interval.this
692        if expression and expression.is_string:
693            expression = exp.Literal.number(expression.this)
694
695        return expression_class(
696            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
697        )
698
699    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
702def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
703    unit = seq_get(args, 0)
704    this = seq_get(args, 1)
705
706    if isinstance(this, exp.Cast) and this.is_type("date"):
707        return exp.DateTrunc(unit=unit, this=this)
708    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
711def date_add_interval_sql(
712    data_type: str, kind: str
713) -> t.Callable[[Generator, exp.Expression], str]:
714    def func(self: Generator, expression: exp.Expression) -> str:
715        this = self.sql(expression, "this")
716        unit = expression.args.get("unit")
717        unit = exp.var(unit.name.upper() if unit else "DAY")
718        interval = exp.Interval(this=expression.expression, unit=unit)
719        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
720
721    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
724def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
725    return self.func(
726        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
727    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
730def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
731    if not expression.expression:
732        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
733    if expression.text("expression").lower() in TIMEZONES:
734        return self.sql(
735            exp.AtTimeZone(
736                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
737                zone=expression.expression,
738            )
739        )
740    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
743def locate_to_strposition(args: t.List) -> exp.Expression:
744    return exp.StrPosition(
745        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
746    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
749def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
750    return self.func(
751        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
752    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
755def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
756    return self.sql(
757        exp.Substring(
758            this=expression.this, start=exp.Literal.number(1), length=expression.expression
759        )
760    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
763def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
764    return self.sql(
765        exp.Substring(
766            this=expression.this,
767            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
768        )
769    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
772def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
773    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
776def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
777    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
781def encode_decode_sql(
782    self: Generator, expression: exp.Expression, name: str, replace: bool = True
783) -> str:
784    charset = expression.args.get("charset")
785    if charset and charset.name.lower() != "utf-8":
786        self.unsupported(f"Expected utf-8 character set, got {charset}.")
787
788    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
791def min_or_least(self: Generator, expression: exp.Min) -> str:
792    name = "LEAST" if expression.expressions else "MIN"
793    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
796def max_or_greatest(self: Generator, expression: exp.Max) -> str:
797    name = "GREATEST" if expression.expressions else "MAX"
798    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
801def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
802    cond = expression.this
803
804    if isinstance(expression.this, exp.Distinct):
805        cond = expression.this.expressions[0]
806        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
807
808    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
811def trim_sql(self: Generator, expression: exp.Trim) -> str:
812    target = self.sql(expression, "this")
813    trim_type = self.sql(expression, "position")
814    remove_chars = self.sql(expression, "expression")
815    collation = self.sql(expression, "collation")
816
817    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
818    if not remove_chars and not collation:
819        return self.trim_sql(expression)
820
821    trim_type = f"{trim_type} " if trim_type else ""
822    remove_chars = f"{remove_chars} " if remove_chars else ""
823    from_part = "FROM " if trim_type or remove_chars else ""
824    collation = f" COLLATE {collation}" if collation else ""
825    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
828def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
829    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
832def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
833    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
836def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
837    delim, *rest_args = expression.expressions
838    return self.sql(
839        reduce(
840            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
841            rest_args,
842        )
843    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
846def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
847    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
848    if bad_args:
849        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
850
851    return self.func(
852        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
853    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
856def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
857    bad_args = list(
858        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
859    )
860    if bad_args:
861        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
862
863    return self.func(
864        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
865    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
868def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
869    names = []
870    for agg in aggregations:
871        if isinstance(agg, exp.Alias):
872            names.append(agg.alias)
873        else:
874            """
875            This case corresponds to aggregations without aliases being used as suffixes
876            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
877            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
878            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
879            """
880            agg_all_unquoted = agg.transform(
881                lambda node: exp.Identifier(this=node.name, quoted=False)
882                if isinstance(node, exp.Identifier)
883                else node
884            )
885            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
886
887    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
890def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
891    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
895def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
896    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
899def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
900    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
903def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
904    a = self.sql(expression.left)
905    b = self.sql(expression.right)
906    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
909def is_parse_json(expression: exp.Expression) -> bool:
910    return isinstance(expression, exp.ParseJSON) or (
911        isinstance(expression, exp.Cast) and expression.is_type("json")
912    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
915def isnull_to_is_null(args: t.List) -> exp.Expression:
916    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
919def generatedasidentitycolumnconstraint_sql(
920    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
921) -> str:
922    start = self.sql(expression, "start") or "1"
923    increment = self.sql(expression, "increment") or "1"
924    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
927def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
928    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
929        if expression.args.get("count"):
930            self.unsupported(f"Only two arguments are supported in function {name}.")
931
932        return self.func(name, expression.this, expression.expression)
933
934    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
937def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
938    this = expression.this.copy()
939
940    return_type = expression.return_type
941    if return_type.is_type(exp.DataType.Type.DATE):
942        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
943        # can truncate timestamp strings, because some dialects can't cast them to DATE
944        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
945
946    expression.this.replace(exp.cast(this, return_type))
947    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
950def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
951    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
952        if cast and isinstance(expression, exp.TsOrDsAdd):
953            expression = ts_or_ds_add_cast(expression)
954
955        return self.func(
956            name,
957            exp.var(expression.text("unit").upper() or "DAY"),
958            expression.expression,
959            expression.this,
960        )
961
962    return _delta_sql
def prepend_dollar_to_path(expression: sqlglot.expressions.GetPath) -> sqlglot.expressions.GetPath:
965def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
966    from sqlglot.optimizer.simplify import simplify
967
968    # Makes sure the path will be evaluated correctly at runtime to include the path root.
969    # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
970    path = expression.expression
971    path = exp.func(
972        "if",
973        exp.func("startswith", path, "'['"),
974        exp.func("concat", "'$'", path),
975        exp.func("concat", "'$.'", path),
976    )
977
978    expression.expression.replace(simplify(path))
979    return expression
def path_to_jsonpath( name: str = 'JSON_EXTRACT') -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.GetPath], str]:
982def path_to_jsonpath(
983    name: str = "JSON_EXTRACT",
984) -> t.Callable[[Generator, exp.GetPath], str]:
985    def _transform(self: Generator, expression: exp.GetPath) -> str:
986        return rename_func(name)(self, prepend_dollar_to_path(expression))
987
988    return _transform
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
991def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
992    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
993    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
994    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
995
996    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
 999def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1000    """Remove table refs from columns in when statements."""
1001    alias = expression.this.args.get("alias")
1002
1003    normalize = (
1004        lambda identifier: self.dialect.normalize_identifier(identifier).name
1005        if identifier
1006        else None
1007    )
1008
1009    targets = {normalize(expression.this.this)}
1010
1011    if alias:
1012        targets.add(normalize(alias.this))
1013
1014    for when in expression.expressions:
1015        when.transform(
1016            lambda node: exp.column(node.this)
1017            if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1018            else node,
1019            copy=False,
1020        )
1021
1022    return self.merge_sql(expression)

Remove table refs from columns in when statements.