Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "athena", "presto", "trino"):
 165            klass.generator_class.TRY_SUPPORTED = False
 166
 167        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 168            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 169            for modifier in ("cluster", "distribute", "sort"):
 170                modifier_transforms.pop(modifier, None)
 171
 172            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 173
 174        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 175            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 176                TokenType.ANTI,
 177                TokenType.SEMI,
 178            }
 179
 180        return klass
 181
 182
 183class Dialect(metaclass=_Dialect):
 184    INDEX_OFFSET = 0
 185    """The base index offset for arrays."""
 186
 187    WEEK_OFFSET = 0
 188    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 189
 190    UNNEST_COLUMN_ONLY = False
 191    """Whether `UNNEST` table aliases are treated as column aliases."""
 192
 193    ALIAS_POST_TABLESAMPLE = False
 194    """Whether the table alias comes after tablesample."""
 195
 196    TABLESAMPLE_SIZE_IS_PERCENT = False
 197    """Whether a size in the table sample clause represents percentage."""
 198
 199    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 200    """Specifies the strategy according to which identifiers should be normalized."""
 201
 202    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 203    """Whether an unquoted identifier can start with a digit."""
 204
 205    DPIPE_IS_STRING_CONCAT = True
 206    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 207
 208    STRICT_STRING_CONCAT = False
 209    """Whether `CONCAT`'s arguments must be strings."""
 210
 211    SUPPORTS_USER_DEFINED_TYPES = True
 212    """Whether user-defined data types are supported."""
 213
 214    SUPPORTS_SEMI_ANTI_JOIN = True
 215    """Whether `SEMI` or `ANTI` joins are supported."""
 216
 217    NORMALIZE_FUNCTIONS: bool | str = "upper"
 218    """
 219    Determines how function names are going to be normalized.
 220    Possible values:
 221        "upper" or True: Convert names to uppercase.
 222        "lower": Convert names to lowercase.
 223        False: Disables function name normalization.
 224    """
 225
 226    LOG_BASE_FIRST: t.Optional[bool] = True
 227    """
 228    Whether the base comes first in the `LOG` function.
 229    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 230    """
 231
 232    NULL_ORDERING = "nulls_are_small"
 233    """
 234    Default `NULL` ordering method to use if not explicitly set.
 235    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 236    """
 237
 238    TYPED_DIVISION = False
 239    """
 240    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 241    False means `a / b` is always float division.
 242    True means `a / b` is integer division if both `a` and `b` are integers.
 243    """
 244
 245    SAFE_DIVISION = False
 246    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 247
 248    CONCAT_COALESCE = False
 249    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 250
 251    HEX_LOWERCASE = False
 252    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 253
 254    DATE_FORMAT = "'%Y-%m-%d'"
 255    DATEINT_FORMAT = "'%Y%m%d'"
 256    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 257
 258    TIME_MAPPING: t.Dict[str, str] = {}
 259    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 260
 261    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 262    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 263    FORMAT_MAPPING: t.Dict[str, str] = {}
 264    """
 265    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 266    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 267    """
 268
 269    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 270    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 271
 272    PSEUDOCOLUMNS: t.Set[str] = set()
 273    """
 274    Columns that are auto-generated by the engine corresponding to this dialect.
 275    For example, such columns may be excluded from `SELECT *` queries.
 276    """
 277
 278    PREFER_CTE_ALIAS_COLUMN = False
 279    """
 280    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 281    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 282    any projection aliases in the subquery.
 283
 284    For example,
 285        WITH y(c) AS (
 286            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288
 289        will be rewritten as
 290
 291        WITH y(c) AS (
 292            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 293        ) SELECT c FROM y;
 294    """
 295
 296    # --- Autofilled ---
 297
 298    tokenizer_class = Tokenizer
 299    parser_class = Parser
 300    generator_class = Generator
 301
 302    # A trie of the time_mapping keys
 303    TIME_TRIE: t.Dict = {}
 304    FORMAT_TRIE: t.Dict = {}
 305
 306    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 307    INVERSE_TIME_TRIE: t.Dict = {}
 308
 309    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 310
 311    # Delimiters for string literals and identifiers
 312    QUOTE_START = "'"
 313    QUOTE_END = "'"
 314    IDENTIFIER_START = '"'
 315    IDENTIFIER_END = '"'
 316
 317    # Delimiters for bit, hex, byte and unicode literals
 318    BIT_START: t.Optional[str] = None
 319    BIT_END: t.Optional[str] = None
 320    HEX_START: t.Optional[str] = None
 321    HEX_END: t.Optional[str] = None
 322    BYTE_START: t.Optional[str] = None
 323    BYTE_END: t.Optional[str] = None
 324    UNICODE_START: t.Optional[str] = None
 325    UNICODE_END: t.Optional[str] = None
 326
 327    # Separator of COPY statement parameters
 328    COPY_PARAMS_ARE_CSV = True
 329
 330    @classmethod
 331    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 332        """
 333        Look up a dialect in the global dialect registry and return it if it exists.
 334
 335        Args:
 336            dialect: The target dialect. If this is a string, it can be optionally followed by
 337                additional key-value pairs that are separated by commas and are used to specify
 338                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 339
 340        Example:
 341            >>> dialect = dialect_class = get_or_raise("duckdb")
 342            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 343
 344        Returns:
 345            The corresponding Dialect instance.
 346        """
 347
 348        if not dialect:
 349            return cls()
 350        if isinstance(dialect, _Dialect):
 351            return dialect()
 352        if isinstance(dialect, Dialect):
 353            return dialect
 354        if isinstance(dialect, str):
 355            try:
 356                dialect_name, *kv_pairs = dialect.split(",")
 357                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 358            except ValueError:
 359                raise ValueError(
 360                    f"Invalid dialect format: '{dialect}'. "
 361                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 362                )
 363
 364            result = cls.get(dialect_name.strip())
 365            if not result:
 366                from difflib import get_close_matches
 367
 368                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 369                if similar:
 370                    similar = f" Did you mean {similar}?"
 371
 372                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 373
 374            return result(**kwargs)
 375
 376        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 377
 378    @classmethod
 379    def format_time(
 380        cls, expression: t.Optional[str | exp.Expression]
 381    ) -> t.Optional[exp.Expression]:
 382        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 383        if isinstance(expression, str):
 384            return exp.Literal.string(
 385                # the time formats are quoted
 386                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 387            )
 388
 389        if expression and expression.is_string:
 390            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 391
 392        return expression
 393
 394    def __init__(self, **kwargs) -> None:
 395        normalization_strategy = kwargs.get("normalization_strategy")
 396
 397        if normalization_strategy is None:
 398            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 399        else:
 400            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 401
 402    def __eq__(self, other: t.Any) -> bool:
 403        # Does not currently take dialect state into account
 404        return type(self) == other
 405
 406    def __hash__(self) -> int:
 407        # Does not currently take dialect state into account
 408        return hash(type(self))
 409
 410    def normalize_identifier(self, expression: E) -> E:
 411        """
 412        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 413
 414        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 415        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 416        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 417        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 418
 419        There are also dialects like Spark, which are case-insensitive even when quotes are
 420        present, and dialects like MySQL, whose resolution rules match those employed by the
 421        underlying operating system, for example they may always be case-sensitive in Linux.
 422
 423        Finally, the normalization behavior of some engines can even be controlled through flags,
 424        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 425
 426        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 427        that it can analyze queries in the optimizer and successfully capture their semantics.
 428        """
 429        if (
 430            isinstance(expression, exp.Identifier)
 431            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 432            and (
 433                not expression.quoted
 434                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 435            )
 436        ):
 437            expression.set(
 438                "this",
 439                (
 440                    expression.this.upper()
 441                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 442                    else expression.this.lower()
 443                ),
 444            )
 445
 446        return expression
 447
 448    def case_sensitive(self, text: str) -> bool:
 449        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 450        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 451            return False
 452
 453        unsafe = (
 454            str.islower
 455            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 456            else str.isupper
 457        )
 458        return any(unsafe(char) for char in text)
 459
 460    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 461        """Checks if text can be identified given an identify option.
 462
 463        Args:
 464            text: The text to check.
 465            identify:
 466                `"always"` or `True`: Always returns `True`.
 467                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 468
 469        Returns:
 470            Whether the given text can be identified.
 471        """
 472        if identify is True or identify == "always":
 473            return True
 474
 475        if identify == "safe":
 476            return not self.case_sensitive(text)
 477
 478        return False
 479
 480    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 481        """
 482        Adds quotes to a given identifier.
 483
 484        Args:
 485            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 486            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 487                "unsafe", with respect to its characters and this dialect's normalization strategy.
 488        """
 489        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 490            name = expression.this
 491            expression.set(
 492                "quoted",
 493                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 494            )
 495
 496        return expression
 497
 498    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 499        if isinstance(path, exp.Literal):
 500            path_text = path.name
 501            if path.is_number:
 502                path_text = f"[{path_text}]"
 503
 504            try:
 505                return parse_json_path(path_text)
 506            except ParseError as e:
 507                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 508
 509        return path
 510
 511    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 512        return self.parser(**opts).parse(self.tokenize(sql), sql)
 513
 514    def parse_into(
 515        self, expression_type: exp.IntoType, sql: str, **opts
 516    ) -> t.List[t.Optional[exp.Expression]]:
 517        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 518
 519    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 520        return self.generator(**opts).generate(expression, copy=copy)
 521
 522    def transpile(self, sql: str, **opts) -> t.List[str]:
 523        return [
 524            self.generate(expression, copy=False, **opts) if expression else ""
 525            for expression in self.parse(sql)
 526        ]
 527
 528    def tokenize(self, sql: str) -> t.List[Token]:
 529        return self.tokenizer.tokenize(sql)
 530
 531    @property
 532    def tokenizer(self) -> Tokenizer:
 533        if not hasattr(self, "_tokenizer"):
 534            self._tokenizer = self.tokenizer_class(dialect=self)
 535        return self._tokenizer
 536
 537    def parser(self, **opts) -> Parser:
 538        return self.parser_class(dialect=self, **opts)
 539
 540    def generator(self, **opts) -> Generator:
 541        return self.generator_class(dialect=self, **opts)
 542
 543
 544DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 545
 546
 547def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 548    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 549
 550
 551def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 552    if expression.args.get("accuracy"):
 553        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 554    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 555
 556
 557def if_sql(
 558    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 559) -> t.Callable[[Generator, exp.If], str]:
 560    def _if_sql(self: Generator, expression: exp.If) -> str:
 561        return self.func(
 562            name,
 563            expression.this,
 564            expression.args.get("true"),
 565            expression.args.get("false") or false_value,
 566        )
 567
 568    return _if_sql
 569
 570
 571def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 572    this = expression.this
 573    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 574        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 575
 576    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 577
 578
 579def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 580    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 581
 582
 583def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 584    elem = seq_get(expression.expressions, 0)
 585    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 586        return self.func("ARRAY", elem)
 587    return inline_array_sql(self, expression)
 588
 589
 590def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 591    return self.like_sql(
 592        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 593    )
 594
 595
 596def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 597    zone = self.sql(expression, "this")
 598    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 599
 600
 601def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 602    if expression.args.get("recursive"):
 603        self.unsupported("Recursive CTEs are unsupported")
 604        expression.args["recursive"] = False
 605    return self.with_sql(expression)
 606
 607
 608def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 609    n = self.sql(expression, "this")
 610    d = self.sql(expression, "expression")
 611    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 612
 613
 614def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 615    self.unsupported("TABLESAMPLE unsupported")
 616    return self.sql(expression.this)
 617
 618
 619def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 620    self.unsupported("PIVOT unsupported")
 621    return ""
 622
 623
 624def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 625    return self.cast_sql(expression)
 626
 627
 628def no_comment_column_constraint_sql(
 629    self: Generator, expression: exp.CommentColumnConstraint
 630) -> str:
 631    self.unsupported("CommentColumnConstraint unsupported")
 632    return ""
 633
 634
 635def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 636    self.unsupported("MAP_FROM_ENTRIES unsupported")
 637    return ""
 638
 639
 640def str_position_sql(
 641    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 642) -> str:
 643    this = self.sql(expression, "this")
 644    substr = self.sql(expression, "substr")
 645    position = self.sql(expression, "position")
 646    instance = expression.args.get("instance") if generate_instance else None
 647    position_offset = ""
 648
 649    if position:
 650        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 651        this = self.func("SUBSTR", this, position)
 652        position_offset = f" + {position} - 1"
 653
 654    return self.func("STRPOS", this, substr, instance) + position_offset
 655
 656
 657def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 658    return (
 659        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 660    )
 661
 662
 663def var_map_sql(
 664    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 665) -> str:
 666    keys = expression.args["keys"]
 667    values = expression.args["values"]
 668
 669    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 670        self.unsupported("Cannot convert array columns into map.")
 671        return self.func(map_func_name, keys, values)
 672
 673    args = []
 674    for key, value in zip(keys.expressions, values.expressions):
 675        args.append(self.sql(key))
 676        args.append(self.sql(value))
 677
 678    return self.func(map_func_name, *args)
 679
 680
 681def build_formatted_time(
 682    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 683) -> t.Callable[[t.List], E]:
 684    """Helper used for time expressions.
 685
 686    Args:
 687        exp_class: the expression class to instantiate.
 688        dialect: target sql dialect.
 689        default: the default format, True being time.
 690
 691    Returns:
 692        A callable that can be used to return the appropriately formatted time expression.
 693    """
 694
 695    def _builder(args: t.List):
 696        return exp_class(
 697            this=seq_get(args, 0),
 698            format=Dialect[dialect].format_time(
 699                seq_get(args, 1)
 700                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 701            ),
 702        )
 703
 704    return _builder
 705
 706
 707def time_format(
 708    dialect: DialectType = None,
 709) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 710    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 711        """
 712        Returns the time format for a given expression, unless it's equivalent
 713        to the default time format of the dialect of interest.
 714        """
 715        time_format = self.format_time(expression)
 716        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 717
 718    return _time_format
 719
 720
 721def build_date_delta(
 722    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 723) -> t.Callable[[t.List], E]:
 724    def _builder(args: t.List) -> E:
 725        unit_based = len(args) == 3
 726        this = args[2] if unit_based else seq_get(args, 0)
 727        unit = args[0] if unit_based else exp.Literal.string("DAY")
 728        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 729        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 730
 731    return _builder
 732
 733
 734def build_date_delta_with_interval(
 735    expression_class: t.Type[E],
 736) -> t.Callable[[t.List], t.Optional[E]]:
 737    def _builder(args: t.List) -> t.Optional[E]:
 738        if len(args) < 2:
 739            return None
 740
 741        interval = args[1]
 742
 743        if not isinstance(interval, exp.Interval):
 744            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 745
 746        expression = interval.this
 747        if expression and expression.is_string:
 748            expression = exp.Literal.number(expression.this)
 749
 750        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 751
 752    return _builder
 753
 754
 755def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 756    unit = seq_get(args, 0)
 757    this = seq_get(args, 1)
 758
 759    if isinstance(this, exp.Cast) and this.is_type("date"):
 760        return exp.DateTrunc(unit=unit, this=this)
 761    return exp.TimestampTrunc(this=this, unit=unit)
 762
 763
 764def date_add_interval_sql(
 765    data_type: str, kind: str
 766) -> t.Callable[[Generator, exp.Expression], str]:
 767    def func(self: Generator, expression: exp.Expression) -> str:
 768        this = self.sql(expression, "this")
 769        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 770        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 771
 772    return func
 773
 774
 775def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 776    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 777        args = [unit_to_str(expression), expression.this]
 778        if zone:
 779            args.append(expression.args.get("zone"))
 780        return self.func("DATE_TRUNC", *args)
 781
 782    return _timestamptrunc_sql
 783
 784
 785def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 786    if not expression.expression:
 787        from sqlglot.optimizer.annotate_types import annotate_types
 788
 789        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 790        return self.sql(exp.cast(expression.this, target_type))
 791    if expression.text("expression").lower() in TIMEZONES:
 792        return self.sql(
 793            exp.AtTimeZone(
 794                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 795                zone=expression.expression,
 796            )
 797        )
 798    return self.func("TIMESTAMP", expression.this, expression.expression)
 799
 800
 801def locate_to_strposition(args: t.List) -> exp.Expression:
 802    return exp.StrPosition(
 803        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 804    )
 805
 806
 807def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 808    return self.func(
 809        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 810    )
 811
 812
 813def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 814    return self.sql(
 815        exp.Substring(
 816            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 817        )
 818    )
 819
 820
 821def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 822    return self.sql(
 823        exp.Substring(
 824            this=expression.this,
 825            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 826        )
 827    )
 828
 829
 830def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 831    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 832
 833
 834def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 835    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 836
 837
 838# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 839def encode_decode_sql(
 840    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 841) -> str:
 842    charset = expression.args.get("charset")
 843    if charset and charset.name.lower() != "utf-8":
 844        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 845
 846    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 847
 848
 849def min_or_least(self: Generator, expression: exp.Min) -> str:
 850    name = "LEAST" if expression.expressions else "MIN"
 851    return rename_func(name)(self, expression)
 852
 853
 854def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 855    name = "GREATEST" if expression.expressions else "MAX"
 856    return rename_func(name)(self, expression)
 857
 858
 859def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 860    cond = expression.this
 861
 862    if isinstance(expression.this, exp.Distinct):
 863        cond = expression.this.expressions[0]
 864        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 865
 866    return self.func("sum", exp.func("if", cond, 1, 0))
 867
 868
 869def trim_sql(self: Generator, expression: exp.Trim) -> str:
 870    target = self.sql(expression, "this")
 871    trim_type = self.sql(expression, "position")
 872    remove_chars = self.sql(expression, "expression")
 873    collation = self.sql(expression, "collation")
 874
 875    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 876    if not remove_chars and not collation:
 877        return self.trim_sql(expression)
 878
 879    trim_type = f"{trim_type} " if trim_type else ""
 880    remove_chars = f"{remove_chars} " if remove_chars else ""
 881    from_part = "FROM " if trim_type or remove_chars else ""
 882    collation = f" COLLATE {collation}" if collation else ""
 883    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 884
 885
 886def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 887    return self.func("STRPTIME", expression.this, self.format_time(expression))
 888
 889
 890def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 891    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 892
 893
 894def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 895    delim, *rest_args = expression.expressions
 896    return self.sql(
 897        reduce(
 898            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 899            rest_args,
 900        )
 901    )
 902
 903
 904def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 905    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 906    if bad_args:
 907        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 908
 909    return self.func(
 910        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 911    )
 912
 913
 914def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 915    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 916    if bad_args:
 917        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 918
 919    return self.func(
 920        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 921    )
 922
 923
 924def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 925    names = []
 926    for agg in aggregations:
 927        if isinstance(agg, exp.Alias):
 928            names.append(agg.alias)
 929        else:
 930            """
 931            This case corresponds to aggregations without aliases being used as suffixes
 932            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 933            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 934            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 935            """
 936            agg_all_unquoted = agg.transform(
 937                lambda node: (
 938                    exp.Identifier(this=node.name, quoted=False)
 939                    if isinstance(node, exp.Identifier)
 940                    else node
 941                )
 942            )
 943            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 944
 945    return names
 946
 947
 948def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 949    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 950
 951
 952# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 953def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 954    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 955
 956
 957def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 958    return self.func("MAX", expression.this)
 959
 960
 961def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 962    a = self.sql(expression.left)
 963    b = self.sql(expression.right)
 964    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 965
 966
 967def is_parse_json(expression: exp.Expression) -> bool:
 968    return isinstance(expression, exp.ParseJSON) or (
 969        isinstance(expression, exp.Cast) and expression.is_type("json")
 970    )
 971
 972
 973def isnull_to_is_null(args: t.List) -> exp.Expression:
 974    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 975
 976
 977def generatedasidentitycolumnconstraint_sql(
 978    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 979) -> str:
 980    start = self.sql(expression, "start") or "1"
 981    increment = self.sql(expression, "increment") or "1"
 982    return f"IDENTITY({start}, {increment})"
 983
 984
 985def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 986    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 987        if expression.args.get("count"):
 988            self.unsupported(f"Only two arguments are supported in function {name}.")
 989
 990        return self.func(name, expression.this, expression.expression)
 991
 992    return _arg_max_or_min_sql
 993
 994
 995def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 996    this = expression.this.copy()
 997
 998    return_type = expression.return_type
 999    if return_type.is_type(exp.DataType.Type.DATE):
1000        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1001        # can truncate timestamp strings, because some dialects can't cast them to DATE
1002        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1003
1004    expression.this.replace(exp.cast(this, return_type))
1005    return expression
1006
1007
1008def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1009    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1010        if cast and isinstance(expression, exp.TsOrDsAdd):
1011            expression = ts_or_ds_add_cast(expression)
1012
1013        return self.func(
1014            name,
1015            unit_to_var(expression),
1016            expression.expression,
1017            expression.this,
1018        )
1019
1020    return _delta_sql
1021
1022
1023def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1024    unit = expression.args.get("unit")
1025
1026    if isinstance(unit, exp.Placeholder):
1027        return unit
1028    if unit:
1029        return exp.Literal.string(unit.name)
1030    return exp.Literal.string(default) if default else None
1031
1032
1033def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1034    unit = expression.args.get("unit")
1035
1036    if isinstance(unit, (exp.Var, exp.Placeholder)):
1037        return unit
1038    return exp.Var(this=default) if default else None
1039
1040
1041def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1042    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1043    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1044    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1045
1046    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1047
1048
1049def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1050    """Remove table refs from columns in when statements."""
1051    alias = expression.this.args.get("alias")
1052
1053    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1054        return self.dialect.normalize_identifier(identifier).name if identifier else None
1055
1056    targets = {normalize(expression.this.this)}
1057
1058    if alias:
1059        targets.add(normalize(alias.this))
1060
1061    for when in expression.expressions:
1062        when.transform(
1063            lambda node: (
1064                exp.column(node.this)
1065                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1066                else node
1067            ),
1068            copy=False,
1069        )
1070
1071    return self.merge_sql(expression)
1072
1073
1074def build_json_extract_path(
1075    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1076) -> t.Callable[[t.List], F]:
1077    def _builder(args: t.List) -> F:
1078        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1079        for arg in args[1:]:
1080            if not isinstance(arg, exp.Literal):
1081                # We use the fallback parser because we can't really transpile non-literals safely
1082                return expr_type.from_arg_list(args)
1083
1084            text = arg.name
1085            if is_int(text):
1086                index = int(text)
1087                segments.append(
1088                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1089                )
1090            else:
1091                segments.append(exp.JSONPathKey(this=text))
1092
1093        # This is done to avoid failing in the expression validator due to the arg count
1094        del args[2:]
1095        return expr_type(
1096            this=seq_get(args, 0),
1097            expression=exp.JSONPath(expressions=segments),
1098            only_json_types=arrow_req_json_type,
1099        )
1100
1101    return _builder
1102
1103
1104def json_extract_segments(
1105    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1106) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1107    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1108        path = expression.expression
1109        if not isinstance(path, exp.JSONPath):
1110            return rename_func(name)(self, expression)
1111
1112        segments = []
1113        for segment in path.expressions:
1114            path = self.sql(segment)
1115            if path:
1116                if isinstance(segment, exp.JSONPathPart) and (
1117                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1118                ):
1119                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1120
1121                segments.append(path)
1122
1123        if op:
1124            return f" {op} ".join([self.sql(expression.this), *segments])
1125        return self.func(name, expression.this, *segments)
1126
1127    return _json_extract_segments
1128
1129
1130def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1131    if isinstance(expression.this, exp.JSONPathWildcard):
1132        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1133
1134    return expression.name
1135
1136
1137def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1138    cond = expression.expression
1139    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1140        alias = cond.expressions[0]
1141        cond = cond.this
1142    elif isinstance(cond, exp.Predicate):
1143        alias = "_u"
1144    else:
1145        self.unsupported("Unsupported filter condition")
1146        return ""
1147
1148    unnest = exp.Unnest(expressions=[expression.this])
1149    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1150    return self.sql(exp.Array(expressions=[filtered]))
1151
1152
1153def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1154    return self.func(
1155        "TO_NUMBER",
1156        expression.this,
1157        expression.args.get("format"),
1158        expression.args.get("nlsparam"),
1159    )
1160
1161
1162def build_default_decimal_type(
1163    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1164) -> t.Callable[[exp.DataType], exp.DataType]:
1165    def _builder(dtype: exp.DataType) -> exp.DataType:
1166        if dtype.expressions or precision is None:
1167            return dtype
1168
1169        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1170        return exp.DataType.build(f"DECIMAL({params})")
1171
1172    return _builder
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
184class Dialect(metaclass=_Dialect):
185    INDEX_OFFSET = 0
186    """The base index offset for arrays."""
187
188    WEEK_OFFSET = 0
189    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
190
191    UNNEST_COLUMN_ONLY = False
192    """Whether `UNNEST` table aliases are treated as column aliases."""
193
194    ALIAS_POST_TABLESAMPLE = False
195    """Whether the table alias comes after tablesample."""
196
197    TABLESAMPLE_SIZE_IS_PERCENT = False
198    """Whether a size in the table sample clause represents percentage."""
199
200    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
201    """Specifies the strategy according to which identifiers should be normalized."""
202
203    IDENTIFIERS_CAN_START_WITH_DIGIT = False
204    """Whether an unquoted identifier can start with a digit."""
205
206    DPIPE_IS_STRING_CONCAT = True
207    """Whether the DPIPE token (`||`) is a string concatenation operator."""
208
209    STRICT_STRING_CONCAT = False
210    """Whether `CONCAT`'s arguments must be strings."""
211
212    SUPPORTS_USER_DEFINED_TYPES = True
213    """Whether user-defined data types are supported."""
214
215    SUPPORTS_SEMI_ANTI_JOIN = True
216    """Whether `SEMI` or `ANTI` joins are supported."""
217
218    NORMALIZE_FUNCTIONS: bool | str = "upper"
219    """
220    Determines how function names are going to be normalized.
221    Possible values:
222        "upper" or True: Convert names to uppercase.
223        "lower": Convert names to lowercase.
224        False: Disables function name normalization.
225    """
226
227    LOG_BASE_FIRST: t.Optional[bool] = True
228    """
229    Whether the base comes first in the `LOG` function.
230    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
231    """
232
233    NULL_ORDERING = "nulls_are_small"
234    """
235    Default `NULL` ordering method to use if not explicitly set.
236    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
237    """
238
239    TYPED_DIVISION = False
240    """
241    Whether the behavior of `a / b` depends on the types of `a` and `b`.
242    False means `a / b` is always float division.
243    True means `a / b` is integer division if both `a` and `b` are integers.
244    """
245
246    SAFE_DIVISION = False
247    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
248
249    CONCAT_COALESCE = False
250    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
251
252    HEX_LOWERCASE = False
253    """Whether the `HEX` function returns a lowercase hexadecimal string."""
254
255    DATE_FORMAT = "'%Y-%m-%d'"
256    DATEINT_FORMAT = "'%Y%m%d'"
257    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
258
259    TIME_MAPPING: t.Dict[str, str] = {}
260    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
261
262    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
263    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
264    FORMAT_MAPPING: t.Dict[str, str] = {}
265    """
266    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
267    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
268    """
269
270    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
271    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
272
273    PSEUDOCOLUMNS: t.Set[str] = set()
274    """
275    Columns that are auto-generated by the engine corresponding to this dialect.
276    For example, such columns may be excluded from `SELECT *` queries.
277    """
278
279    PREFER_CTE_ALIAS_COLUMN = False
280    """
281    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
282    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
283    any projection aliases in the subquery.
284
285    For example,
286        WITH y(c) AS (
287            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289
290        will be rewritten as
291
292        WITH y(c) AS (
293            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
294        ) SELECT c FROM y;
295    """
296
297    # --- Autofilled ---
298
299    tokenizer_class = Tokenizer
300    parser_class = Parser
301    generator_class = Generator
302
303    # A trie of the time_mapping keys
304    TIME_TRIE: t.Dict = {}
305    FORMAT_TRIE: t.Dict = {}
306
307    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
308    INVERSE_TIME_TRIE: t.Dict = {}
309
310    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
311
312    # Delimiters for string literals and identifiers
313    QUOTE_START = "'"
314    QUOTE_END = "'"
315    IDENTIFIER_START = '"'
316    IDENTIFIER_END = '"'
317
318    # Delimiters for bit, hex, byte and unicode literals
319    BIT_START: t.Optional[str] = None
320    BIT_END: t.Optional[str] = None
321    HEX_START: t.Optional[str] = None
322    HEX_END: t.Optional[str] = None
323    BYTE_START: t.Optional[str] = None
324    BYTE_END: t.Optional[str] = None
325    UNICODE_START: t.Optional[str] = None
326    UNICODE_END: t.Optional[str] = None
327
328    # Separator of COPY statement parameters
329    COPY_PARAMS_ARE_CSV = True
330
331    @classmethod
332    def get_or_raise(cls, dialect: DialectType) -> Dialect:
333        """
334        Look up a dialect in the global dialect registry and return it if it exists.
335
336        Args:
337            dialect: The target dialect. If this is a string, it can be optionally followed by
338                additional key-value pairs that are separated by commas and are used to specify
339                dialect settings, such as whether the dialect's identifiers are case-sensitive.
340
341        Example:
342            >>> dialect = dialect_class = get_or_raise("duckdb")
343            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
344
345        Returns:
346            The corresponding Dialect instance.
347        """
348
349        if not dialect:
350            return cls()
351        if isinstance(dialect, _Dialect):
352            return dialect()
353        if isinstance(dialect, Dialect):
354            return dialect
355        if isinstance(dialect, str):
356            try:
357                dialect_name, *kv_pairs = dialect.split(",")
358                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
359            except ValueError:
360                raise ValueError(
361                    f"Invalid dialect format: '{dialect}'. "
362                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
363                )
364
365            result = cls.get(dialect_name.strip())
366            if not result:
367                from difflib import get_close_matches
368
369                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
370                if similar:
371                    similar = f" Did you mean {similar}?"
372
373                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
374
375            return result(**kwargs)
376
377        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
378
379    @classmethod
380    def format_time(
381        cls, expression: t.Optional[str | exp.Expression]
382    ) -> t.Optional[exp.Expression]:
383        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
384        if isinstance(expression, str):
385            return exp.Literal.string(
386                # the time formats are quoted
387                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
388            )
389
390        if expression and expression.is_string:
391            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
392
393        return expression
394
395    def __init__(self, **kwargs) -> None:
396        normalization_strategy = kwargs.get("normalization_strategy")
397
398        if normalization_strategy is None:
399            self.normalization_strategy = self.NORMALIZATION_STRATEGY
400        else:
401            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
402
403    def __eq__(self, other: t.Any) -> bool:
404        # Does not currently take dialect state into account
405        return type(self) == other
406
407    def __hash__(self) -> int:
408        # Does not currently take dialect state into account
409        return hash(type(self))
410
411    def normalize_identifier(self, expression: E) -> E:
412        """
413        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
414
415        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
416        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
417        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
418        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
419
420        There are also dialects like Spark, which are case-insensitive even when quotes are
421        present, and dialects like MySQL, whose resolution rules match those employed by the
422        underlying operating system, for example they may always be case-sensitive in Linux.
423
424        Finally, the normalization behavior of some engines can even be controlled through flags,
425        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
426
427        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
428        that it can analyze queries in the optimizer and successfully capture their semantics.
429        """
430        if (
431            isinstance(expression, exp.Identifier)
432            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
433            and (
434                not expression.quoted
435                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
436            )
437        ):
438            expression.set(
439                "this",
440                (
441                    expression.this.upper()
442                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
443                    else expression.this.lower()
444                ),
445            )
446
447        return expression
448
449    def case_sensitive(self, text: str) -> bool:
450        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
451        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
452            return False
453
454        unsafe = (
455            str.islower
456            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
457            else str.isupper
458        )
459        return any(unsafe(char) for char in text)
460
461    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
462        """Checks if text can be identified given an identify option.
463
464        Args:
465            text: The text to check.
466            identify:
467                `"always"` or `True`: Always returns `True`.
468                `"safe"`: Only returns `True` if the identifier is case-insensitive.
469
470        Returns:
471            Whether the given text can be identified.
472        """
473        if identify is True or identify == "always":
474            return True
475
476        if identify == "safe":
477            return not self.case_sensitive(text)
478
479        return False
480
481    def quote_identifier(self, expression: E, identify: bool = True) -> E:
482        """
483        Adds quotes to a given identifier.
484
485        Args:
486            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
487            identify: If set to `False`, the quotes will only be added if the identifier is deemed
488                "unsafe", with respect to its characters and this dialect's normalization strategy.
489        """
490        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
491            name = expression.this
492            expression.set(
493                "quoted",
494                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
495            )
496
497        return expression
498
499    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
500        if isinstance(path, exp.Literal):
501            path_text = path.name
502            if path.is_number:
503                path_text = f"[{path_text}]"
504
505            try:
506                return parse_json_path(path_text)
507            except ParseError as e:
508                logger.warning(f"Invalid JSON path syntax. {str(e)}")
509
510        return path
511
512    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
513        return self.parser(**opts).parse(self.tokenize(sql), sql)
514
515    def parse_into(
516        self, expression_type: exp.IntoType, sql: str, **opts
517    ) -> t.List[t.Optional[exp.Expression]]:
518        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
519
520    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
521        return self.generator(**opts).generate(expression, copy=copy)
522
523    def transpile(self, sql: str, **opts) -> t.List[str]:
524        return [
525            self.generate(expression, copy=False, **opts) if expression else ""
526            for expression in self.parse(sql)
527        ]
528
529    def tokenize(self, sql: str) -> t.List[Token]:
530        return self.tokenizer.tokenize(sql)
531
532    @property
533    def tokenizer(self) -> Tokenizer:
534        if not hasattr(self, "_tokenizer"):
535            self._tokenizer = self.tokenizer_class(dialect=self)
536        return self._tokenizer
537
538    def parser(self, **opts) -> Parser:
539        return self.parser_class(dialect=self, **opts)
540
541    def generator(self, **opts) -> Generator:
542        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
395    def __init__(self, **kwargs) -> None:
396        normalization_strategy = kwargs.get("normalization_strategy")
397
398        if normalization_strategy is None:
399            self.normalization_strategy = self.NORMALIZATION_STRATEGY
400        else:
401            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
331    @classmethod
332    def get_or_raise(cls, dialect: DialectType) -> Dialect:
333        """
334        Look up a dialect in the global dialect registry and return it if it exists.
335
336        Args:
337            dialect: The target dialect. If this is a string, it can be optionally followed by
338                additional key-value pairs that are separated by commas and are used to specify
339                dialect settings, such as whether the dialect's identifiers are case-sensitive.
340
341        Example:
342            >>> dialect = dialect_class = get_or_raise("duckdb")
343            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
344
345        Returns:
346            The corresponding Dialect instance.
347        """
348
349        if not dialect:
350            return cls()
351        if isinstance(dialect, _Dialect):
352            return dialect()
353        if isinstance(dialect, Dialect):
354            return dialect
355        if isinstance(dialect, str):
356            try:
357                dialect_name, *kv_pairs = dialect.split(",")
358                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
359            except ValueError:
360                raise ValueError(
361                    f"Invalid dialect format: '{dialect}'. "
362                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
363                )
364
365            result = cls.get(dialect_name.strip())
366            if not result:
367                from difflib import get_close_matches
368
369                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
370                if similar:
371                    similar = f" Did you mean {similar}?"
372
373                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
374
375            return result(**kwargs)
376
377        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
379    @classmethod
380    def format_time(
381        cls, expression: t.Optional[str | exp.Expression]
382    ) -> t.Optional[exp.Expression]:
383        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
384        if isinstance(expression, str):
385            return exp.Literal.string(
386                # the time formats are quoted
387                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
388            )
389
390        if expression and expression.is_string:
391            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
392
393        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
411    def normalize_identifier(self, expression: E) -> E:
412        """
413        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
414
415        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
416        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
417        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
418        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
419
420        There are also dialects like Spark, which are case-insensitive even when quotes are
421        present, and dialects like MySQL, whose resolution rules match those employed by the
422        underlying operating system, for example they may always be case-sensitive in Linux.
423
424        Finally, the normalization behavior of some engines can even be controlled through flags,
425        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
426
427        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
428        that it can analyze queries in the optimizer and successfully capture their semantics.
429        """
430        if (
431            isinstance(expression, exp.Identifier)
432            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
433            and (
434                not expression.quoted
435                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
436            )
437        ):
438            expression.set(
439                "this",
440                (
441                    expression.this.upper()
442                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
443                    else expression.this.lower()
444                ),
445            )
446
447        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
449    def case_sensitive(self, text: str) -> bool:
450        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
451        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
452            return False
453
454        unsafe = (
455            str.islower
456            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
457            else str.isupper
458        )
459        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
461    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
462        """Checks if text can be identified given an identify option.
463
464        Args:
465            text: The text to check.
466            identify:
467                `"always"` or `True`: Always returns `True`.
468                `"safe"`: Only returns `True` if the identifier is case-insensitive.
469
470        Returns:
471            Whether the given text can be identified.
472        """
473        if identify is True or identify == "always":
474            return True
475
476        if identify == "safe":
477            return not self.case_sensitive(text)
478
479        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
481    def quote_identifier(self, expression: E, identify: bool = True) -> E:
482        """
483        Adds quotes to a given identifier.
484
485        Args:
486            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
487            identify: If set to `False`, the quotes will only be added if the identifier is deemed
488                "unsafe", with respect to its characters and this dialect's normalization strategy.
489        """
490        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
491            name = expression.this
492            expression.set(
493                "quoted",
494                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
495            )
496
497        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
499    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
500        if isinstance(path, exp.Literal):
501            path_text = path.name
502            if path.is_number:
503                path_text = f"[{path_text}]"
504
505            try:
506                return parse_json_path(path_text)
507            except ParseError as e:
508                logger.warning(f"Invalid JSON path syntax. {str(e)}")
509
510        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
512    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
513        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
515    def parse_into(
516        self, expression_type: exp.IntoType, sql: str, **opts
517    ) -> t.List[t.Optional[exp.Expression]]:
518        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
520    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
521        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
523    def transpile(self, sql: str, **opts) -> t.List[str]:
524        return [
525            self.generate(expression, copy=False, **opts) if expression else ""
526            for expression in self.parse(sql)
527        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
529    def tokenize(self, sql: str) -> t.List[Token]:
530        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
532    @property
533    def tokenizer(self) -> Tokenizer:
534        if not hasattr(self, "_tokenizer"):
535            self._tokenizer = self.tokenizer_class(dialect=self)
536        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
538    def parser(self, **opts) -> Parser:
539        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
541    def generator(self, **opts) -> Generator:
542        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
548def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
549    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
552def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
553    if expression.args.get("accuracy"):
554        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
555    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
558def if_sql(
559    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
560) -> t.Callable[[Generator, exp.If], str]:
561    def _if_sql(self: Generator, expression: exp.If) -> str:
562        return self.func(
563            name,
564            expression.this,
565            expression.args.get("true"),
566            expression.args.get("false") or false_value,
567        )
568
569    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
572def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
573    this = expression.this
574    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
575        this.replace(exp.cast(this, exp.DataType.Type.JSON))
576
577    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
580def inline_array_sql(self: Generator, expression: exp.Array) -> str:
581    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
584def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
585    elem = seq_get(expression.expressions, 0)
586    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
587        return self.func("ARRAY", elem)
588    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
591def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
592    return self.like_sql(
593        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
594    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
597def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
598    zone = self.sql(expression, "this")
599    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
602def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
603    if expression.args.get("recursive"):
604        self.unsupported("Recursive CTEs are unsupported")
605        expression.args["recursive"] = False
606    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
609def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
610    n = self.sql(expression, "this")
611    d = self.sql(expression, "expression")
612    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
615def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
616    self.unsupported("TABLESAMPLE unsupported")
617    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
620def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
621    self.unsupported("PIVOT unsupported")
622    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
625def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
626    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
629def no_comment_column_constraint_sql(
630    self: Generator, expression: exp.CommentColumnConstraint
631) -> str:
632    self.unsupported("CommentColumnConstraint unsupported")
633    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
636def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
637    self.unsupported("MAP_FROM_ENTRIES unsupported")
638    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
641def str_position_sql(
642    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
643) -> str:
644    this = self.sql(expression, "this")
645    substr = self.sql(expression, "substr")
646    position = self.sql(expression, "position")
647    instance = expression.args.get("instance") if generate_instance else None
648    position_offset = ""
649
650    if position:
651        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
652        this = self.func("SUBSTR", this, position)
653        position_offset = f" + {position} - 1"
654
655    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
658def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
659    return (
660        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
661    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
664def var_map_sql(
665    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
666) -> str:
667    keys = expression.args["keys"]
668    values = expression.args["values"]
669
670    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
671        self.unsupported("Cannot convert array columns into map.")
672        return self.func(map_func_name, keys, values)
673
674    args = []
675    for key, value in zip(keys.expressions, values.expressions):
676        args.append(self.sql(key))
677        args.append(self.sql(value))
678
679    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
682def build_formatted_time(
683    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
684) -> t.Callable[[t.List], E]:
685    """Helper used for time expressions.
686
687    Args:
688        exp_class: the expression class to instantiate.
689        dialect: target sql dialect.
690        default: the default format, True being time.
691
692    Returns:
693        A callable that can be used to return the appropriately formatted time expression.
694    """
695
696    def _builder(args: t.List):
697        return exp_class(
698            this=seq_get(args, 0),
699            format=Dialect[dialect].format_time(
700                seq_get(args, 1)
701                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
702            ),
703        )
704
705    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
708def time_format(
709    dialect: DialectType = None,
710) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
711    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
712        """
713        Returns the time format for a given expression, unless it's equivalent
714        to the default time format of the dialect of interest.
715        """
716        time_format = self.format_time(expression)
717        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
718
719    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
722def build_date_delta(
723    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
724) -> t.Callable[[t.List], E]:
725    def _builder(args: t.List) -> E:
726        unit_based = len(args) == 3
727        this = args[2] if unit_based else seq_get(args, 0)
728        unit = args[0] if unit_based else exp.Literal.string("DAY")
729        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
730        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
731
732    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
735def build_date_delta_with_interval(
736    expression_class: t.Type[E],
737) -> t.Callable[[t.List], t.Optional[E]]:
738    def _builder(args: t.List) -> t.Optional[E]:
739        if len(args) < 2:
740            return None
741
742        interval = args[1]
743
744        if not isinstance(interval, exp.Interval):
745            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
746
747        expression = interval.this
748        if expression and expression.is_string:
749            expression = exp.Literal.number(expression.this)
750
751        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
752
753    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
756def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
757    unit = seq_get(args, 0)
758    this = seq_get(args, 1)
759
760    if isinstance(this, exp.Cast) and this.is_type("date"):
761        return exp.DateTrunc(unit=unit, this=this)
762    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
765def date_add_interval_sql(
766    data_type: str, kind: str
767) -> t.Callable[[Generator, exp.Expression], str]:
768    def func(self: Generator, expression: exp.Expression) -> str:
769        this = self.sql(expression, "this")
770        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
771        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
772
773    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
776def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
777    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
778        args = [unit_to_str(expression), expression.this]
779        if zone:
780            args.append(expression.args.get("zone"))
781        return self.func("DATE_TRUNC", *args)
782
783    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
786def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
787    if not expression.expression:
788        from sqlglot.optimizer.annotate_types import annotate_types
789
790        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
791        return self.sql(exp.cast(expression.this, target_type))
792    if expression.text("expression").lower() in TIMEZONES:
793        return self.sql(
794            exp.AtTimeZone(
795                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
796                zone=expression.expression,
797            )
798        )
799    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
802def locate_to_strposition(args: t.List) -> exp.Expression:
803    return exp.StrPosition(
804        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
805    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
808def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
809    return self.func(
810        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
811    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
814def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
815    return self.sql(
816        exp.Substring(
817            this=expression.this, start=exp.Literal.number(1), length=expression.expression
818        )
819    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
822def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
823    return self.sql(
824        exp.Substring(
825            this=expression.this,
826            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
827        )
828    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
831def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
832    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
835def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
836    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
840def encode_decode_sql(
841    self: Generator, expression: exp.Expression, name: str, replace: bool = True
842) -> str:
843    charset = expression.args.get("charset")
844    if charset and charset.name.lower() != "utf-8":
845        self.unsupported(f"Expected utf-8 character set, got {charset}.")
846
847    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
850def min_or_least(self: Generator, expression: exp.Min) -> str:
851    name = "LEAST" if expression.expressions else "MIN"
852    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
855def max_or_greatest(self: Generator, expression: exp.Max) -> str:
856    name = "GREATEST" if expression.expressions else "MAX"
857    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
860def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
861    cond = expression.this
862
863    if isinstance(expression.this, exp.Distinct):
864        cond = expression.this.expressions[0]
865        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
866
867    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
870def trim_sql(self: Generator, expression: exp.Trim) -> str:
871    target = self.sql(expression, "this")
872    trim_type = self.sql(expression, "position")
873    remove_chars = self.sql(expression, "expression")
874    collation = self.sql(expression, "collation")
875
876    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
877    if not remove_chars and not collation:
878        return self.trim_sql(expression)
879
880    trim_type = f"{trim_type} " if trim_type else ""
881    remove_chars = f"{remove_chars} " if remove_chars else ""
882    from_part = "FROM " if trim_type or remove_chars else ""
883    collation = f" COLLATE {collation}" if collation else ""
884    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
887def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
888    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
891def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
892    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
895def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
896    delim, *rest_args = expression.expressions
897    return self.sql(
898        reduce(
899            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
900            rest_args,
901        )
902    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
905def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
906    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
907    if bad_args:
908        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
909
910    return self.func(
911        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
912    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
915def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
916    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
917    if bad_args:
918        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
919
920    return self.func(
921        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
922    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
925def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
926    names = []
927    for agg in aggregations:
928        if isinstance(agg, exp.Alias):
929            names.append(agg.alias)
930        else:
931            """
932            This case corresponds to aggregations without aliases being used as suffixes
933            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
934            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
935            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
936            """
937            agg_all_unquoted = agg.transform(
938                lambda node: (
939                    exp.Identifier(this=node.name, quoted=False)
940                    if isinstance(node, exp.Identifier)
941                    else node
942                )
943            )
944            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
945
946    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
949def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
950    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
954def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
955    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
958def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
959    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
962def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
963    a = self.sql(expression.left)
964    b = self.sql(expression.right)
965    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
968def is_parse_json(expression: exp.Expression) -> bool:
969    return isinstance(expression, exp.ParseJSON) or (
970        isinstance(expression, exp.Cast) and expression.is_type("json")
971    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
974def isnull_to_is_null(args: t.List) -> exp.Expression:
975    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
978def generatedasidentitycolumnconstraint_sql(
979    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
980) -> str:
981    start = self.sql(expression, "start") or "1"
982    increment = self.sql(expression, "increment") or "1"
983    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
986def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
987    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
988        if expression.args.get("count"):
989            self.unsupported(f"Only two arguments are supported in function {name}.")
990
991        return self.func(name, expression.this, expression.expression)
992
993    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
 996def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 997    this = expression.this.copy()
 998
 999    return_type = expression.return_type
1000    if return_type.is_type(exp.DataType.Type.DATE):
1001        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1002        # can truncate timestamp strings, because some dialects can't cast them to DATE
1003        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1004
1005    expression.this.replace(exp.cast(this, return_type))
1006    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1009def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1010    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1011        if cast and isinstance(expression, exp.TsOrDsAdd):
1012            expression = ts_or_ds_add_cast(expression)
1013
1014        return self.func(
1015            name,
1016            unit_to_var(expression),
1017            expression.expression,
1018            expression.this,
1019        )
1020
1021    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1024def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1025    unit = expression.args.get("unit")
1026
1027    if isinstance(unit, exp.Placeholder):
1028        return unit
1029    if unit:
1030        return exp.Literal.string(unit.name)
1031    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1034def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1035    unit = expression.args.get("unit")
1036
1037    if isinstance(unit, (exp.Var, exp.Placeholder)):
1038        return unit
1039    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1042def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1043    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1044    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1045    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1046
1047    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1050def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1051    """Remove table refs from columns in when statements."""
1052    alias = expression.this.args.get("alias")
1053
1054    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1055        return self.dialect.normalize_identifier(identifier).name if identifier else None
1056
1057    targets = {normalize(expression.this.this)}
1058
1059    if alias:
1060        targets.add(normalize(alias.this))
1061
1062    for when in expression.expressions:
1063        when.transform(
1064            lambda node: (
1065                exp.column(node.this)
1066                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1067                else node
1068            ),
1069            copy=False,
1070        )
1071
1072    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1075def build_json_extract_path(
1076    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1077) -> t.Callable[[t.List], F]:
1078    def _builder(args: t.List) -> F:
1079        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1080        for arg in args[1:]:
1081            if not isinstance(arg, exp.Literal):
1082                # We use the fallback parser because we can't really transpile non-literals safely
1083                return expr_type.from_arg_list(args)
1084
1085            text = arg.name
1086            if is_int(text):
1087                index = int(text)
1088                segments.append(
1089                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1090                )
1091            else:
1092                segments.append(exp.JSONPathKey(this=text))
1093
1094        # This is done to avoid failing in the expression validator due to the arg count
1095        del args[2:]
1096        return expr_type(
1097            this=seq_get(args, 0),
1098            expression=exp.JSONPath(expressions=segments),
1099            only_json_types=arrow_req_json_type,
1100        )
1101
1102    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1105def json_extract_segments(
1106    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1107) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1108    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1109        path = expression.expression
1110        if not isinstance(path, exp.JSONPath):
1111            return rename_func(name)(self, expression)
1112
1113        segments = []
1114        for segment in path.expressions:
1115            path = self.sql(segment)
1116            if path:
1117                if isinstance(segment, exp.JSONPathPart) and (
1118                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1119                ):
1120                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1121
1122                segments.append(path)
1123
1124        if op:
1125            return f" {op} ".join([self.sql(expression.this), *segments])
1126        return self.func(name, expression.this, *segments)
1127
1128    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1131def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1132    if isinstance(expression.this, exp.JSONPathWildcard):
1133        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1134
1135    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1138def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1139    cond = expression.expression
1140    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1141        alias = cond.expressions[0]
1142        cond = cond.this
1143    elif isinstance(cond, exp.Predicate):
1144        alias = "_u"
1145    else:
1146        self.unsupported("Unsupported filter condition")
1147        return ""
1148
1149    unnest = exp.Unnest(expressions=[expression.this])
1150    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1151    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1154def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1155    return self.func(
1156        "TO_NUMBER",
1157        expression.this,
1158        expression.args.get("format"),
1159        expression.args.get("nlsparam"),
1160    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1163def build_default_decimal_type(
1164    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1165) -> t.Callable[[exp.DataType], exp.DataType]:
1166    def _builder(dtype: exp.DataType) -> exp.DataType:
1167        if dtype.expressions or precision is None:
1168            return dtype
1169
1170        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1171        return exp.DataType.build(f"DECIMAL({params})")
1172
1173    return _builder