Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28UNESCAPED_SEQUENCES = {
  29    "\\a": "\a",
  30    "\\b": "\b",
  31    "\\f": "\f",
  32    "\\n": "\n",
  33    "\\r": "\r",
  34    "\\t": "\t",
  35    "\\v": "\v",
  36    "\\\\": "\\",
  37}
  38
  39
  40class Dialects(str, Enum):
  41    """Dialects supported by SQLGLot."""
  42
  43    DIALECT = ""
  44
  45    ATHENA = "athena"
  46    BIGQUERY = "bigquery"
  47    CLICKHOUSE = "clickhouse"
  48    DATABRICKS = "databricks"
  49    DORIS = "doris"
  50    DRILL = "drill"
  51    DUCKDB = "duckdb"
  52    HIVE = "hive"
  53    MATERIALIZE = "materialize"
  54    MYSQL = "mysql"
  55    ORACLE = "oracle"
  56    POSTGRES = "postgres"
  57    PRESTO = "presto"
  58    PRQL = "prql"
  59    REDSHIFT = "redshift"
  60    RISINGWAVE = "risingwave"
  61    SNOWFLAKE = "snowflake"
  62    SPARK = "spark"
  63    SPARK2 = "spark2"
  64    SQLITE = "sqlite"
  65    STARROCKS = "starrocks"
  66    TABLEAU = "tableau"
  67    TERADATA = "teradata"
  68    TRINO = "trino"
  69    TSQL = "tsql"
  70
  71
  72class NormalizationStrategy(str, AutoName):
  73    """Specifies the strategy according to which identifiers should be normalized."""
  74
  75    LOWERCASE = auto()
  76    """Unquoted identifiers are lowercased."""
  77
  78    UPPERCASE = auto()
  79    """Unquoted identifiers are uppercased."""
  80
  81    CASE_SENSITIVE = auto()
  82    """Always case-sensitive, regardless of quotes."""
  83
  84    CASE_INSENSITIVE = auto()
  85    """Always case-insensitive, regardless of quotes."""
  86
  87
  88class _Dialect(type):
  89    classes: t.Dict[str, t.Type[Dialect]] = {}
  90
  91    def __eq__(cls, other: t.Any) -> bool:
  92        if cls is other:
  93            return True
  94        if isinstance(other, str):
  95            return cls is cls.get(other)
  96        if isinstance(other, Dialect):
  97            return cls is type(other)
  98
  99        return False
 100
 101    def __hash__(cls) -> int:
 102        return hash(cls.__name__.lower())
 103
 104    @classmethod
 105    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 106        return cls.classes[key]
 107
 108    @classmethod
 109    def get(
 110        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 111    ) -> t.Optional[t.Type[Dialect]]:
 112        return cls.classes.get(key, default)
 113
 114    def __new__(cls, clsname, bases, attrs):
 115        klass = super().__new__(cls, clsname, bases, attrs)
 116        enum = Dialects.__members__.get(clsname.upper())
 117        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 118
 119        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 120        klass.FORMAT_TRIE = (
 121            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 122        )
 123        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 124        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 125
 126        base = seq_get(bases, 0)
 127        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 128        base_parser = (getattr(base, "parser_class", Parser),)
 129        base_generator = (getattr(base, "generator_class", Generator),)
 130
 131        klass.tokenizer_class = klass.__dict__.get(
 132            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 133        )
 134        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 135        klass.generator_class = klass.__dict__.get(
 136            "Generator", type("Generator", base_generator, {})
 137        )
 138
 139        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 140        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 141            klass.tokenizer_class._IDENTIFIERS.items()
 142        )[0]
 143
 144        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 145            return next(
 146                (
 147                    (s, e)
 148                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 149                    if t == token_type
 150                ),
 151                (None, None),
 152            )
 153
 154        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 155        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 156        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 157        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 158
 159        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 160            klass.UNESCAPED_SEQUENCES = {
 161                **UNESCAPED_SEQUENCES,
 162                **klass.UNESCAPED_SEQUENCES,
 163            }
 164
 165        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 166
 167        if enum not in ("", "bigquery"):
 168            klass.generator_class.SELECT_KINDS = ()
 169
 170        if enum not in ("", "athena", "presto", "trino"):
 171            klass.generator_class.TRY_SUPPORTED = False
 172
 173        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 174            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 175            for modifier in ("cluster", "distribute", "sort"):
 176                modifier_transforms.pop(modifier, None)
 177
 178            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 179
 180        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 181            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 182                TokenType.ANTI,
 183                TokenType.SEMI,
 184            }
 185
 186        return klass
 187
 188
 189class Dialect(metaclass=_Dialect):
 190    INDEX_OFFSET = 0
 191    """The base index offset for arrays."""
 192
 193    WEEK_OFFSET = 0
 194    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 195
 196    UNNEST_COLUMN_ONLY = False
 197    """Whether `UNNEST` table aliases are treated as column aliases."""
 198
 199    ALIAS_POST_TABLESAMPLE = False
 200    """Whether the table alias comes after tablesample."""
 201
 202    TABLESAMPLE_SIZE_IS_PERCENT = False
 203    """Whether a size in the table sample clause represents percentage."""
 204
 205    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 206    """Specifies the strategy according to which identifiers should be normalized."""
 207
 208    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 209    """Whether an unquoted identifier can start with a digit."""
 210
 211    DPIPE_IS_STRING_CONCAT = True
 212    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 213
 214    STRICT_STRING_CONCAT = False
 215    """Whether `CONCAT`'s arguments must be strings."""
 216
 217    SUPPORTS_USER_DEFINED_TYPES = True
 218    """Whether user-defined data types are supported."""
 219
 220    SUPPORTS_SEMI_ANTI_JOIN = True
 221    """Whether `SEMI` or `ANTI` joins are supported."""
 222
 223    NORMALIZE_FUNCTIONS: bool | str = "upper"
 224    """
 225    Determines how function names are going to be normalized.
 226    Possible values:
 227        "upper" or True: Convert names to uppercase.
 228        "lower": Convert names to lowercase.
 229        False: Disables function name normalization.
 230    """
 231
 232    LOG_BASE_FIRST: t.Optional[bool] = True
 233    """
 234    Whether the base comes first in the `LOG` function.
 235    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 236    """
 237
 238    NULL_ORDERING = "nulls_are_small"
 239    """
 240    Default `NULL` ordering method to use if not explicitly set.
 241    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 242    """
 243
 244    TYPED_DIVISION = False
 245    """
 246    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 247    False means `a / b` is always float division.
 248    True means `a / b` is integer division if both `a` and `b` are integers.
 249    """
 250
 251    SAFE_DIVISION = False
 252    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 253
 254    CONCAT_COALESCE = False
 255    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 256
 257    HEX_LOWERCASE = False
 258    """Whether the `HEX` function returns a lowercase hexadecimal string."""
 259
 260    DATE_FORMAT = "'%Y-%m-%d'"
 261    DATEINT_FORMAT = "'%Y%m%d'"
 262    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 263
 264    TIME_MAPPING: t.Dict[str, str] = {}
 265    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 266
 267    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 268    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 269    FORMAT_MAPPING: t.Dict[str, str] = {}
 270    """
 271    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 272    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 273    """
 274
 275    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 276    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 277
 278    PSEUDOCOLUMNS: t.Set[str] = set()
 279    """
 280    Columns that are auto-generated by the engine corresponding to this dialect.
 281    For example, such columns may be excluded from `SELECT *` queries.
 282    """
 283
 284    PREFER_CTE_ALIAS_COLUMN = False
 285    """
 286    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 287    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 288    any projection aliases in the subquery.
 289
 290    For example,
 291        WITH y(c) AS (
 292            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 293        ) SELECT c FROM y;
 294
 295        will be rewritten as
 296
 297        WITH y(c) AS (
 298            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 299        ) SELECT c FROM y;
 300    """
 301
 302    # --- Autofilled ---
 303
 304    tokenizer_class = Tokenizer
 305    parser_class = Parser
 306    generator_class = Generator
 307
 308    # A trie of the time_mapping keys
 309    TIME_TRIE: t.Dict = {}
 310    FORMAT_TRIE: t.Dict = {}
 311
 312    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 313    INVERSE_TIME_TRIE: t.Dict = {}
 314
 315    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 316
 317    # Delimiters for string literals and identifiers
 318    QUOTE_START = "'"
 319    QUOTE_END = "'"
 320    IDENTIFIER_START = '"'
 321    IDENTIFIER_END = '"'
 322
 323    # Delimiters for bit, hex, byte and unicode literals
 324    BIT_START: t.Optional[str] = None
 325    BIT_END: t.Optional[str] = None
 326    HEX_START: t.Optional[str] = None
 327    HEX_END: t.Optional[str] = None
 328    BYTE_START: t.Optional[str] = None
 329    BYTE_END: t.Optional[str] = None
 330    UNICODE_START: t.Optional[str] = None
 331    UNICODE_END: t.Optional[str] = None
 332
 333    # Separator of COPY statement parameters
 334    COPY_PARAMS_ARE_CSV = True
 335
 336    @classmethod
 337    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 338        """
 339        Look up a dialect in the global dialect registry and return it if it exists.
 340
 341        Args:
 342            dialect: The target dialect. If this is a string, it can be optionally followed by
 343                additional key-value pairs that are separated by commas and are used to specify
 344                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 345
 346        Example:
 347            >>> dialect = dialect_class = get_or_raise("duckdb")
 348            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 349
 350        Returns:
 351            The corresponding Dialect instance.
 352        """
 353
 354        if not dialect:
 355            return cls()
 356        if isinstance(dialect, _Dialect):
 357            return dialect()
 358        if isinstance(dialect, Dialect):
 359            return dialect
 360        if isinstance(dialect, str):
 361            try:
 362                dialect_name, *kv_pairs = dialect.split(",")
 363                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 364            except ValueError:
 365                raise ValueError(
 366                    f"Invalid dialect format: '{dialect}'. "
 367                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 368                )
 369
 370            result = cls.get(dialect_name.strip())
 371            if not result:
 372                from difflib import get_close_matches
 373
 374                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 375                if similar:
 376                    similar = f" Did you mean {similar}?"
 377
 378                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 379
 380            return result(**kwargs)
 381
 382        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 383
 384    @classmethod
 385    def format_time(
 386        cls, expression: t.Optional[str | exp.Expression]
 387    ) -> t.Optional[exp.Expression]:
 388        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 389        if isinstance(expression, str):
 390            return exp.Literal.string(
 391                # the time formats are quoted
 392                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 393            )
 394
 395        if expression and expression.is_string:
 396            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 397
 398        return expression
 399
 400    def __init__(self, **kwargs) -> None:
 401        normalization_strategy = kwargs.get("normalization_strategy")
 402
 403        if normalization_strategy is None:
 404            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 405        else:
 406            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 407
 408    def __eq__(self, other: t.Any) -> bool:
 409        # Does not currently take dialect state into account
 410        return type(self) == other
 411
 412    def __hash__(self) -> int:
 413        # Does not currently take dialect state into account
 414        return hash(type(self))
 415
 416    def normalize_identifier(self, expression: E) -> E:
 417        """
 418        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 419
 420        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 421        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 422        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 423        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 424
 425        There are also dialects like Spark, which are case-insensitive even when quotes are
 426        present, and dialects like MySQL, whose resolution rules match those employed by the
 427        underlying operating system, for example they may always be case-sensitive in Linux.
 428
 429        Finally, the normalization behavior of some engines can even be controlled through flags,
 430        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 431
 432        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 433        that it can analyze queries in the optimizer and successfully capture their semantics.
 434        """
 435        if (
 436            isinstance(expression, exp.Identifier)
 437            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 438            and (
 439                not expression.quoted
 440                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 441            )
 442        ):
 443            expression.set(
 444                "this",
 445                (
 446                    expression.this.upper()
 447                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 448                    else expression.this.lower()
 449                ),
 450            )
 451
 452        return expression
 453
 454    def case_sensitive(self, text: str) -> bool:
 455        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 456        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 457            return False
 458
 459        unsafe = (
 460            str.islower
 461            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 462            else str.isupper
 463        )
 464        return any(unsafe(char) for char in text)
 465
 466    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 467        """Checks if text can be identified given an identify option.
 468
 469        Args:
 470            text: The text to check.
 471            identify:
 472                `"always"` or `True`: Always returns `True`.
 473                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 474
 475        Returns:
 476            Whether the given text can be identified.
 477        """
 478        if identify is True or identify == "always":
 479            return True
 480
 481        if identify == "safe":
 482            return not self.case_sensitive(text)
 483
 484        return False
 485
 486    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 487        """
 488        Adds quotes to a given identifier.
 489
 490        Args:
 491            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 492            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 493                "unsafe", with respect to its characters and this dialect's normalization strategy.
 494        """
 495        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 496            name = expression.this
 497            expression.set(
 498                "quoted",
 499                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 500            )
 501
 502        return expression
 503
 504    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 505        if isinstance(path, exp.Literal):
 506            path_text = path.name
 507            if path.is_number:
 508                path_text = f"[{path_text}]"
 509
 510            try:
 511                return parse_json_path(path_text)
 512            except ParseError as e:
 513                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 514
 515        return path
 516
 517    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 518        return self.parser(**opts).parse(self.tokenize(sql), sql)
 519
 520    def parse_into(
 521        self, expression_type: exp.IntoType, sql: str, **opts
 522    ) -> t.List[t.Optional[exp.Expression]]:
 523        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 524
 525    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 526        return self.generator(**opts).generate(expression, copy=copy)
 527
 528    def transpile(self, sql: str, **opts) -> t.List[str]:
 529        return [
 530            self.generate(expression, copy=False, **opts) if expression else ""
 531            for expression in self.parse(sql)
 532        ]
 533
 534    def tokenize(self, sql: str) -> t.List[Token]:
 535        return self.tokenizer.tokenize(sql)
 536
 537    @property
 538    def tokenizer(self) -> Tokenizer:
 539        if not hasattr(self, "_tokenizer"):
 540            self._tokenizer = self.tokenizer_class(dialect=self)
 541        return self._tokenizer
 542
 543    def parser(self, **opts) -> Parser:
 544        return self.parser_class(dialect=self, **opts)
 545
 546    def generator(self, **opts) -> Generator:
 547        return self.generator_class(dialect=self, **opts)
 548
 549
 550DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 551
 552
 553def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 554    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 555
 556
 557def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 558    if expression.args.get("accuracy"):
 559        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 560    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 561
 562
 563def if_sql(
 564    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 565) -> t.Callable[[Generator, exp.If], str]:
 566    def _if_sql(self: Generator, expression: exp.If) -> str:
 567        return self.func(
 568            name,
 569            expression.this,
 570            expression.args.get("true"),
 571            expression.args.get("false") or false_value,
 572        )
 573
 574    return _if_sql
 575
 576
 577def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 578    this = expression.this
 579    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 580        this.replace(exp.cast(this, exp.DataType.Type.JSON))
 581
 582    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 583
 584
 585def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 586    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
 587
 588
 589def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
 590    elem = seq_get(expression.expressions, 0)
 591    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
 592        return self.func("ARRAY", elem)
 593    return inline_array_sql(self, expression)
 594
 595
 596def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 597    return self.like_sql(
 598        exp.Like(
 599            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
 600        )
 601    )
 602
 603
 604def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 605    zone = self.sql(expression, "this")
 606    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 607
 608
 609def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 610    if expression.args.get("recursive"):
 611        self.unsupported("Recursive CTEs are unsupported")
 612        expression.args["recursive"] = False
 613    return self.with_sql(expression)
 614
 615
 616def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 617    n = self.sql(expression, "this")
 618    d = self.sql(expression, "expression")
 619    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 620
 621
 622def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 623    self.unsupported("TABLESAMPLE unsupported")
 624    return self.sql(expression.this)
 625
 626
 627def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 628    self.unsupported("PIVOT unsupported")
 629    return ""
 630
 631
 632def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 633    return self.cast_sql(expression)
 634
 635
 636def no_comment_column_constraint_sql(
 637    self: Generator, expression: exp.CommentColumnConstraint
 638) -> str:
 639    self.unsupported("CommentColumnConstraint unsupported")
 640    return ""
 641
 642
 643def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 644    self.unsupported("MAP_FROM_ENTRIES unsupported")
 645    return ""
 646
 647
 648def str_position_sql(
 649    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 650) -> str:
 651    this = self.sql(expression, "this")
 652    substr = self.sql(expression, "substr")
 653    position = self.sql(expression, "position")
 654    instance = expression.args.get("instance") if generate_instance else None
 655    position_offset = ""
 656
 657    if position:
 658        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 659        this = self.func("SUBSTR", this, position)
 660        position_offset = f" + {position} - 1"
 661
 662    return self.func("STRPOS", this, substr, instance) + position_offset
 663
 664
 665def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 666    return (
 667        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 668    )
 669
 670
 671def var_map_sql(
 672    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 673) -> str:
 674    keys = expression.args["keys"]
 675    values = expression.args["values"]
 676
 677    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 678        self.unsupported("Cannot convert array columns into map.")
 679        return self.func(map_func_name, keys, values)
 680
 681    args = []
 682    for key, value in zip(keys.expressions, values.expressions):
 683        args.append(self.sql(key))
 684        args.append(self.sql(value))
 685
 686    return self.func(map_func_name, *args)
 687
 688
 689def build_formatted_time(
 690    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 691) -> t.Callable[[t.List], E]:
 692    """Helper used for time expressions.
 693
 694    Args:
 695        exp_class: the expression class to instantiate.
 696        dialect: target sql dialect.
 697        default: the default format, True being time.
 698
 699    Returns:
 700        A callable that can be used to return the appropriately formatted time expression.
 701    """
 702
 703    def _builder(args: t.List):
 704        return exp_class(
 705            this=seq_get(args, 0),
 706            format=Dialect[dialect].format_time(
 707                seq_get(args, 1)
 708                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 709            ),
 710        )
 711
 712    return _builder
 713
 714
 715def time_format(
 716    dialect: DialectType = None,
 717) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 718    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 719        """
 720        Returns the time format for a given expression, unless it's equivalent
 721        to the default time format of the dialect of interest.
 722        """
 723        time_format = self.format_time(expression)
 724        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 725
 726    return _time_format
 727
 728
 729def build_date_delta(
 730    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 731) -> t.Callable[[t.List], E]:
 732    def _builder(args: t.List) -> E:
 733        unit_based = len(args) == 3
 734        this = args[2] if unit_based else seq_get(args, 0)
 735        unit = args[0] if unit_based else exp.Literal.string("DAY")
 736        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 737        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 738
 739    return _builder
 740
 741
 742def build_date_delta_with_interval(
 743    expression_class: t.Type[E],
 744) -> t.Callable[[t.List], t.Optional[E]]:
 745    def _builder(args: t.List) -> t.Optional[E]:
 746        if len(args) < 2:
 747            return None
 748
 749        interval = args[1]
 750
 751        if not isinstance(interval, exp.Interval):
 752            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 753
 754        expression = interval.this
 755        if expression and expression.is_string:
 756            expression = exp.Literal.number(expression.this)
 757
 758        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
 759
 760    return _builder
 761
 762
 763def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 764    unit = seq_get(args, 0)
 765    this = seq_get(args, 1)
 766
 767    if isinstance(this, exp.Cast) and this.is_type("date"):
 768        return exp.DateTrunc(unit=unit, this=this)
 769    return exp.TimestampTrunc(this=this, unit=unit)
 770
 771
 772def date_add_interval_sql(
 773    data_type: str, kind: str
 774) -> t.Callable[[Generator, exp.Expression], str]:
 775    def func(self: Generator, expression: exp.Expression) -> str:
 776        this = self.sql(expression, "this")
 777        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
 778        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 779
 780    return func
 781
 782
 783def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
 784    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 785        args = [unit_to_str(expression), expression.this]
 786        if zone:
 787            args.append(expression.args.get("zone"))
 788        return self.func("DATE_TRUNC", *args)
 789
 790    return _timestamptrunc_sql
 791
 792
 793def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 794    if not expression.expression:
 795        from sqlglot.optimizer.annotate_types import annotate_types
 796
 797        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 798        return self.sql(exp.cast(expression.this, target_type))
 799    if expression.text("expression").lower() in TIMEZONES:
 800        return self.sql(
 801            exp.AtTimeZone(
 802                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
 803                zone=expression.expression,
 804            )
 805        )
 806    return self.func("TIMESTAMP", expression.this, expression.expression)
 807
 808
 809def locate_to_strposition(args: t.List) -> exp.Expression:
 810    return exp.StrPosition(
 811        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 812    )
 813
 814
 815def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 816    return self.func(
 817        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 818    )
 819
 820
 821def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 822    return self.sql(
 823        exp.Substring(
 824            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 825        )
 826    )
 827
 828
 829def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 830    return self.sql(
 831        exp.Substring(
 832            this=expression.this,
 833            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 834        )
 835    )
 836
 837
 838def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 839    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
 840
 841
 842def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 843    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
 844
 845
 846# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 847def encode_decode_sql(
 848    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 849) -> str:
 850    charset = expression.args.get("charset")
 851    if charset and charset.name.lower() != "utf-8":
 852        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 853
 854    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 855
 856
 857def min_or_least(self: Generator, expression: exp.Min) -> str:
 858    name = "LEAST" if expression.expressions else "MIN"
 859    return rename_func(name)(self, expression)
 860
 861
 862def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 863    name = "GREATEST" if expression.expressions else "MAX"
 864    return rename_func(name)(self, expression)
 865
 866
 867def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 868    cond = expression.this
 869
 870    if isinstance(expression.this, exp.Distinct):
 871        cond = expression.this.expressions[0]
 872        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 873
 874    return self.func("sum", exp.func("if", cond, 1, 0))
 875
 876
 877def trim_sql(self: Generator, expression: exp.Trim) -> str:
 878    target = self.sql(expression, "this")
 879    trim_type = self.sql(expression, "position")
 880    remove_chars = self.sql(expression, "expression")
 881    collation = self.sql(expression, "collation")
 882
 883    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 884    if not remove_chars and not collation:
 885        return self.trim_sql(expression)
 886
 887    trim_type = f"{trim_type} " if trim_type else ""
 888    remove_chars = f"{remove_chars} " if remove_chars else ""
 889    from_part = "FROM " if trim_type or remove_chars else ""
 890    collation = f" COLLATE {collation}" if collation else ""
 891    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 892
 893
 894def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 895    return self.func("STRPTIME", expression.this, self.format_time(expression))
 896
 897
 898def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 899    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 900
 901
 902def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 903    delim, *rest_args = expression.expressions
 904    return self.sql(
 905        reduce(
 906            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 907            rest_args,
 908        )
 909    )
 910
 911
 912def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 913    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 914    if bad_args:
 915        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 916
 917    return self.func(
 918        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 919    )
 920
 921
 922def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 923    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
 924    if bad_args:
 925        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 926
 927    return self.func(
 928        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 929    )
 930
 931
 932def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 933    names = []
 934    for agg in aggregations:
 935        if isinstance(agg, exp.Alias):
 936            names.append(agg.alias)
 937        else:
 938            """
 939            This case corresponds to aggregations without aliases being used as suffixes
 940            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 941            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 942            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 943            """
 944            agg_all_unquoted = agg.transform(
 945                lambda node: (
 946                    exp.Identifier(this=node.name, quoted=False)
 947                    if isinstance(node, exp.Identifier)
 948                    else node
 949                )
 950            )
 951            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 952
 953    return names
 954
 955
 956def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 957    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 958
 959
 960# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 961def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 962    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 963
 964
 965def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 966    return self.func("MAX", expression.this)
 967
 968
 969def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 970    a = self.sql(expression.left)
 971    b = self.sql(expression.right)
 972    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 973
 974
 975def is_parse_json(expression: exp.Expression) -> bool:
 976    return isinstance(expression, exp.ParseJSON) or (
 977        isinstance(expression, exp.Cast) and expression.is_type("json")
 978    )
 979
 980
 981def isnull_to_is_null(args: t.List) -> exp.Expression:
 982    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 983
 984
 985def generatedasidentitycolumnconstraint_sql(
 986    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 987) -> str:
 988    start = self.sql(expression, "start") or "1"
 989    increment = self.sql(expression, "increment") or "1"
 990    return f"IDENTITY({start}, {increment})"
 991
 992
 993def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 994    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 995        if expression.args.get("count"):
 996            self.unsupported(f"Only two arguments are supported in function {name}.")
 997
 998        return self.func(name, expression.this, expression.expression)
 999
1000    return _arg_max_or_min_sql
1001
1002
1003def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1004    this = expression.this.copy()
1005
1006    return_type = expression.return_type
1007    if return_type.is_type(exp.DataType.Type.DATE):
1008        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1009        # can truncate timestamp strings, because some dialects can't cast them to DATE
1010        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1011
1012    expression.this.replace(exp.cast(this, return_type))
1013    return expression
1014
1015
1016def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1017    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1018        if cast and isinstance(expression, exp.TsOrDsAdd):
1019            expression = ts_or_ds_add_cast(expression)
1020
1021        return self.func(
1022            name,
1023            unit_to_var(expression),
1024            expression.expression,
1025            expression.this,
1026        )
1027
1028    return _delta_sql
1029
1030
1031def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1032    unit = expression.args.get("unit")
1033
1034    if isinstance(unit, exp.Placeholder):
1035        return unit
1036    if unit:
1037        return exp.Literal.string(unit.name)
1038    return exp.Literal.string(default) if default else None
1039
1040
1041def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1042    unit = expression.args.get("unit")
1043
1044    if isinstance(unit, (exp.Var, exp.Placeholder)):
1045        return unit
1046    return exp.Var(this=default) if default else None
1047
1048
1049def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1050    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1051    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1052    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1053
1054    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1055
1056
1057def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1058    """Remove table refs from columns in when statements."""
1059    alias = expression.this.args.get("alias")
1060
1061    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1062        return self.dialect.normalize_identifier(identifier).name if identifier else None
1063
1064    targets = {normalize(expression.this.this)}
1065
1066    if alias:
1067        targets.add(normalize(alias.this))
1068
1069    for when in expression.expressions:
1070        when.transform(
1071            lambda node: (
1072                exp.column(node.this)
1073                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1074                else node
1075            ),
1076            copy=False,
1077        )
1078
1079    return self.merge_sql(expression)
1080
1081
1082def build_json_extract_path(
1083    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1084) -> t.Callable[[t.List], F]:
1085    def _builder(args: t.List) -> F:
1086        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1087        for arg in args[1:]:
1088            if not isinstance(arg, exp.Literal):
1089                # We use the fallback parser because we can't really transpile non-literals safely
1090                return expr_type.from_arg_list(args)
1091
1092            text = arg.name
1093            if is_int(text):
1094                index = int(text)
1095                segments.append(
1096                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1097                )
1098            else:
1099                segments.append(exp.JSONPathKey(this=text))
1100
1101        # This is done to avoid failing in the expression validator due to the arg count
1102        del args[2:]
1103        return expr_type(
1104            this=seq_get(args, 0),
1105            expression=exp.JSONPath(expressions=segments),
1106            only_json_types=arrow_req_json_type,
1107        )
1108
1109    return _builder
1110
1111
1112def json_extract_segments(
1113    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1114) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1115    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1116        path = expression.expression
1117        if not isinstance(path, exp.JSONPath):
1118            return rename_func(name)(self, expression)
1119
1120        segments = []
1121        for segment in path.expressions:
1122            path = self.sql(segment)
1123            if path:
1124                if isinstance(segment, exp.JSONPathPart) and (
1125                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1126                ):
1127                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1128
1129                segments.append(path)
1130
1131        if op:
1132            return f" {op} ".join([self.sql(expression.this), *segments])
1133        return self.func(name, expression.this, *segments)
1134
1135    return _json_extract_segments
1136
1137
1138def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1139    if isinstance(expression.this, exp.JSONPathWildcard):
1140        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1141
1142    return expression.name
1143
1144
1145def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1146    cond = expression.expression
1147    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1148        alias = cond.expressions[0]
1149        cond = cond.this
1150    elif isinstance(cond, exp.Predicate):
1151        alias = "_u"
1152    else:
1153        self.unsupported("Unsupported filter condition")
1154        return ""
1155
1156    unnest = exp.Unnest(expressions=[expression.this])
1157    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1158    return self.sql(exp.Array(expressions=[filtered]))
1159
1160
1161def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1162    return self.func(
1163        "TO_NUMBER",
1164        expression.this,
1165        expression.args.get("format"),
1166        expression.args.get("nlsparam"),
1167    )
1168
1169
1170def build_default_decimal_type(
1171    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1172) -> t.Callable[[exp.DataType], exp.DataType]:
1173    def _builder(dtype: exp.DataType) -> exp.DataType:
1174        if dtype.expressions or precision is None:
1175            return dtype
1176
1177        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1178        return exp.DataType.build(f"DECIMAL({params})")
1179
1180    return _builder
logger = <Logger sqlglot (WARNING)>
UNESCAPED_SEQUENCES = {'\\a': '\x07', '\\b': '\x08', '\\f': '\x0c', '\\n': '\n', '\\r': '\r', '\\t': '\t', '\\v': '\x0b', '\\\\': '\\'}
class Dialects(builtins.str, enum.Enum):
41class Dialects(str, Enum):
42    """Dialects supported by SQLGLot."""
43
44    DIALECT = ""
45
46    ATHENA = "athena"
47    BIGQUERY = "bigquery"
48    CLICKHOUSE = "clickhouse"
49    DATABRICKS = "databricks"
50    DORIS = "doris"
51    DRILL = "drill"
52    DUCKDB = "duckdb"
53    HIVE = "hive"
54    MATERIALIZE = "materialize"
55    MYSQL = "mysql"
56    ORACLE = "oracle"
57    POSTGRES = "postgres"
58    PRESTO = "presto"
59    PRQL = "prql"
60    REDSHIFT = "redshift"
61    RISINGWAVE = "risingwave"
62    SNOWFLAKE = "snowflake"
63    SPARK = "spark"
64    SPARK2 = "spark2"
65    SQLITE = "sqlite"
66    STARROCKS = "starrocks"
67    TABLEAU = "tableau"
68    TERADATA = "teradata"
69    TRINO = "trino"
70    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MATERIALIZE = <Dialects.MATERIALIZE: 'materialize'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
RISINGWAVE = <Dialects.RISINGWAVE: 'risingwave'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
73class NormalizationStrategy(str, AutoName):
74    """Specifies the strategy according to which identifiers should be normalized."""
75
76    LOWERCASE = auto()
77    """Unquoted identifiers are lowercased."""
78
79    UPPERCASE = auto()
80    """Unquoted identifiers are uppercased."""
81
82    CASE_SENSITIVE = auto()
83    """Always case-sensitive, regardless of quotes."""
84
85    CASE_INSENSITIVE = auto()
86    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
190class Dialect(metaclass=_Dialect):
191    INDEX_OFFSET = 0
192    """The base index offset for arrays."""
193
194    WEEK_OFFSET = 0
195    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
196
197    UNNEST_COLUMN_ONLY = False
198    """Whether `UNNEST` table aliases are treated as column aliases."""
199
200    ALIAS_POST_TABLESAMPLE = False
201    """Whether the table alias comes after tablesample."""
202
203    TABLESAMPLE_SIZE_IS_PERCENT = False
204    """Whether a size in the table sample clause represents percentage."""
205
206    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
207    """Specifies the strategy according to which identifiers should be normalized."""
208
209    IDENTIFIERS_CAN_START_WITH_DIGIT = False
210    """Whether an unquoted identifier can start with a digit."""
211
212    DPIPE_IS_STRING_CONCAT = True
213    """Whether the DPIPE token (`||`) is a string concatenation operator."""
214
215    STRICT_STRING_CONCAT = False
216    """Whether `CONCAT`'s arguments must be strings."""
217
218    SUPPORTS_USER_DEFINED_TYPES = True
219    """Whether user-defined data types are supported."""
220
221    SUPPORTS_SEMI_ANTI_JOIN = True
222    """Whether `SEMI` or `ANTI` joins are supported."""
223
224    NORMALIZE_FUNCTIONS: bool | str = "upper"
225    """
226    Determines how function names are going to be normalized.
227    Possible values:
228        "upper" or True: Convert names to uppercase.
229        "lower": Convert names to lowercase.
230        False: Disables function name normalization.
231    """
232
233    LOG_BASE_FIRST: t.Optional[bool] = True
234    """
235    Whether the base comes first in the `LOG` function.
236    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
237    """
238
239    NULL_ORDERING = "nulls_are_small"
240    """
241    Default `NULL` ordering method to use if not explicitly set.
242    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
243    """
244
245    TYPED_DIVISION = False
246    """
247    Whether the behavior of `a / b` depends on the types of `a` and `b`.
248    False means `a / b` is always float division.
249    True means `a / b` is integer division if both `a` and `b` are integers.
250    """
251
252    SAFE_DIVISION = False
253    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
254
255    CONCAT_COALESCE = False
256    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
257
258    HEX_LOWERCASE = False
259    """Whether the `HEX` function returns a lowercase hexadecimal string."""
260
261    DATE_FORMAT = "'%Y-%m-%d'"
262    DATEINT_FORMAT = "'%Y%m%d'"
263    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
264
265    TIME_MAPPING: t.Dict[str, str] = {}
266    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
267
268    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
269    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
270    FORMAT_MAPPING: t.Dict[str, str] = {}
271    """
272    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
273    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
274    """
275
276    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
277    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
278
279    PSEUDOCOLUMNS: t.Set[str] = set()
280    """
281    Columns that are auto-generated by the engine corresponding to this dialect.
282    For example, such columns may be excluded from `SELECT *` queries.
283    """
284
285    PREFER_CTE_ALIAS_COLUMN = False
286    """
287    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
288    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
289    any projection aliases in the subquery.
290
291    For example,
292        WITH y(c) AS (
293            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
294        ) SELECT c FROM y;
295
296        will be rewritten as
297
298        WITH y(c) AS (
299            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
300        ) SELECT c FROM y;
301    """
302
303    # --- Autofilled ---
304
305    tokenizer_class = Tokenizer
306    parser_class = Parser
307    generator_class = Generator
308
309    # A trie of the time_mapping keys
310    TIME_TRIE: t.Dict = {}
311    FORMAT_TRIE: t.Dict = {}
312
313    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
314    INVERSE_TIME_TRIE: t.Dict = {}
315
316    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
317
318    # Delimiters for string literals and identifiers
319    QUOTE_START = "'"
320    QUOTE_END = "'"
321    IDENTIFIER_START = '"'
322    IDENTIFIER_END = '"'
323
324    # Delimiters for bit, hex, byte and unicode literals
325    BIT_START: t.Optional[str] = None
326    BIT_END: t.Optional[str] = None
327    HEX_START: t.Optional[str] = None
328    HEX_END: t.Optional[str] = None
329    BYTE_START: t.Optional[str] = None
330    BYTE_END: t.Optional[str] = None
331    UNICODE_START: t.Optional[str] = None
332    UNICODE_END: t.Optional[str] = None
333
334    # Separator of COPY statement parameters
335    COPY_PARAMS_ARE_CSV = True
336
337    @classmethod
338    def get_or_raise(cls, dialect: DialectType) -> Dialect:
339        """
340        Look up a dialect in the global dialect registry and return it if it exists.
341
342        Args:
343            dialect: The target dialect. If this is a string, it can be optionally followed by
344                additional key-value pairs that are separated by commas and are used to specify
345                dialect settings, such as whether the dialect's identifiers are case-sensitive.
346
347        Example:
348            >>> dialect = dialect_class = get_or_raise("duckdb")
349            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
350
351        Returns:
352            The corresponding Dialect instance.
353        """
354
355        if not dialect:
356            return cls()
357        if isinstance(dialect, _Dialect):
358            return dialect()
359        if isinstance(dialect, Dialect):
360            return dialect
361        if isinstance(dialect, str):
362            try:
363                dialect_name, *kv_pairs = dialect.split(",")
364                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
365            except ValueError:
366                raise ValueError(
367                    f"Invalid dialect format: '{dialect}'. "
368                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
369                )
370
371            result = cls.get(dialect_name.strip())
372            if not result:
373                from difflib import get_close_matches
374
375                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
376                if similar:
377                    similar = f" Did you mean {similar}?"
378
379                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
380
381            return result(**kwargs)
382
383        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
384
385    @classmethod
386    def format_time(
387        cls, expression: t.Optional[str | exp.Expression]
388    ) -> t.Optional[exp.Expression]:
389        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
390        if isinstance(expression, str):
391            return exp.Literal.string(
392                # the time formats are quoted
393                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
394            )
395
396        if expression and expression.is_string:
397            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
398
399        return expression
400
401    def __init__(self, **kwargs) -> None:
402        normalization_strategy = kwargs.get("normalization_strategy")
403
404        if normalization_strategy is None:
405            self.normalization_strategy = self.NORMALIZATION_STRATEGY
406        else:
407            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
408
409    def __eq__(self, other: t.Any) -> bool:
410        # Does not currently take dialect state into account
411        return type(self) == other
412
413    def __hash__(self) -> int:
414        # Does not currently take dialect state into account
415        return hash(type(self))
416
417    def normalize_identifier(self, expression: E) -> E:
418        """
419        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
420
421        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
422        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
423        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
424        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
425
426        There are also dialects like Spark, which are case-insensitive even when quotes are
427        present, and dialects like MySQL, whose resolution rules match those employed by the
428        underlying operating system, for example they may always be case-sensitive in Linux.
429
430        Finally, the normalization behavior of some engines can even be controlled through flags,
431        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
432
433        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
434        that it can analyze queries in the optimizer and successfully capture their semantics.
435        """
436        if (
437            isinstance(expression, exp.Identifier)
438            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
439            and (
440                not expression.quoted
441                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
442            )
443        ):
444            expression.set(
445                "this",
446                (
447                    expression.this.upper()
448                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
449                    else expression.this.lower()
450                ),
451            )
452
453        return expression
454
455    def case_sensitive(self, text: str) -> bool:
456        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
457        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
458            return False
459
460        unsafe = (
461            str.islower
462            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
463            else str.isupper
464        )
465        return any(unsafe(char) for char in text)
466
467    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
468        """Checks if text can be identified given an identify option.
469
470        Args:
471            text: The text to check.
472            identify:
473                `"always"` or `True`: Always returns `True`.
474                `"safe"`: Only returns `True` if the identifier is case-insensitive.
475
476        Returns:
477            Whether the given text can be identified.
478        """
479        if identify is True or identify == "always":
480            return True
481
482        if identify == "safe":
483            return not self.case_sensitive(text)
484
485        return False
486
487    def quote_identifier(self, expression: E, identify: bool = True) -> E:
488        """
489        Adds quotes to a given identifier.
490
491        Args:
492            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
493            identify: If set to `False`, the quotes will only be added if the identifier is deemed
494                "unsafe", with respect to its characters and this dialect's normalization strategy.
495        """
496        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
497            name = expression.this
498            expression.set(
499                "quoted",
500                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
501            )
502
503        return expression
504
505    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
506        if isinstance(path, exp.Literal):
507            path_text = path.name
508            if path.is_number:
509                path_text = f"[{path_text}]"
510
511            try:
512                return parse_json_path(path_text)
513            except ParseError as e:
514                logger.warning(f"Invalid JSON path syntax. {str(e)}")
515
516        return path
517
518    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
519        return self.parser(**opts).parse(self.tokenize(sql), sql)
520
521    def parse_into(
522        self, expression_type: exp.IntoType, sql: str, **opts
523    ) -> t.List[t.Optional[exp.Expression]]:
524        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
525
526    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
527        return self.generator(**opts).generate(expression, copy=copy)
528
529    def transpile(self, sql: str, **opts) -> t.List[str]:
530        return [
531            self.generate(expression, copy=False, **opts) if expression else ""
532            for expression in self.parse(sql)
533        ]
534
535    def tokenize(self, sql: str) -> t.List[Token]:
536        return self.tokenizer.tokenize(sql)
537
538    @property
539    def tokenizer(self) -> Tokenizer:
540        if not hasattr(self, "_tokenizer"):
541            self._tokenizer = self.tokenizer_class(dialect=self)
542        return self._tokenizer
543
544    def parser(self, **opts) -> Parser:
545        return self.parser_class(dialect=self, **opts)
546
547    def generator(self, **opts) -> Generator:
548        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
401    def __init__(self, **kwargs) -> None:
402        normalization_strategy = kwargs.get("normalization_strategy")
403
404        if normalization_strategy is None:
405            self.normalization_strategy = self.NORMALIZATION_STRATEGY
406        else:
407            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

HEX_LOWERCASE = False

Whether the HEX function returns a lowercase hexadecimal string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
COPY_PARAMS_ARE_CSV = True
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
337    @classmethod
338    def get_or_raise(cls, dialect: DialectType) -> Dialect:
339        """
340        Look up a dialect in the global dialect registry and return it if it exists.
341
342        Args:
343            dialect: The target dialect. If this is a string, it can be optionally followed by
344                additional key-value pairs that are separated by commas and are used to specify
345                dialect settings, such as whether the dialect's identifiers are case-sensitive.
346
347        Example:
348            >>> dialect = dialect_class = get_or_raise("duckdb")
349            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
350
351        Returns:
352            The corresponding Dialect instance.
353        """
354
355        if not dialect:
356            return cls()
357        if isinstance(dialect, _Dialect):
358            return dialect()
359        if isinstance(dialect, Dialect):
360            return dialect
361        if isinstance(dialect, str):
362            try:
363                dialect_name, *kv_pairs = dialect.split(",")
364                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
365            except ValueError:
366                raise ValueError(
367                    f"Invalid dialect format: '{dialect}'. "
368                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
369                )
370
371            result = cls.get(dialect_name.strip())
372            if not result:
373                from difflib import get_close_matches
374
375                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
376                if similar:
377                    similar = f" Did you mean {similar}?"
378
379                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
380
381            return result(**kwargs)
382
383        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
385    @classmethod
386    def format_time(
387        cls, expression: t.Optional[str | exp.Expression]
388    ) -> t.Optional[exp.Expression]:
389        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
390        if isinstance(expression, str):
391            return exp.Literal.string(
392                # the time formats are quoted
393                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
394            )
395
396        if expression and expression.is_string:
397            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
398
399        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
417    def normalize_identifier(self, expression: E) -> E:
418        """
419        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
420
421        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
422        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
423        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
424        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
425
426        There are also dialects like Spark, which are case-insensitive even when quotes are
427        present, and dialects like MySQL, whose resolution rules match those employed by the
428        underlying operating system, for example they may always be case-sensitive in Linux.
429
430        Finally, the normalization behavior of some engines can even be controlled through flags,
431        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
432
433        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
434        that it can analyze queries in the optimizer and successfully capture their semantics.
435        """
436        if (
437            isinstance(expression, exp.Identifier)
438            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
439            and (
440                not expression.quoted
441                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
442            )
443        ):
444            expression.set(
445                "this",
446                (
447                    expression.this.upper()
448                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
449                    else expression.this.lower()
450                ),
451            )
452
453        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
455    def case_sensitive(self, text: str) -> bool:
456        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
457        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
458            return False
459
460        unsafe = (
461            str.islower
462            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
463            else str.isupper
464        )
465        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
467    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
468        """Checks if text can be identified given an identify option.
469
470        Args:
471            text: The text to check.
472            identify:
473                `"always"` or `True`: Always returns `True`.
474                `"safe"`: Only returns `True` if the identifier is case-insensitive.
475
476        Returns:
477            Whether the given text can be identified.
478        """
479        if identify is True or identify == "always":
480            return True
481
482        if identify == "safe":
483            return not self.case_sensitive(text)
484
485        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
487    def quote_identifier(self, expression: E, identify: bool = True) -> E:
488        """
489        Adds quotes to a given identifier.
490
491        Args:
492            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
493            identify: If set to `False`, the quotes will only be added if the identifier is deemed
494                "unsafe", with respect to its characters and this dialect's normalization strategy.
495        """
496        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
497            name = expression.this
498            expression.set(
499                "quoted",
500                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
501            )
502
503        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
505    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
506        if isinstance(path, exp.Literal):
507            path_text = path.name
508            if path.is_number:
509                path_text = f"[{path_text}]"
510
511            try:
512                return parse_json_path(path_text)
513            except ParseError as e:
514                logger.warning(f"Invalid JSON path syntax. {str(e)}")
515
516        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
518    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
519        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
521    def parse_into(
522        self, expression_type: exp.IntoType, sql: str, **opts
523    ) -> t.List[t.Optional[exp.Expression]]:
524        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
526    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
527        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
529    def transpile(self, sql: str, **opts) -> t.List[str]:
530        return [
531            self.generate(expression, copy=False, **opts) if expression else ""
532            for expression in self.parse(sql)
533        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
535    def tokenize(self, sql: str) -> t.List[Token]:
536        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
538    @property
539    def tokenizer(self) -> Tokenizer:
540        if not hasattr(self, "_tokenizer"):
541            self._tokenizer = self.tokenizer_class(dialect=self)
542        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
544    def parser(self, **opts) -> Parser:
545        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
547    def generator(self, **opts) -> Generator:
548        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
554def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
555    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
558def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
559    if expression.args.get("accuracy"):
560        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
561    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
564def if_sql(
565    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
566) -> t.Callable[[Generator, exp.If], str]:
567    def _if_sql(self: Generator, expression: exp.If) -> str:
568        return self.func(
569            name,
570            expression.this,
571            expression.args.get("true"),
572            expression.args.get("false") or false_value,
573        )
574
575    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
578def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
579    this = expression.this
580    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
581        this.replace(exp.cast(this, exp.DataType.Type.JSON))
582
583    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
586def inline_array_sql(self: Generator, expression: exp.Array) -> str:
587    return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
590def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
591    elem = seq_get(expression.expressions, 0)
592    if isinstance(elem, exp.Expression) and elem.find(exp.Query):
593        return self.func("ARRAY", elem)
594    return inline_array_sql(self, expression)
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
597def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
598    return self.like_sql(
599        exp.Like(
600            this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
601        )
602    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
605def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
606    zone = self.sql(expression, "this")
607    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
610def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
611    if expression.args.get("recursive"):
612        self.unsupported("Recursive CTEs are unsupported")
613        expression.args["recursive"] = False
614    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
617def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
618    n = self.sql(expression, "this")
619    d = self.sql(expression, "expression")
620    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
623def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
624    self.unsupported("TABLESAMPLE unsupported")
625    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
628def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
629    self.unsupported("PIVOT unsupported")
630    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
633def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
634    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
637def no_comment_column_constraint_sql(
638    self: Generator, expression: exp.CommentColumnConstraint
639) -> str:
640    self.unsupported("CommentColumnConstraint unsupported")
641    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
644def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
645    self.unsupported("MAP_FROM_ENTRIES unsupported")
646    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
649def str_position_sql(
650    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
651) -> str:
652    this = self.sql(expression, "this")
653    substr = self.sql(expression, "substr")
654    position = self.sql(expression, "position")
655    instance = expression.args.get("instance") if generate_instance else None
656    position_offset = ""
657
658    if position:
659        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
660        this = self.func("SUBSTR", this, position)
661        position_offset = f" + {position} - 1"
662
663    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
666def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
667    return (
668        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
669    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
672def var_map_sql(
673    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
674) -> str:
675    keys = expression.args["keys"]
676    values = expression.args["values"]
677
678    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
679        self.unsupported("Cannot convert array columns into map.")
680        return self.func(map_func_name, keys, values)
681
682    args = []
683    for key, value in zip(keys.expressions, values.expressions):
684        args.append(self.sql(key))
685        args.append(self.sql(value))
686
687    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
690def build_formatted_time(
691    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
692) -> t.Callable[[t.List], E]:
693    """Helper used for time expressions.
694
695    Args:
696        exp_class: the expression class to instantiate.
697        dialect: target sql dialect.
698        default: the default format, True being time.
699
700    Returns:
701        A callable that can be used to return the appropriately formatted time expression.
702    """
703
704    def _builder(args: t.List):
705        return exp_class(
706            this=seq_get(args, 0),
707            format=Dialect[dialect].format_time(
708                seq_get(args, 1)
709                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
710            ),
711        )
712
713    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
716def time_format(
717    dialect: DialectType = None,
718) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
719    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
720        """
721        Returns the time format for a given expression, unless it's equivalent
722        to the default time format of the dialect of interest.
723        """
724        time_format = self.format_time(expression)
725        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
726
727    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
730def build_date_delta(
731    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
732) -> t.Callable[[t.List], E]:
733    def _builder(args: t.List) -> E:
734        unit_based = len(args) == 3
735        this = args[2] if unit_based else seq_get(args, 0)
736        unit = args[0] if unit_based else exp.Literal.string("DAY")
737        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
738        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
739
740    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
743def build_date_delta_with_interval(
744    expression_class: t.Type[E],
745) -> t.Callable[[t.List], t.Optional[E]]:
746    def _builder(args: t.List) -> t.Optional[E]:
747        if len(args) < 2:
748            return None
749
750        interval = args[1]
751
752        if not isinstance(interval, exp.Interval):
753            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
754
755        expression = interval.this
756        if expression and expression.is_string:
757            expression = exp.Literal.number(expression.this)
758
759        return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
760
761    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
764def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
765    unit = seq_get(args, 0)
766    this = seq_get(args, 1)
767
768    if isinstance(this, exp.Cast) and this.is_type("date"):
769        return exp.DateTrunc(unit=unit, this=this)
770    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
773def date_add_interval_sql(
774    data_type: str, kind: str
775) -> t.Callable[[Generator, exp.Expression], str]:
776    def func(self: Generator, expression: exp.Expression) -> str:
777        this = self.sql(expression, "this")
778        interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
779        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
780
781    return func
def timestamptrunc_sql( zone: bool = False) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.TimestampTrunc], str]:
784def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
785    def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
786        args = [unit_to_str(expression), expression.this]
787        if zone:
788            args.append(expression.args.get("zone"))
789        return self.func("DATE_TRUNC", *args)
790
791    return _timestamptrunc_sql
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
794def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
795    if not expression.expression:
796        from sqlglot.optimizer.annotate_types import annotate_types
797
798        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
799        return self.sql(exp.cast(expression.this, target_type))
800    if expression.text("expression").lower() in TIMEZONES:
801        return self.sql(
802            exp.AtTimeZone(
803                this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
804                zone=expression.expression,
805            )
806        )
807    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
810def locate_to_strposition(args: t.List) -> exp.Expression:
811    return exp.StrPosition(
812        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
813    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
816def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
817    return self.func(
818        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
819    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
822def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
823    return self.sql(
824        exp.Substring(
825            this=expression.this, start=exp.Literal.number(1), length=expression.expression
826        )
827    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
830def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
831    return self.sql(
832        exp.Substring(
833            this=expression.this,
834            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
835        )
836    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
839def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
840    return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
843def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
844    return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
848def encode_decode_sql(
849    self: Generator, expression: exp.Expression, name: str, replace: bool = True
850) -> str:
851    charset = expression.args.get("charset")
852    if charset and charset.name.lower() != "utf-8":
853        self.unsupported(f"Expected utf-8 character set, got {charset}.")
854
855    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
858def min_or_least(self: Generator, expression: exp.Min) -> str:
859    name = "LEAST" if expression.expressions else "MIN"
860    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
863def max_or_greatest(self: Generator, expression: exp.Max) -> str:
864    name = "GREATEST" if expression.expressions else "MAX"
865    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
868def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
869    cond = expression.this
870
871    if isinstance(expression.this, exp.Distinct):
872        cond = expression.this.expressions[0]
873        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
874
875    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
878def trim_sql(self: Generator, expression: exp.Trim) -> str:
879    target = self.sql(expression, "this")
880    trim_type = self.sql(expression, "position")
881    remove_chars = self.sql(expression, "expression")
882    collation = self.sql(expression, "collation")
883
884    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
885    if not remove_chars and not collation:
886        return self.trim_sql(expression)
887
888    trim_type = f"{trim_type} " if trim_type else ""
889    remove_chars = f"{remove_chars} " if remove_chars else ""
890    from_part = "FROM " if trim_type or remove_chars else ""
891    collation = f" COLLATE {collation}" if collation else ""
892    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
895def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
896    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
899def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
900    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
903def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
904    delim, *rest_args = expression.expressions
905    return self.sql(
906        reduce(
907            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
908            rest_args,
909        )
910    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
913def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
914    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
915    if bad_args:
916        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
917
918    return self.func(
919        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
920    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
923def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
924    bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
925    if bad_args:
926        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
927
928    return self.func(
929        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
930    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
933def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
934    names = []
935    for agg in aggregations:
936        if isinstance(agg, exp.Alias):
937            names.append(agg.alias)
938        else:
939            """
940            This case corresponds to aggregations without aliases being used as suffixes
941            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
942            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
943            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
944            """
945            agg_all_unquoted = agg.transform(
946                lambda node: (
947                    exp.Identifier(this=node.name, quoted=False)
948                    if isinstance(node, exp.Identifier)
949                    else node
950                )
951            )
952            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
953
954    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
957def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
958    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
962def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
963    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
966def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
967    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
970def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
971    a = self.sql(expression.left)
972    b = self.sql(expression.right)
973    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
976def is_parse_json(expression: exp.Expression) -> bool:
977    return isinstance(expression, exp.ParseJSON) or (
978        isinstance(expression, exp.Cast) and expression.is_type("json")
979    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
982def isnull_to_is_null(args: t.List) -> exp.Expression:
983    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
986def generatedasidentitycolumnconstraint_sql(
987    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
988) -> str:
989    start = self.sql(expression, "start") or "1"
990    increment = self.sql(expression, "increment") or "1"
991    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
 994def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 995    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 996        if expression.args.get("count"):
 997            self.unsupported(f"Only two arguments are supported in function {name}.")
 998
 999        return self.func(name, expression.this, expression.expression)
1000
1001    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
1004def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
1005    this = expression.this.copy()
1006
1007    return_type = expression.return_type
1008    if return_type.is_type(exp.DataType.Type.DATE):
1009        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
1010        # can truncate timestamp strings, because some dialects can't cast them to DATE
1011        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
1012
1013    expression.this.replace(exp.cast(this, return_type))
1014    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
1017def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
1018    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
1019        if cast and isinstance(expression, exp.TsOrDsAdd):
1020            expression = ts_or_ds_add_cast(expression)
1021
1022        return self.func(
1023            name,
1024            unit_to_var(expression),
1025            expression.expression,
1026            expression.this,
1027        )
1028
1029    return _delta_sql
def unit_to_str( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1032def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1033    unit = expression.args.get("unit")
1034
1035    if isinstance(unit, exp.Placeholder):
1036        return unit
1037    if unit:
1038        return exp.Literal.string(unit.name)
1039    return exp.Literal.string(default) if default else None
def unit_to_var( expression: sqlglot.expressions.Expression, default: str = 'DAY') -> Optional[sqlglot.expressions.Expression]:
1042def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
1043    unit = expression.args.get("unit")
1044
1045    if isinstance(unit, (exp.Var, exp.Placeholder)):
1046        return unit
1047    return exp.Var(this=default) if default else None
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1050def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1051    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1052    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1053    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1054
1055    return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1058def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1059    """Remove table refs from columns in when statements."""
1060    alias = expression.this.args.get("alias")
1061
1062    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1063        return self.dialect.normalize_identifier(identifier).name if identifier else None
1064
1065    targets = {normalize(expression.this.this)}
1066
1067    if alias:
1068        targets.add(normalize(alias.this))
1069
1070    for when in expression.expressions:
1071        when.transform(
1072            lambda node: (
1073                exp.column(node.this)
1074                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1075                else node
1076            ),
1077            copy=False,
1078        )
1079
1080    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1083def build_json_extract_path(
1084    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1085) -> t.Callable[[t.List], F]:
1086    def _builder(args: t.List) -> F:
1087        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1088        for arg in args[1:]:
1089            if not isinstance(arg, exp.Literal):
1090                # We use the fallback parser because we can't really transpile non-literals safely
1091                return expr_type.from_arg_list(args)
1092
1093            text = arg.name
1094            if is_int(text):
1095                index = int(text)
1096                segments.append(
1097                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1098                )
1099            else:
1100                segments.append(exp.JSONPathKey(this=text))
1101
1102        # This is done to avoid failing in the expression validator due to the arg count
1103        del args[2:]
1104        return expr_type(
1105            this=seq_get(args, 0),
1106            expression=exp.JSONPath(expressions=segments),
1107            only_json_types=arrow_req_json_type,
1108        )
1109
1110    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1113def json_extract_segments(
1114    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1115) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1116    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1117        path = expression.expression
1118        if not isinstance(path, exp.JSONPath):
1119            return rename_func(name)(self, expression)
1120
1121        segments = []
1122        for segment in path.expressions:
1123            path = self.sql(segment)
1124            if path:
1125                if isinstance(segment, exp.JSONPathPart) and (
1126                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1127                ):
1128                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1129
1130                segments.append(path)
1131
1132        if op:
1133            return f" {op} ".join([self.sql(expression.this), *segments])
1134        return self.func(name, expression.this, *segments)
1135
1136    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1139def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1140    if isinstance(expression.this, exp.JSONPathWildcard):
1141        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1142
1143    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1146def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1147    cond = expression.expression
1148    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1149        alias = cond.expressions[0]
1150        cond = cond.this
1151    elif isinstance(cond, exp.Predicate):
1152        alias = "_u"
1153    else:
1154        self.unsupported("Unsupported filter condition")
1155        return ""
1156
1157    unnest = exp.Unnest(expressions=[expression.this])
1158    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1159    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ToNumber) -> str:
1162def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
1163    return self.func(
1164        "TO_NUMBER",
1165        expression.this,
1166        expression.args.get("format"),
1167        expression.args.get("nlsparam"),
1168    )
def build_default_decimal_type( precision: Optional[int] = None, scale: Optional[int] = None) -> Callable[[sqlglot.expressions.DataType], sqlglot.expressions.DataType]:
1171def build_default_decimal_type(
1172    precision: t.Optional[int] = None, scale: t.Optional[int] = None
1173) -> t.Callable[[exp.DataType], exp.DataType]:
1174    def _builder(dtype: exp.DataType) -> exp.DataType:
1175        if dtype.expressions or precision is None:
1176            return dtype
1177
1178        params = f"{precision}{f', {scale}' if scale is not None else ''}"
1179        return exp.DataType.build(f"DECIMAL({params})")
1180
1181    return _builder