Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 85
 86        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 87        klass.parser_class = getattr(klass, "Parser", Parser)
 88        klass.generator_class = getattr(klass, "Generator", Generator)
 89
 90        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 91        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 92            klass.tokenizer_class._IDENTIFIERS.items()
 93        )[0]
 94
 95        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 96            return next(
 97                (
 98                    (s, e)
 99                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
100                    if t == token_type
101                ),
102                (None, None),
103            )
104
105        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
106        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
107        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
108
109        dialect_properties = {
110            **{
111                k: v
112                for k, v in vars(klass).items()
113                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
114            },
115            "TOKENIZER_CLASS": klass.tokenizer_class,
116        }
117
118        if enum not in ("", "bigquery"):
119            dialect_properties["SELECT_KINDS"] = ()
120
121        # Pass required dialect properties to the tokenizer, parser and generator classes
122        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
123            for name, value in dialect_properties.items():
124                if hasattr(subclass, name):
125                    setattr(subclass, name, value)
126
127        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
128            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
129
130        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
131            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
132                TokenType.ANTI,
133                TokenType.SEMI,
134            }
135
136        klass.generator_class.can_identify = klass.can_identify
137
138        return klass
139
140
141class Dialect(metaclass=_Dialect):
142    # Determines the base index offset for arrays
143    INDEX_OFFSET = 0
144
145    # If true unnest table aliases are considered only as column aliases
146    UNNEST_COLUMN_ONLY = False
147
148    # Determines whether or not the table alias comes after tablesample
149    ALIAS_POST_TABLESAMPLE = False
150
151    # Determines whether or not unquoted identifiers are resolved as uppercase
152    # When set to None, it means that the dialect treats all identifiers as case-insensitive
153    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
154
155    # Determines whether or not an unquoted identifier can start with a digit
156    IDENTIFIERS_CAN_START_WITH_DIGIT = False
157
158    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
159    DPIPE_IS_STRING_CONCAT = True
160
161    # Determines whether or not CONCAT's arguments must be strings
162    STRICT_STRING_CONCAT = False
163
164    # Determines whether or not user-defined data types are supported
165    SUPPORTS_USER_DEFINED_TYPES = True
166
167    # Determines whether or not SEMI/ANTI JOINs are supported
168    SUPPORTS_SEMI_ANTI_JOIN = True
169
170    # Determines how function names are going to be normalized
171    NORMALIZE_FUNCTIONS: bool | str = "upper"
172
173    # Determines whether the base comes first in the LOG function
174    LOG_BASE_FIRST = True
175
176    # Indicates the default null ordering method to use if not explicitly set
177    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
178    NULL_ORDERING = "nulls_are_small"
179
180    DATE_FORMAT = "'%Y-%m-%d'"
181    DATEINT_FORMAT = "'%Y%m%d'"
182    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
183
184    # Custom time mappings in which the key represents dialect time format
185    # and the value represents a python time format
186    TIME_MAPPING: t.Dict[str, str] = {}
187
188    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
189    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
190    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
191    FORMAT_MAPPING: t.Dict[str, str] = {}
192
193    # Mapping of an unescaped escape sequence to the corresponding character
194    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
195
196    # Columns that are auto-generated by the engine corresponding to this dialect
197    # Such columns may be excluded from SELECT * queries, for example
198    PSEUDOCOLUMNS: t.Set[str] = set()
199
200    # Autofilled
201    tokenizer_class = Tokenizer
202    parser_class = Parser
203    generator_class = Generator
204
205    # A trie of the time_mapping keys
206    TIME_TRIE: t.Dict = {}
207    FORMAT_TRIE: t.Dict = {}
208
209    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
210    INVERSE_TIME_TRIE: t.Dict = {}
211
212    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
213
214    def __eq__(self, other: t.Any) -> bool:
215        return type(self) == other
216
217    def __hash__(self) -> int:
218        return hash(type(self))
219
220    @classmethod
221    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
222        if not dialect:
223            return cls
224        if isinstance(dialect, _Dialect):
225            return dialect
226        if isinstance(dialect, Dialect):
227            return dialect.__class__
228
229        result = cls.get(dialect)
230        if not result:
231            raise ValueError(f"Unknown dialect '{dialect}'")
232
233        return result
234
235    @classmethod
236    def format_time(
237        cls, expression: t.Optional[str | exp.Expression]
238    ) -> t.Optional[exp.Expression]:
239        if isinstance(expression, str):
240            return exp.Literal.string(
241                # the time formats are quoted
242                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
243            )
244
245        if expression and expression.is_string:
246            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
247
248        return expression
249
250    @classmethod
251    def normalize_identifier(cls, expression: E) -> E:
252        """
253        Normalizes an unquoted identifier to either lower or upper case, thus essentially
254        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
255        they will be normalized to lowercase regardless of being quoted or not.
256        """
257        if isinstance(expression, exp.Identifier) and (
258            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
259        ):
260            expression.set(
261                "this",
262                expression.this.upper()
263                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
264                else expression.this.lower(),
265            )
266
267        return expression
268
269    @classmethod
270    def case_sensitive(cls, text: str) -> bool:
271        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
272        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
273            return False
274
275        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
276        return any(unsafe(char) for char in text)
277
278    @classmethod
279    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
280        """Checks if text can be identified given an identify option.
281
282        Args:
283            text: The text to check.
284            identify:
285                "always" or `True`: Always returns true.
286                "safe": True if the identifier is case-insensitive.
287
288        Returns:
289            Whether or not the given text can be identified.
290        """
291        if identify is True or identify == "always":
292            return True
293
294        if identify == "safe":
295            return not cls.case_sensitive(text)
296
297        return False
298
299    @classmethod
300    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
301        if isinstance(expression, exp.Identifier):
302            name = expression.this
303            expression.set(
304                "quoted",
305                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
306            )
307
308        return expression
309
310    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
311        return self.parser(**opts).parse(self.tokenize(sql), sql)
312
313    def parse_into(
314        self, expression_type: exp.IntoType, sql: str, **opts
315    ) -> t.List[t.Optional[exp.Expression]]:
316        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
317
318    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
319        return self.generator(**opts).generate(expression)
320
321    def transpile(self, sql: str, **opts) -> t.List[str]:
322        return [self.generate(expression, **opts) for expression in self.parse(sql)]
323
324    def tokenize(self, sql: str) -> t.List[Token]:
325        return self.tokenizer.tokenize(sql)
326
327    @property
328    def tokenizer(self) -> Tokenizer:
329        if not hasattr(self, "_tokenizer"):
330            self._tokenizer = self.tokenizer_class()
331        return self._tokenizer
332
333    def parser(self, **opts) -> Parser:
334        return self.parser_class(**opts)
335
336    def generator(self, **opts) -> Generator:
337        return self.generator_class(**opts)
338
339
340DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
341
342
343def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
344    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
345
346
347def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
348    if expression.args.get("accuracy"):
349        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
350    return self.func("APPROX_COUNT_DISTINCT", expression.this)
351
352
353def if_sql(
354    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
355) -> t.Callable[[Generator, exp.If], str]:
356    def _if_sql(self: Generator, expression: exp.If) -> str:
357        return self.func(
358            name,
359            expression.this,
360            expression.args.get("true"),
361            expression.args.get("false") or false_value,
362        )
363
364    return _if_sql
365
366
367def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
368    return self.binary(expression, "->")
369
370
371def arrow_json_extract_scalar_sql(
372    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
373) -> str:
374    return self.binary(expression, "->>")
375
376
377def inline_array_sql(self: Generator, expression: exp.Array) -> str:
378    return f"[{self.expressions(expression, flat=True)}]"
379
380
381def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
382    return self.like_sql(
383        exp.Like(
384            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
385        )
386    )
387
388
389def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
390    zone = self.sql(expression, "this")
391    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
392
393
394def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
395    if expression.args.get("recursive"):
396        self.unsupported("Recursive CTEs are unsupported")
397        expression.args["recursive"] = False
398    return self.with_sql(expression)
399
400
401def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
402    n = self.sql(expression, "this")
403    d = self.sql(expression, "expression")
404    return f"IF({d} <> 0, {n} / {d}, NULL)"
405
406
407def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
408    self.unsupported("TABLESAMPLE unsupported")
409    return self.sql(expression.this)
410
411
412def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
413    self.unsupported("PIVOT unsupported")
414    return ""
415
416
417def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
418    return self.cast_sql(expression)
419
420
421def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
422    self.unsupported("Properties unsupported")
423    return ""
424
425
426def no_comment_column_constraint_sql(
427    self: Generator, expression: exp.CommentColumnConstraint
428) -> str:
429    self.unsupported("CommentColumnConstraint unsupported")
430    return ""
431
432
433def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
434    self.unsupported("MAP_FROM_ENTRIES unsupported")
435    return ""
436
437
438def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
439    this = self.sql(expression, "this")
440    substr = self.sql(expression, "substr")
441    position = self.sql(expression, "position")
442    if position:
443        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
444    return f"STRPOS({this}, {substr})"
445
446
447def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
448    return (
449        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
450    )
451
452
453def var_map_sql(
454    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
455) -> str:
456    keys = expression.args["keys"]
457    values = expression.args["values"]
458
459    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
460        self.unsupported("Cannot convert array columns into map.")
461        return self.func(map_func_name, keys, values)
462
463    args = []
464    for key, value in zip(keys.expressions, values.expressions):
465        args.append(self.sql(key))
466        args.append(self.sql(value))
467
468    return self.func(map_func_name, *args)
469
470
471def format_time_lambda(
472    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
473) -> t.Callable[[t.List], E]:
474    """Helper used for time expressions.
475
476    Args:
477        exp_class: the expression class to instantiate.
478        dialect: target sql dialect.
479        default: the default format, True being time.
480
481    Returns:
482        A callable that can be used to return the appropriately formatted time expression.
483    """
484
485    def _format_time(args: t.List):
486        return exp_class(
487            this=seq_get(args, 0),
488            format=Dialect[dialect].format_time(
489                seq_get(args, 1)
490                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
491            ),
492        )
493
494    return _format_time
495
496
497def time_format(
498    dialect: DialectType = None,
499) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
500    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
501        """
502        Returns the time format for a given expression, unless it's equivalent
503        to the default time format of the dialect of interest.
504        """
505        time_format = self.format_time(expression)
506        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
507
508    return _time_format
509
510
511def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
512    """
513    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
514    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
515    columns are removed from the create statement.
516    """
517    has_schema = isinstance(expression.this, exp.Schema)
518    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
519
520    if has_schema and is_partitionable:
521        expression = expression.copy()
522        prop = expression.find(exp.PartitionedByProperty)
523        if prop and prop.this and not isinstance(prop.this, exp.Schema):
524            schema = expression.this
525            columns = {v.name.upper() for v in prop.this.expressions}
526            partitions = [col for col in schema.expressions if col.name.upper() in columns]
527            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
528            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
529            expression.set("this", schema)
530
531    return self.create_sql(expression)
532
533
534def parse_date_delta(
535    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
536) -> t.Callable[[t.List], E]:
537    def inner_func(args: t.List) -> E:
538        unit_based = len(args) == 3
539        this = args[2] if unit_based else seq_get(args, 0)
540        unit = args[0] if unit_based else exp.Literal.string("DAY")
541        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
542        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
543
544    return inner_func
545
546
547def parse_date_delta_with_interval(
548    expression_class: t.Type[E],
549) -> t.Callable[[t.List], t.Optional[E]]:
550    def func(args: t.List) -> t.Optional[E]:
551        if len(args) < 2:
552            return None
553
554        interval = args[1]
555
556        if not isinstance(interval, exp.Interval):
557            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
558
559        expression = interval.this
560        if expression and expression.is_string:
561            expression = exp.Literal.number(expression.this)
562
563        return expression_class(
564            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
565        )
566
567    return func
568
569
570def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
571    unit = seq_get(args, 0)
572    this = seq_get(args, 1)
573
574    if isinstance(this, exp.Cast) and this.is_type("date"):
575        return exp.DateTrunc(unit=unit, this=this)
576    return exp.TimestampTrunc(this=this, unit=unit)
577
578
579def date_add_interval_sql(
580    data_type: str, kind: str
581) -> t.Callable[[Generator, exp.Expression], str]:
582    def func(self: Generator, expression: exp.Expression) -> str:
583        this = self.sql(expression, "this")
584        unit = expression.args.get("unit")
585        unit = exp.var(unit.name.upper() if unit else "DAY")
586        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
587        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
588
589    return func
590
591
592def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
593    return self.func(
594        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
595    )
596
597
598def locate_to_strposition(args: t.List) -> exp.Expression:
599    return exp.StrPosition(
600        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
601    )
602
603
604def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
605    return self.func(
606        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
607    )
608
609
610def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
611    expression = expression.copy()
612    return self.sql(
613        exp.Substring(
614            this=expression.this, start=exp.Literal.number(1), length=expression.expression
615        )
616    )
617
618
619def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
620    expression = expression.copy()
621    return self.sql(
622        exp.Substring(
623            this=expression.this,
624            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
625        )
626    )
627
628
629def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
630    return self.sql(exp.cast(expression.this, "timestamp"))
631
632
633def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
634    return self.sql(exp.cast(expression.this, "date"))
635
636
637# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
638def encode_decode_sql(
639    self: Generator, expression: exp.Expression, name: str, replace: bool = True
640) -> str:
641    charset = expression.args.get("charset")
642    if charset and charset.name.lower() != "utf-8":
643        self.unsupported(f"Expected utf-8 character set, got {charset}.")
644
645    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
646
647
648def min_or_least(self: Generator, expression: exp.Min) -> str:
649    name = "LEAST" if expression.expressions else "MIN"
650    return rename_func(name)(self, expression)
651
652
653def max_or_greatest(self: Generator, expression: exp.Max) -> str:
654    name = "GREATEST" if expression.expressions else "MAX"
655    return rename_func(name)(self, expression)
656
657
658def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
659    cond = expression.this
660
661    if isinstance(expression.this, exp.Distinct):
662        cond = expression.this.expressions[0]
663        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
664
665    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
666
667
668def trim_sql(self: Generator, expression: exp.Trim) -> str:
669    target = self.sql(expression, "this")
670    trim_type = self.sql(expression, "position")
671    remove_chars = self.sql(expression, "expression")
672    collation = self.sql(expression, "collation")
673
674    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
675    if not remove_chars and not collation:
676        return self.trim_sql(expression)
677
678    trim_type = f"{trim_type} " if trim_type else ""
679    remove_chars = f"{remove_chars} " if remove_chars else ""
680    from_part = "FROM " if trim_type or remove_chars else ""
681    collation = f" COLLATE {collation}" if collation else ""
682    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
683
684
685def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
686    return self.func("STRPTIME", expression.this, self.format_time(expression))
687
688
689def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
690    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
691        _dialect = Dialect.get_or_raise(dialect)
692        time_format = self.format_time(expression)
693        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
694            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
695
696        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
697
698    return _ts_or_ds_to_date_sql
699
700
701def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
702    expression = expression.copy()
703    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
704
705
706def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
707    expression = expression.copy()
708    delim, *rest_args = expression.expressions
709    return self.sql(
710        reduce(
711            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
712            rest_args,
713        )
714    )
715
716
717def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
718    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
719    if bad_args:
720        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
721
722    return self.func(
723        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
724    )
725
726
727def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
728    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
729    if bad_args:
730        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
731
732    return self.func(
733        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
734    )
735
736
737def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
738    names = []
739    for agg in aggregations:
740        if isinstance(agg, exp.Alias):
741            names.append(agg.alias)
742        else:
743            """
744            This case corresponds to aggregations without aliases being used as suffixes
745            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
746            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
747            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
748            """
749            agg_all_unquoted = agg.transform(
750                lambda node: exp.Identifier(this=node.name, quoted=False)
751                if isinstance(node, exp.Identifier)
752                else node
753            )
754            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
755
756    return names
757
758
759def simplify_literal(expression: E) -> E:
760    if not isinstance(expression.expression, exp.Literal):
761        from sqlglot.optimizer.simplify import simplify
762
763        simplify(expression.expression)
764
765    return expression
766
767
768def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
769    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
770
771
772# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
773def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
774    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
775
776
777def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
778    return self.func("MAX", expression.this)
779
780
781def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
782    a = self.sql(expression.left)
783    b = self.sql(expression.right)
784    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
785
786
787# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
788def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
789    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
790
791
792def is_parse_json(expression: exp.Expression) -> bool:
793    return isinstance(expression, exp.ParseJSON) or (
794        isinstance(expression, exp.Cast) and expression.is_type("json")
795    )
796
797
798def isnull_to_is_null(args: t.List) -> exp.Expression:
799    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
800
801
802def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
803    if expression.expression.args.get("with"):
804        expression = expression.copy()
805        expression.set("with", expression.expression.args["with"].pop())
806    return self.insert_sql(expression)
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
142class Dialect(metaclass=_Dialect):
143    # Determines the base index offset for arrays
144    INDEX_OFFSET = 0
145
146    # If true unnest table aliases are considered only as column aliases
147    UNNEST_COLUMN_ONLY = False
148
149    # Determines whether or not the table alias comes after tablesample
150    ALIAS_POST_TABLESAMPLE = False
151
152    # Determines whether or not unquoted identifiers are resolved as uppercase
153    # When set to None, it means that the dialect treats all identifiers as case-insensitive
154    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
155
156    # Determines whether or not an unquoted identifier can start with a digit
157    IDENTIFIERS_CAN_START_WITH_DIGIT = False
158
159    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
160    DPIPE_IS_STRING_CONCAT = True
161
162    # Determines whether or not CONCAT's arguments must be strings
163    STRICT_STRING_CONCAT = False
164
165    # Determines whether or not user-defined data types are supported
166    SUPPORTS_USER_DEFINED_TYPES = True
167
168    # Determines whether or not SEMI/ANTI JOINs are supported
169    SUPPORTS_SEMI_ANTI_JOIN = True
170
171    # Determines how function names are going to be normalized
172    NORMALIZE_FUNCTIONS: bool | str = "upper"
173
174    # Determines whether the base comes first in the LOG function
175    LOG_BASE_FIRST = True
176
177    # Indicates the default null ordering method to use if not explicitly set
178    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
179    NULL_ORDERING = "nulls_are_small"
180
181    DATE_FORMAT = "'%Y-%m-%d'"
182    DATEINT_FORMAT = "'%Y%m%d'"
183    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
184
185    # Custom time mappings in which the key represents dialect time format
186    # and the value represents a python time format
187    TIME_MAPPING: t.Dict[str, str] = {}
188
189    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
190    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
191    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
192    FORMAT_MAPPING: t.Dict[str, str] = {}
193
194    # Mapping of an unescaped escape sequence to the corresponding character
195    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
196
197    # Columns that are auto-generated by the engine corresponding to this dialect
198    # Such columns may be excluded from SELECT * queries, for example
199    PSEUDOCOLUMNS: t.Set[str] = set()
200
201    # Autofilled
202    tokenizer_class = Tokenizer
203    parser_class = Parser
204    generator_class = Generator
205
206    # A trie of the time_mapping keys
207    TIME_TRIE: t.Dict = {}
208    FORMAT_TRIE: t.Dict = {}
209
210    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
211    INVERSE_TIME_TRIE: t.Dict = {}
212
213    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
214
215    def __eq__(self, other: t.Any) -> bool:
216        return type(self) == other
217
218    def __hash__(self) -> int:
219        return hash(type(self))
220
221    @classmethod
222    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
223        if not dialect:
224            return cls
225        if isinstance(dialect, _Dialect):
226            return dialect
227        if isinstance(dialect, Dialect):
228            return dialect.__class__
229
230        result = cls.get(dialect)
231        if not result:
232            raise ValueError(f"Unknown dialect '{dialect}'")
233
234        return result
235
236    @classmethod
237    def format_time(
238        cls, expression: t.Optional[str | exp.Expression]
239    ) -> t.Optional[exp.Expression]:
240        if isinstance(expression, str):
241            return exp.Literal.string(
242                # the time formats are quoted
243                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
244            )
245
246        if expression and expression.is_string:
247            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
248
249        return expression
250
251    @classmethod
252    def normalize_identifier(cls, expression: E) -> E:
253        """
254        Normalizes an unquoted identifier to either lower or upper case, thus essentially
255        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
256        they will be normalized to lowercase regardless of being quoted or not.
257        """
258        if isinstance(expression, exp.Identifier) and (
259            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
260        ):
261            expression.set(
262                "this",
263                expression.this.upper()
264                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
265                else expression.this.lower(),
266            )
267
268        return expression
269
270    @classmethod
271    def case_sensitive(cls, text: str) -> bool:
272        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
273        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
274            return False
275
276        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
277        return any(unsafe(char) for char in text)
278
279    @classmethod
280    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
281        """Checks if text can be identified given an identify option.
282
283        Args:
284            text: The text to check.
285            identify:
286                "always" or `True`: Always returns true.
287                "safe": True if the identifier is case-insensitive.
288
289        Returns:
290            Whether or not the given text can be identified.
291        """
292        if identify is True or identify == "always":
293            return True
294
295        if identify == "safe":
296            return not cls.case_sensitive(text)
297
298        return False
299
300    @classmethod
301    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
302        if isinstance(expression, exp.Identifier):
303            name = expression.this
304            expression.set(
305                "quoted",
306                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
307            )
308
309        return expression
310
311    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
312        return self.parser(**opts).parse(self.tokenize(sql), sql)
313
314    def parse_into(
315        self, expression_type: exp.IntoType, sql: str, **opts
316    ) -> t.List[t.Optional[exp.Expression]]:
317        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
318
319    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
320        return self.generator(**opts).generate(expression)
321
322    def transpile(self, sql: str, **opts) -> t.List[str]:
323        return [self.generate(expression, **opts) for expression in self.parse(sql)]
324
325    def tokenize(self, sql: str) -> t.List[Token]:
326        return self.tokenizer.tokenize(sql)
327
328    @property
329    def tokenizer(self) -> Tokenizer:
330        if not hasattr(self, "_tokenizer"):
331            self._tokenizer = self.tokenizer_class()
332        return self._tokenizer
333
334    def parser(self, **opts) -> Parser:
335        return self.parser_class(**opts)
336
337    def generator(self, **opts) -> Generator:
338        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
ESCAPE_SEQUENCES: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
221    @classmethod
222    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
223        if not dialect:
224            return cls
225        if isinstance(dialect, _Dialect):
226            return dialect
227        if isinstance(dialect, Dialect):
228            return dialect.__class__
229
230        result = cls.get(dialect)
231        if not result:
232            raise ValueError(f"Unknown dialect '{dialect}'")
233
234        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
236    @classmethod
237    def format_time(
238        cls, expression: t.Optional[str | exp.Expression]
239    ) -> t.Optional[exp.Expression]:
240        if isinstance(expression, str):
241            return exp.Literal.string(
242                # the time formats are quoted
243                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
244            )
245
246        if expression and expression.is_string:
247            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
248
249        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
251    @classmethod
252    def normalize_identifier(cls, expression: E) -> E:
253        """
254        Normalizes an unquoted identifier to either lower or upper case, thus essentially
255        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
256        they will be normalized to lowercase regardless of being quoted or not.
257        """
258        if isinstance(expression, exp.Identifier) and (
259            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
260        ):
261            expression.set(
262                "this",
263                expression.this.upper()
264                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
265                else expression.this.lower(),
266            )
267
268        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
270    @classmethod
271    def case_sensitive(cls, text: str) -> bool:
272        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
273        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
274            return False
275
276        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
277        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
279    @classmethod
280    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
281        """Checks if text can be identified given an identify option.
282
283        Args:
284            text: The text to check.
285            identify:
286                "always" or `True`: Always returns true.
287                "safe": True if the identifier is case-insensitive.
288
289        Returns:
290            Whether or not the given text can be identified.
291        """
292        if identify is True or identify == "always":
293            return True
294
295        if identify == "safe":
296            return not cls.case_sensitive(text)
297
298        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
300    @classmethod
301    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
302        if isinstance(expression, exp.Identifier):
303            name = expression.this
304            expression.set(
305                "quoted",
306                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
307            )
308
309        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
311    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
312        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
314    def parse_into(
315        self, expression_type: exp.IntoType, sql: str, **opts
316    ) -> t.List[t.Optional[exp.Expression]]:
317        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
319    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
320        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
322    def transpile(self, sql: str, **opts) -> t.List[str]:
323        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
325    def tokenize(self, sql: str) -> t.List[Token]:
326        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
334    def parser(self, **opts) -> Parser:
335        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
337    def generator(self, **opts) -> Generator:
338        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
344def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
345    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
348def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
349    if expression.args.get("accuracy"):
350        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
351    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
354def if_sql(
355    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
356) -> t.Callable[[Generator, exp.If], str]:
357    def _if_sql(self: Generator, expression: exp.If) -> str:
358        return self.func(
359            name,
360            expression.this,
361            expression.args.get("true"),
362            expression.args.get("false") or false_value,
363        )
364
365    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
368def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
369    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
372def arrow_json_extract_scalar_sql(
373    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
374) -> str:
375    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
378def inline_array_sql(self: Generator, expression: exp.Array) -> str:
379    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
382def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
383    return self.like_sql(
384        exp.Like(
385            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
386        )
387    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
390def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
391    zone = self.sql(expression, "this")
392    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
395def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
396    if expression.args.get("recursive"):
397        self.unsupported("Recursive CTEs are unsupported")
398        expression.args["recursive"] = False
399    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
402def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
403    n = self.sql(expression, "this")
404    d = self.sql(expression, "expression")
405    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
408def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
409    self.unsupported("TABLESAMPLE unsupported")
410    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
413def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
414    self.unsupported("PIVOT unsupported")
415    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
418def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
419    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
422def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
423    self.unsupported("Properties unsupported")
424    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
427def no_comment_column_constraint_sql(
428    self: Generator, expression: exp.CommentColumnConstraint
429) -> str:
430    self.unsupported("CommentColumnConstraint unsupported")
431    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
434def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
435    self.unsupported("MAP_FROM_ENTRIES unsupported")
436    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
439def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
440    this = self.sql(expression, "this")
441    substr = self.sql(expression, "substr")
442    position = self.sql(expression, "position")
443    if position:
444        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
445    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
448def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
449    return (
450        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
451    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
454def var_map_sql(
455    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
456) -> str:
457    keys = expression.args["keys"]
458    values = expression.args["values"]
459
460    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
461        self.unsupported("Cannot convert array columns into map.")
462        return self.func(map_func_name, keys, values)
463
464    args = []
465    for key, value in zip(keys.expressions, values.expressions):
466        args.append(self.sql(key))
467        args.append(self.sql(value))
468
469    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
472def format_time_lambda(
473    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
474) -> t.Callable[[t.List], E]:
475    """Helper used for time expressions.
476
477    Args:
478        exp_class: the expression class to instantiate.
479        dialect: target sql dialect.
480        default: the default format, True being time.
481
482    Returns:
483        A callable that can be used to return the appropriately formatted time expression.
484    """
485
486    def _format_time(args: t.List):
487        return exp_class(
488            this=seq_get(args, 0),
489            format=Dialect[dialect].format_time(
490                seq_get(args, 1)
491                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
492            ),
493        )
494
495    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
498def time_format(
499    dialect: DialectType = None,
500) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
501    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
502        """
503        Returns the time format for a given expression, unless it's equivalent
504        to the default time format of the dialect of interest.
505        """
506        time_format = self.format_time(expression)
507        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
508
509    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
512def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
513    """
514    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
515    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
516    columns are removed from the create statement.
517    """
518    has_schema = isinstance(expression.this, exp.Schema)
519    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
520
521    if has_schema and is_partitionable:
522        expression = expression.copy()
523        prop = expression.find(exp.PartitionedByProperty)
524        if prop and prop.this and not isinstance(prop.this, exp.Schema):
525            schema = expression.this
526            columns = {v.name.upper() for v in prop.this.expressions}
527            partitions = [col for col in schema.expressions if col.name.upper() in columns]
528            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
529            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
530            expression.set("this", schema)
531
532    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
535def parse_date_delta(
536    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
537) -> t.Callable[[t.List], E]:
538    def inner_func(args: t.List) -> E:
539        unit_based = len(args) == 3
540        this = args[2] if unit_based else seq_get(args, 0)
541        unit = args[0] if unit_based else exp.Literal.string("DAY")
542        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
543        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
544
545    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
548def parse_date_delta_with_interval(
549    expression_class: t.Type[E],
550) -> t.Callable[[t.List], t.Optional[E]]:
551    def func(args: t.List) -> t.Optional[E]:
552        if len(args) < 2:
553            return None
554
555        interval = args[1]
556
557        if not isinstance(interval, exp.Interval):
558            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
559
560        expression = interval.this
561        if expression and expression.is_string:
562            expression = exp.Literal.number(expression.this)
563
564        return expression_class(
565            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
566        )
567
568    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
571def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
572    unit = seq_get(args, 0)
573    this = seq_get(args, 1)
574
575    if isinstance(this, exp.Cast) and this.is_type("date"):
576        return exp.DateTrunc(unit=unit, this=this)
577    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
580def date_add_interval_sql(
581    data_type: str, kind: str
582) -> t.Callable[[Generator, exp.Expression], str]:
583    def func(self: Generator, expression: exp.Expression) -> str:
584        this = self.sql(expression, "this")
585        unit = expression.args.get("unit")
586        unit = exp.var(unit.name.upper() if unit else "DAY")
587        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
588        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
589
590    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
593def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
594    return self.func(
595        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
596    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
599def locate_to_strposition(args: t.List) -> exp.Expression:
600    return exp.StrPosition(
601        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
602    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
605def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
606    return self.func(
607        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
608    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
611def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
612    expression = expression.copy()
613    return self.sql(
614        exp.Substring(
615            this=expression.this, start=exp.Literal.number(1), length=expression.expression
616        )
617    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
620def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
621    expression = expression.copy()
622    return self.sql(
623        exp.Substring(
624            this=expression.this,
625            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
626        )
627    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
630def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
631    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
634def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
635    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
639def encode_decode_sql(
640    self: Generator, expression: exp.Expression, name: str, replace: bool = True
641) -> str:
642    charset = expression.args.get("charset")
643    if charset and charset.name.lower() != "utf-8":
644        self.unsupported(f"Expected utf-8 character set, got {charset}.")
645
646    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
649def min_or_least(self: Generator, expression: exp.Min) -> str:
650    name = "LEAST" if expression.expressions else "MIN"
651    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
654def max_or_greatest(self: Generator, expression: exp.Max) -> str:
655    name = "GREATEST" if expression.expressions else "MAX"
656    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
659def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
660    cond = expression.this
661
662    if isinstance(expression.this, exp.Distinct):
663        cond = expression.this.expressions[0]
664        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
665
666    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
669def trim_sql(self: Generator, expression: exp.Trim) -> str:
670    target = self.sql(expression, "this")
671    trim_type = self.sql(expression, "position")
672    remove_chars = self.sql(expression, "expression")
673    collation = self.sql(expression, "collation")
674
675    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
676    if not remove_chars and not collation:
677        return self.trim_sql(expression)
678
679    trim_type = f"{trim_type} " if trim_type else ""
680    remove_chars = f"{remove_chars} " if remove_chars else ""
681    from_part = "FROM " if trim_type or remove_chars else ""
682    collation = f" COLLATE {collation}" if collation else ""
683    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
686def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
687    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
690def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
691    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
692        _dialect = Dialect.get_or_raise(dialect)
693        time_format = self.format_time(expression)
694        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
695            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
696
697        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
698
699    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
702def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
703    expression = expression.copy()
704    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
707def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
708    expression = expression.copy()
709    delim, *rest_args = expression.expressions
710    return self.sql(
711        reduce(
712            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
713            rest_args,
714        )
715    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
718def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
719    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
720    if bad_args:
721        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
722
723    return self.func(
724        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
725    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
728def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
729    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
730    if bad_args:
731        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
732
733    return self.func(
734        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
735    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
738def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
739    names = []
740    for agg in aggregations:
741        if isinstance(agg, exp.Alias):
742            names.append(agg.alias)
743        else:
744            """
745            This case corresponds to aggregations without aliases being used as suffixes
746            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
747            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
748            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
749            """
750            agg_all_unquoted = agg.transform(
751                lambda node: exp.Identifier(this=node.name, quoted=False)
752                if isinstance(node, exp.Identifier)
753                else node
754            )
755            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
756
757    return names
def simplify_literal(expression: ~E) -> ~E:
760def simplify_literal(expression: E) -> E:
761    if not isinstance(expression.expression, exp.Literal):
762        from sqlglot.optimizer.simplify import simplify
763
764        simplify(expression.expression)
765
766    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
769def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
770    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
774def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
775    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
778def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
779    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
782def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
783    a = self.sql(expression.left)
784    b = self.sql(expression.right)
785    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
789def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
790    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
793def is_parse_json(expression: exp.Expression) -> bool:
794    return isinstance(expression, exp.ParseJSON) or (
795        isinstance(expression, exp.Cast) and expression.is_type("json")
796    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
799def isnull_to_is_null(args: t.List) -> exp.Expression:
800    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def move_insert_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Insert) -> str:
803def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
804    if expression.expression.args.get("with"):
805        expression = expression.copy()
806        expression.set("with", expression.expression.args["with"].pop())
807    return self.insert_sql(expression)