Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import TIMEZONES, format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 85
 86        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 87        klass.parser_class = getattr(klass, "Parser", Parser)
 88        klass.generator_class = getattr(klass, "Generator", Generator)
 89
 90        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 91        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 92            klass.tokenizer_class._IDENTIFIERS.items()
 93        )[0]
 94
 95        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 96            return next(
 97                (
 98                    (s, e)
 99                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
100                    if t == token_type
101                ),
102                (None, None),
103            )
104
105        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
106        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
107        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
108
109        dialect_properties = {
110            **{
111                k: v
112                for k, v in vars(klass).items()
113                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
114            },
115            "TOKENIZER_CLASS": klass.tokenizer_class,
116        }
117
118        if enum not in ("", "bigquery"):
119            dialect_properties["SELECT_KINDS"] = ()
120
121        # Pass required dialect properties to the tokenizer, parser and generator classes
122        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
123            for name, value in dialect_properties.items():
124                if hasattr(subclass, name):
125                    setattr(subclass, name, value)
126
127        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
128            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
129
130        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
131            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
132                TokenType.ANTI,
133                TokenType.SEMI,
134            }
135
136        klass.generator_class.can_identify = klass.can_identify
137
138        return klass
139
140
141class Dialect(metaclass=_Dialect):
142    # Determines the base index offset for arrays
143    INDEX_OFFSET = 0
144
145    # If true unnest table aliases are considered only as column aliases
146    UNNEST_COLUMN_ONLY = False
147
148    # Determines whether or not the table alias comes after tablesample
149    ALIAS_POST_TABLESAMPLE = False
150
151    # Determines whether or not unquoted identifiers are resolved as uppercase
152    # When set to None, it means that the dialect treats all identifiers as case-insensitive
153    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
154
155    # Determines whether or not an unquoted identifier can start with a digit
156    IDENTIFIERS_CAN_START_WITH_DIGIT = False
157
158    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
159    DPIPE_IS_STRING_CONCAT = True
160
161    # Determines whether or not CONCAT's arguments must be strings
162    STRICT_STRING_CONCAT = False
163
164    # Determines whether or not user-defined data types are supported
165    SUPPORTS_USER_DEFINED_TYPES = True
166
167    # Determines whether or not SEMI/ANTI JOINs are supported
168    SUPPORTS_SEMI_ANTI_JOIN = True
169
170    # Determines how function names are going to be normalized
171    NORMALIZE_FUNCTIONS: bool | str = "upper"
172
173    # Determines whether the base comes first in the LOG function
174    LOG_BASE_FIRST = True
175
176    # Indicates the default null ordering method to use if not explicitly set
177    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
178    NULL_ORDERING = "nulls_are_small"
179
180    DATE_FORMAT = "'%Y-%m-%d'"
181    DATEINT_FORMAT = "'%Y%m%d'"
182    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
183
184    # Custom time mappings in which the key represents dialect time format
185    # and the value represents a python time format
186    TIME_MAPPING: t.Dict[str, str] = {}
187
188    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
189    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
190    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
191    FORMAT_MAPPING: t.Dict[str, str] = {}
192
193    # Mapping of an unescaped escape sequence to the corresponding character
194    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
195
196    # Columns that are auto-generated by the engine corresponding to this dialect
197    # Such columns may be excluded from SELECT * queries, for example
198    PSEUDOCOLUMNS: t.Set[str] = set()
199
200    # Autofilled
201    tokenizer_class = Tokenizer
202    parser_class = Parser
203    generator_class = Generator
204
205    # A trie of the time_mapping keys
206    TIME_TRIE: t.Dict = {}
207    FORMAT_TRIE: t.Dict = {}
208
209    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
210    INVERSE_TIME_TRIE: t.Dict = {}
211
212    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
213
214    def __eq__(self, other: t.Any) -> bool:
215        return type(self) == other
216
217    def __hash__(self) -> int:
218        return hash(type(self))
219
220    @classmethod
221    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
222        if not dialect:
223            return cls
224        if isinstance(dialect, _Dialect):
225            return dialect
226        if isinstance(dialect, Dialect):
227            return dialect.__class__
228
229        result = cls.get(dialect)
230        if not result:
231            raise ValueError(f"Unknown dialect '{dialect}'")
232
233        return result
234
235    @classmethod
236    def format_time(
237        cls, expression: t.Optional[str | exp.Expression]
238    ) -> t.Optional[exp.Expression]:
239        if isinstance(expression, str):
240            return exp.Literal.string(
241                # the time formats are quoted
242                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
243            )
244
245        if expression and expression.is_string:
246            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
247
248        return expression
249
250    @classmethod
251    def normalize_identifier(cls, expression: E) -> E:
252        """
253        Normalizes an unquoted identifier to either lower or upper case, thus essentially
254        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
255        they will be normalized to lowercase regardless of being quoted or not.
256        """
257        if isinstance(expression, exp.Identifier) and (
258            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
259        ):
260            expression.set(
261                "this",
262                expression.this.upper()
263                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
264                else expression.this.lower(),
265            )
266
267        return expression
268
269    @classmethod
270    def case_sensitive(cls, text: str) -> bool:
271        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
272        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
273            return False
274
275        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
276        return any(unsafe(char) for char in text)
277
278    @classmethod
279    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
280        """Checks if text can be identified given an identify option.
281
282        Args:
283            text: The text to check.
284            identify:
285                "always" or `True`: Always returns true.
286                "safe": True if the identifier is case-insensitive.
287
288        Returns:
289            Whether or not the given text can be identified.
290        """
291        if identify is True or identify == "always":
292            return True
293
294        if identify == "safe":
295            return not cls.case_sensitive(text)
296
297        return False
298
299    @classmethod
300    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
301        if isinstance(expression, exp.Identifier):
302            name = expression.this
303            expression.set(
304                "quoted",
305                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
306            )
307
308        return expression
309
310    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
311        return self.parser(**opts).parse(self.tokenize(sql), sql)
312
313    def parse_into(
314        self, expression_type: exp.IntoType, sql: str, **opts
315    ) -> t.List[t.Optional[exp.Expression]]:
316        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
317
318    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
319        return self.generator(**opts).generate(expression, copy=copy)
320
321    def transpile(self, sql: str, **opts) -> t.List[str]:
322        return [
323            self.generate(expression, copy=False, **opts) if expression else ""
324            for expression in self.parse(sql)
325        ]
326
327    def tokenize(self, sql: str) -> t.List[Token]:
328        return self.tokenizer.tokenize(sql)
329
330    @property
331    def tokenizer(self) -> Tokenizer:
332        if not hasattr(self, "_tokenizer"):
333            self._tokenizer = self.tokenizer_class()
334        return self._tokenizer
335
336    def parser(self, **opts) -> Parser:
337        return self.parser_class(**opts)
338
339    def generator(self, **opts) -> Generator:
340        return self.generator_class(**opts)
341
342
343DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
344
345
346def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
347    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
348
349
350def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
351    if expression.args.get("accuracy"):
352        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
353    return self.func("APPROX_COUNT_DISTINCT", expression.this)
354
355
356def if_sql(
357    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
358) -> t.Callable[[Generator, exp.If], str]:
359    def _if_sql(self: Generator, expression: exp.If) -> str:
360        return self.func(
361            name,
362            expression.this,
363            expression.args.get("true"),
364            expression.args.get("false") or false_value,
365        )
366
367    return _if_sql
368
369
370def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
371    return self.binary(expression, "->")
372
373
374def arrow_json_extract_scalar_sql(
375    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
376) -> str:
377    return self.binary(expression, "->>")
378
379
380def inline_array_sql(self: Generator, expression: exp.Array) -> str:
381    return f"[{self.expressions(expression, flat=True)}]"
382
383
384def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
385    return self.like_sql(
386        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
387    )
388
389
390def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
391    zone = self.sql(expression, "this")
392    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
393
394
395def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
396    if expression.args.get("recursive"):
397        self.unsupported("Recursive CTEs are unsupported")
398        expression.args["recursive"] = False
399    return self.with_sql(expression)
400
401
402def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
403    n = self.sql(expression, "this")
404    d = self.sql(expression, "expression")
405    return f"IF({d} <> 0, {n} / {d}, NULL)"
406
407
408def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
409    self.unsupported("TABLESAMPLE unsupported")
410    return self.sql(expression.this)
411
412
413def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
414    self.unsupported("PIVOT unsupported")
415    return ""
416
417
418def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
419    return self.cast_sql(expression)
420
421
422def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
423    self.unsupported("Properties unsupported")
424    return ""
425
426
427def no_comment_column_constraint_sql(
428    self: Generator, expression: exp.CommentColumnConstraint
429) -> str:
430    self.unsupported("CommentColumnConstraint unsupported")
431    return ""
432
433
434def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
435    self.unsupported("MAP_FROM_ENTRIES unsupported")
436    return ""
437
438
439def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
440    this = self.sql(expression, "this")
441    substr = self.sql(expression, "substr")
442    position = self.sql(expression, "position")
443    if position:
444        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
445    return f"STRPOS({this}, {substr})"
446
447
448def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
449    return (
450        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
451    )
452
453
454def var_map_sql(
455    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
456) -> str:
457    keys = expression.args["keys"]
458    values = expression.args["values"]
459
460    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
461        self.unsupported("Cannot convert array columns into map.")
462        return self.func(map_func_name, keys, values)
463
464    args = []
465    for key, value in zip(keys.expressions, values.expressions):
466        args.append(self.sql(key))
467        args.append(self.sql(value))
468
469    return self.func(map_func_name, *args)
470
471
472def format_time_lambda(
473    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
474) -> t.Callable[[t.List], E]:
475    """Helper used for time expressions.
476
477    Args:
478        exp_class: the expression class to instantiate.
479        dialect: target sql dialect.
480        default: the default format, True being time.
481
482    Returns:
483        A callable that can be used to return the appropriately formatted time expression.
484    """
485
486    def _format_time(args: t.List):
487        return exp_class(
488            this=seq_get(args, 0),
489            format=Dialect[dialect].format_time(
490                seq_get(args, 1)
491                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
492            ),
493        )
494
495    return _format_time
496
497
498def time_format(
499    dialect: DialectType = None,
500) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
501    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
502        """
503        Returns the time format for a given expression, unless it's equivalent
504        to the default time format of the dialect of interest.
505        """
506        time_format = self.format_time(expression)
507        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
508
509    return _time_format
510
511
512def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
513    """
514    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
515    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
516    columns are removed from the create statement.
517    """
518    has_schema = isinstance(expression.this, exp.Schema)
519    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
520
521    if has_schema and is_partitionable:
522        prop = expression.find(exp.PartitionedByProperty)
523        if prop and prop.this and not isinstance(prop.this, exp.Schema):
524            schema = expression.this
525            columns = {v.name.upper() for v in prop.this.expressions}
526            partitions = [col for col in schema.expressions if col.name.upper() in columns]
527            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
528            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
529            expression.set("this", schema)
530
531    return self.create_sql(expression)
532
533
534def parse_date_delta(
535    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
536) -> t.Callable[[t.List], E]:
537    def inner_func(args: t.List) -> E:
538        unit_based = len(args) == 3
539        this = args[2] if unit_based else seq_get(args, 0)
540        unit = args[0] if unit_based else exp.Literal.string("DAY")
541        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
542        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
543
544    return inner_func
545
546
547def parse_date_delta_with_interval(
548    expression_class: t.Type[E],
549) -> t.Callable[[t.List], t.Optional[E]]:
550    def func(args: t.List) -> t.Optional[E]:
551        if len(args) < 2:
552            return None
553
554        interval = args[1]
555
556        if not isinstance(interval, exp.Interval):
557            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
558
559        expression = interval.this
560        if expression and expression.is_string:
561            expression = exp.Literal.number(expression.this)
562
563        return expression_class(
564            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
565        )
566
567    return func
568
569
570def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
571    unit = seq_get(args, 0)
572    this = seq_get(args, 1)
573
574    if isinstance(this, exp.Cast) and this.is_type("date"):
575        return exp.DateTrunc(unit=unit, this=this)
576    return exp.TimestampTrunc(this=this, unit=unit)
577
578
579def date_add_interval_sql(
580    data_type: str, kind: str
581) -> t.Callable[[Generator, exp.Expression], str]:
582    def func(self: Generator, expression: exp.Expression) -> str:
583        this = self.sql(expression, "this")
584        unit = expression.args.get("unit")
585        unit = exp.var(unit.name.upper() if unit else "DAY")
586        interval = exp.Interval(this=expression.expression, unit=unit)
587        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
588
589    return func
590
591
592def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
593    return self.func(
594        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
595    )
596
597
598def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
599    if not expression.expression:
600        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
601    if expression.text("expression").lower() in TIMEZONES:
602        return self.sql(
603            exp.AtTimeZone(
604                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
605                zone=expression.expression,
606            )
607        )
608    return self.function_fallback_sql(expression)
609
610
611def locate_to_strposition(args: t.List) -> exp.Expression:
612    return exp.StrPosition(
613        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
614    )
615
616
617def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
618    return self.func(
619        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
620    )
621
622
623def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
624    return self.sql(
625        exp.Substring(
626            this=expression.this, start=exp.Literal.number(1), length=expression.expression
627        )
628    )
629
630
631def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
632    return self.sql(
633        exp.Substring(
634            this=expression.this,
635            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
636        )
637    )
638
639
640def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
641    return self.sql(exp.cast(expression.this, "timestamp"))
642
643
644def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
645    return self.sql(exp.cast(expression.this, "date"))
646
647
648# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
649def encode_decode_sql(
650    self: Generator, expression: exp.Expression, name: str, replace: bool = True
651) -> str:
652    charset = expression.args.get("charset")
653    if charset and charset.name.lower() != "utf-8":
654        self.unsupported(f"Expected utf-8 character set, got {charset}.")
655
656    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
657
658
659def min_or_least(self: Generator, expression: exp.Min) -> str:
660    name = "LEAST" if expression.expressions else "MIN"
661    return rename_func(name)(self, expression)
662
663
664def max_or_greatest(self: Generator, expression: exp.Max) -> str:
665    name = "GREATEST" if expression.expressions else "MAX"
666    return rename_func(name)(self, expression)
667
668
669def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
670    cond = expression.this
671
672    if isinstance(expression.this, exp.Distinct):
673        cond = expression.this.expressions[0]
674        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
675
676    return self.func("sum", exp.func("if", cond, 1, 0))
677
678
679def trim_sql(self: Generator, expression: exp.Trim) -> str:
680    target = self.sql(expression, "this")
681    trim_type = self.sql(expression, "position")
682    remove_chars = self.sql(expression, "expression")
683    collation = self.sql(expression, "collation")
684
685    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
686    if not remove_chars and not collation:
687        return self.trim_sql(expression)
688
689    trim_type = f"{trim_type} " if trim_type else ""
690    remove_chars = f"{remove_chars} " if remove_chars else ""
691    from_part = "FROM " if trim_type or remove_chars else ""
692    collation = f" COLLATE {collation}" if collation else ""
693    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
694
695
696def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
697    return self.func("STRPTIME", expression.this, self.format_time(expression))
698
699
700def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
701    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
702        _dialect = Dialect.get_or_raise(dialect)
703        time_format = self.format_time(expression)
704        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
705            return self.sql(
706                exp.cast(
707                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
708                    "date",
709                )
710            )
711        return self.sql(exp.cast(expression.this, "date"))
712
713    return _ts_or_ds_to_date_sql
714
715
716def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
717    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
718
719
720def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
721    delim, *rest_args = expression.expressions
722    return self.sql(
723        reduce(
724            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
725            rest_args,
726        )
727    )
728
729
730def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
731    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
732    if bad_args:
733        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
734
735    return self.func(
736        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
737    )
738
739
740def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
741    bad_args = list(
742        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
743    )
744    if bad_args:
745        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
746
747    return self.func(
748        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
749    )
750
751
752def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
753    names = []
754    for agg in aggregations:
755        if isinstance(agg, exp.Alias):
756            names.append(agg.alias)
757        else:
758            """
759            This case corresponds to aggregations without aliases being used as suffixes
760            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
761            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
762            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
763            """
764            agg_all_unquoted = agg.transform(
765                lambda node: exp.Identifier(this=node.name, quoted=False)
766                if isinstance(node, exp.Identifier)
767                else node
768            )
769            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
770
771    return names
772
773
774def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
775    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
776
777
778# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
779def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
780    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
781
782
783def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
784    return self.func("MAX", expression.this)
785
786
787def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
788    a = self.sql(expression.left)
789    b = self.sql(expression.right)
790    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
791
792
793# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
794def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
795    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
796
797
798def is_parse_json(expression: exp.Expression) -> bool:
799    return isinstance(expression, exp.ParseJSON) or (
800        isinstance(expression, exp.Cast) and expression.is_type("json")
801    )
802
803
804def isnull_to_is_null(args: t.List) -> exp.Expression:
805    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
806
807
808def generatedasidentitycolumnconstraint_sql(
809    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
810) -> str:
811    start = self.sql(expression, "start") or "1"
812    increment = self.sql(expression, "increment") or "1"
813    return f"IDENTITY({start}, {increment})"
814
815
816def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
817    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
818        if expression.args.get("count"):
819            self.unsupported(f"Only two arguments are supported in function {name}.")
820
821        return self.func(name, expression.this, expression.expression)
822
823    return _arg_max_or_min_sql
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
142class Dialect(metaclass=_Dialect):
143    # Determines the base index offset for arrays
144    INDEX_OFFSET = 0
145
146    # If true unnest table aliases are considered only as column aliases
147    UNNEST_COLUMN_ONLY = False
148
149    # Determines whether or not the table alias comes after tablesample
150    ALIAS_POST_TABLESAMPLE = False
151
152    # Determines whether or not unquoted identifiers are resolved as uppercase
153    # When set to None, it means that the dialect treats all identifiers as case-insensitive
154    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
155
156    # Determines whether or not an unquoted identifier can start with a digit
157    IDENTIFIERS_CAN_START_WITH_DIGIT = False
158
159    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
160    DPIPE_IS_STRING_CONCAT = True
161
162    # Determines whether or not CONCAT's arguments must be strings
163    STRICT_STRING_CONCAT = False
164
165    # Determines whether or not user-defined data types are supported
166    SUPPORTS_USER_DEFINED_TYPES = True
167
168    # Determines whether or not SEMI/ANTI JOINs are supported
169    SUPPORTS_SEMI_ANTI_JOIN = True
170
171    # Determines how function names are going to be normalized
172    NORMALIZE_FUNCTIONS: bool | str = "upper"
173
174    # Determines whether the base comes first in the LOG function
175    LOG_BASE_FIRST = True
176
177    # Indicates the default null ordering method to use if not explicitly set
178    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
179    NULL_ORDERING = "nulls_are_small"
180
181    DATE_FORMAT = "'%Y-%m-%d'"
182    DATEINT_FORMAT = "'%Y%m%d'"
183    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
184
185    # Custom time mappings in which the key represents dialect time format
186    # and the value represents a python time format
187    TIME_MAPPING: t.Dict[str, str] = {}
188
189    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
190    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
191    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
192    FORMAT_MAPPING: t.Dict[str, str] = {}
193
194    # Mapping of an unescaped escape sequence to the corresponding character
195    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
196
197    # Columns that are auto-generated by the engine corresponding to this dialect
198    # Such columns may be excluded from SELECT * queries, for example
199    PSEUDOCOLUMNS: t.Set[str] = set()
200
201    # Autofilled
202    tokenizer_class = Tokenizer
203    parser_class = Parser
204    generator_class = Generator
205
206    # A trie of the time_mapping keys
207    TIME_TRIE: t.Dict = {}
208    FORMAT_TRIE: t.Dict = {}
209
210    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
211    INVERSE_TIME_TRIE: t.Dict = {}
212
213    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
214
215    def __eq__(self, other: t.Any) -> bool:
216        return type(self) == other
217
218    def __hash__(self) -> int:
219        return hash(type(self))
220
221    @classmethod
222    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
223        if not dialect:
224            return cls
225        if isinstance(dialect, _Dialect):
226            return dialect
227        if isinstance(dialect, Dialect):
228            return dialect.__class__
229
230        result = cls.get(dialect)
231        if not result:
232            raise ValueError(f"Unknown dialect '{dialect}'")
233
234        return result
235
236    @classmethod
237    def format_time(
238        cls, expression: t.Optional[str | exp.Expression]
239    ) -> t.Optional[exp.Expression]:
240        if isinstance(expression, str):
241            return exp.Literal.string(
242                # the time formats are quoted
243                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
244            )
245
246        if expression and expression.is_string:
247            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
248
249        return expression
250
251    @classmethod
252    def normalize_identifier(cls, expression: E) -> E:
253        """
254        Normalizes an unquoted identifier to either lower or upper case, thus essentially
255        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
256        they will be normalized to lowercase regardless of being quoted or not.
257        """
258        if isinstance(expression, exp.Identifier) and (
259            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
260        ):
261            expression.set(
262                "this",
263                expression.this.upper()
264                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
265                else expression.this.lower(),
266            )
267
268        return expression
269
270    @classmethod
271    def case_sensitive(cls, text: str) -> bool:
272        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
273        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
274            return False
275
276        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
277        return any(unsafe(char) for char in text)
278
279    @classmethod
280    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
281        """Checks if text can be identified given an identify option.
282
283        Args:
284            text: The text to check.
285            identify:
286                "always" or `True`: Always returns true.
287                "safe": True if the identifier is case-insensitive.
288
289        Returns:
290            Whether or not the given text can be identified.
291        """
292        if identify is True or identify == "always":
293            return True
294
295        if identify == "safe":
296            return not cls.case_sensitive(text)
297
298        return False
299
300    @classmethod
301    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
302        if isinstance(expression, exp.Identifier):
303            name = expression.this
304            expression.set(
305                "quoted",
306                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
307            )
308
309        return expression
310
311    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
312        return self.parser(**opts).parse(self.tokenize(sql), sql)
313
314    def parse_into(
315        self, expression_type: exp.IntoType, sql: str, **opts
316    ) -> t.List[t.Optional[exp.Expression]]:
317        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
318
319    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
320        return self.generator(**opts).generate(expression, copy=copy)
321
322    def transpile(self, sql: str, **opts) -> t.List[str]:
323        return [
324            self.generate(expression, copy=False, **opts) if expression else ""
325            for expression in self.parse(sql)
326        ]
327
328    def tokenize(self, sql: str) -> t.List[Token]:
329        return self.tokenizer.tokenize(sql)
330
331    @property
332    def tokenizer(self) -> Tokenizer:
333        if not hasattr(self, "_tokenizer"):
334            self._tokenizer = self.tokenizer_class()
335        return self._tokenizer
336
337    def parser(self, **opts) -> Parser:
338        return self.parser_class(**opts)
339
340    def generator(self, **opts) -> Generator:
341        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
ESCAPE_SEQUENCES: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
221    @classmethod
222    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
223        if not dialect:
224            return cls
225        if isinstance(dialect, _Dialect):
226            return dialect
227        if isinstance(dialect, Dialect):
228            return dialect.__class__
229
230        result = cls.get(dialect)
231        if not result:
232            raise ValueError(f"Unknown dialect '{dialect}'")
233
234        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
236    @classmethod
237    def format_time(
238        cls, expression: t.Optional[str | exp.Expression]
239    ) -> t.Optional[exp.Expression]:
240        if isinstance(expression, str):
241            return exp.Literal.string(
242                # the time formats are quoted
243                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
244            )
245
246        if expression and expression.is_string:
247            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
248
249        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
251    @classmethod
252    def normalize_identifier(cls, expression: E) -> E:
253        """
254        Normalizes an unquoted identifier to either lower or upper case, thus essentially
255        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
256        they will be normalized to lowercase regardless of being quoted or not.
257        """
258        if isinstance(expression, exp.Identifier) and (
259            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
260        ):
261            expression.set(
262                "this",
263                expression.this.upper()
264                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
265                else expression.this.lower(),
266            )
267
268        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
270    @classmethod
271    def case_sensitive(cls, text: str) -> bool:
272        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
273        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
274            return False
275
276        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
277        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
279    @classmethod
280    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
281        """Checks if text can be identified given an identify option.
282
283        Args:
284            text: The text to check.
285            identify:
286                "always" or `True`: Always returns true.
287                "safe": True if the identifier is case-insensitive.
288
289        Returns:
290            Whether or not the given text can be identified.
291        """
292        if identify is True or identify == "always":
293            return True
294
295        if identify == "safe":
296            return not cls.case_sensitive(text)
297
298        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
300    @classmethod
301    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
302        if isinstance(expression, exp.Identifier):
303            name = expression.this
304            expression.set(
305                "quoted",
306                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
307            )
308
309        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
311    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
312        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
314    def parse_into(
315        self, expression_type: exp.IntoType, sql: str, **opts
316    ) -> t.List[t.Optional[exp.Expression]]:
317        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
319    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
320        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
322    def transpile(self, sql: str, **opts) -> t.List[str]:
323        return [
324            self.generate(expression, copy=False, **opts) if expression else ""
325            for expression in self.parse(sql)
326        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
328    def tokenize(self, sql: str) -> t.List[Token]:
329        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
337    def parser(self, **opts) -> Parser:
338        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
340    def generator(self, **opts) -> Generator:
341        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
347def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
348    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
351def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
352    if expression.args.get("accuracy"):
353        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
354    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
357def if_sql(
358    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
359) -> t.Callable[[Generator, exp.If], str]:
360    def _if_sql(self: Generator, expression: exp.If) -> str:
361        return self.func(
362            name,
363            expression.this,
364            expression.args.get("true"),
365            expression.args.get("false") or false_value,
366        )
367
368    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
371def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
372    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
375def arrow_json_extract_scalar_sql(
376    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
377) -> str:
378    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
381def inline_array_sql(self: Generator, expression: exp.Array) -> str:
382    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
385def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
386    return self.like_sql(
387        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
388    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
391def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
392    zone = self.sql(expression, "this")
393    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
396def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
397    if expression.args.get("recursive"):
398        self.unsupported("Recursive CTEs are unsupported")
399        expression.args["recursive"] = False
400    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
403def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
404    n = self.sql(expression, "this")
405    d = self.sql(expression, "expression")
406    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
409def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
410    self.unsupported("TABLESAMPLE unsupported")
411    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
414def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
415    self.unsupported("PIVOT unsupported")
416    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
419def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
420    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
423def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
424    self.unsupported("Properties unsupported")
425    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
428def no_comment_column_constraint_sql(
429    self: Generator, expression: exp.CommentColumnConstraint
430) -> str:
431    self.unsupported("CommentColumnConstraint unsupported")
432    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
435def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
436    self.unsupported("MAP_FROM_ENTRIES unsupported")
437    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
440def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
441    this = self.sql(expression, "this")
442    substr = self.sql(expression, "substr")
443    position = self.sql(expression, "position")
444    if position:
445        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
446    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
449def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
450    return (
451        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
452    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
455def var_map_sql(
456    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
457) -> str:
458    keys = expression.args["keys"]
459    values = expression.args["values"]
460
461    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
462        self.unsupported("Cannot convert array columns into map.")
463        return self.func(map_func_name, keys, values)
464
465    args = []
466    for key, value in zip(keys.expressions, values.expressions):
467        args.append(self.sql(key))
468        args.append(self.sql(value))
469
470    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
473def format_time_lambda(
474    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
475) -> t.Callable[[t.List], E]:
476    """Helper used for time expressions.
477
478    Args:
479        exp_class: the expression class to instantiate.
480        dialect: target sql dialect.
481        default: the default format, True being time.
482
483    Returns:
484        A callable that can be used to return the appropriately formatted time expression.
485    """
486
487    def _format_time(args: t.List):
488        return exp_class(
489            this=seq_get(args, 0),
490            format=Dialect[dialect].format_time(
491                seq_get(args, 1)
492                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
493            ),
494        )
495
496    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
499def time_format(
500    dialect: DialectType = None,
501) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
502    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
503        """
504        Returns the time format for a given expression, unless it's equivalent
505        to the default time format of the dialect of interest.
506        """
507        time_format = self.format_time(expression)
508        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
509
510    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
513def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
514    """
515    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
516    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
517    columns are removed from the create statement.
518    """
519    has_schema = isinstance(expression.this, exp.Schema)
520    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
521
522    if has_schema and is_partitionable:
523        prop = expression.find(exp.PartitionedByProperty)
524        if prop and prop.this and not isinstance(prop.this, exp.Schema):
525            schema = expression.this
526            columns = {v.name.upper() for v in prop.this.expressions}
527            partitions = [col for col in schema.expressions if col.name.upper() in columns]
528            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
529            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
530            expression.set("this", schema)
531
532    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
535def parse_date_delta(
536    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
537) -> t.Callable[[t.List], E]:
538    def inner_func(args: t.List) -> E:
539        unit_based = len(args) == 3
540        this = args[2] if unit_based else seq_get(args, 0)
541        unit = args[0] if unit_based else exp.Literal.string("DAY")
542        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
543        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
544
545    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
548def parse_date_delta_with_interval(
549    expression_class: t.Type[E],
550) -> t.Callable[[t.List], t.Optional[E]]:
551    def func(args: t.List) -> t.Optional[E]:
552        if len(args) < 2:
553            return None
554
555        interval = args[1]
556
557        if not isinstance(interval, exp.Interval):
558            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
559
560        expression = interval.this
561        if expression and expression.is_string:
562            expression = exp.Literal.number(expression.this)
563
564        return expression_class(
565            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
566        )
567
568    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
571def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
572    unit = seq_get(args, 0)
573    this = seq_get(args, 1)
574
575    if isinstance(this, exp.Cast) and this.is_type("date"):
576        return exp.DateTrunc(unit=unit, this=this)
577    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
580def date_add_interval_sql(
581    data_type: str, kind: str
582) -> t.Callable[[Generator, exp.Expression], str]:
583    def func(self: Generator, expression: exp.Expression) -> str:
584        this = self.sql(expression, "this")
585        unit = expression.args.get("unit")
586        unit = exp.var(unit.name.upper() if unit else "DAY")
587        interval = exp.Interval(this=expression.expression, unit=unit)
588        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
589
590    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
593def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
594    return self.func(
595        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
596    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
599def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
600    if not expression.expression:
601        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
602    if expression.text("expression").lower() in TIMEZONES:
603        return self.sql(
604            exp.AtTimeZone(
605                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
606                zone=expression.expression,
607            )
608        )
609    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
612def locate_to_strposition(args: t.List) -> exp.Expression:
613    return exp.StrPosition(
614        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
615    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
618def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
619    return self.func(
620        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
621    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
624def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
625    return self.sql(
626        exp.Substring(
627            this=expression.this, start=exp.Literal.number(1), length=expression.expression
628        )
629    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
632def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
633    return self.sql(
634        exp.Substring(
635            this=expression.this,
636            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
637        )
638    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
641def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
642    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
645def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
646    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
650def encode_decode_sql(
651    self: Generator, expression: exp.Expression, name: str, replace: bool = True
652) -> str:
653    charset = expression.args.get("charset")
654    if charset and charset.name.lower() != "utf-8":
655        self.unsupported(f"Expected utf-8 character set, got {charset}.")
656
657    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
660def min_or_least(self: Generator, expression: exp.Min) -> str:
661    name = "LEAST" if expression.expressions else "MIN"
662    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
665def max_or_greatest(self: Generator, expression: exp.Max) -> str:
666    name = "GREATEST" if expression.expressions else "MAX"
667    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
670def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
671    cond = expression.this
672
673    if isinstance(expression.this, exp.Distinct):
674        cond = expression.this.expressions[0]
675        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
676
677    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
680def trim_sql(self: Generator, expression: exp.Trim) -> str:
681    target = self.sql(expression, "this")
682    trim_type = self.sql(expression, "position")
683    remove_chars = self.sql(expression, "expression")
684    collation = self.sql(expression, "collation")
685
686    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
687    if not remove_chars and not collation:
688        return self.trim_sql(expression)
689
690    trim_type = f"{trim_type} " if trim_type else ""
691    remove_chars = f"{remove_chars} " if remove_chars else ""
692    from_part = "FROM " if trim_type or remove_chars else ""
693    collation = f" COLLATE {collation}" if collation else ""
694    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
697def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
698    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
701def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
702    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
703        _dialect = Dialect.get_or_raise(dialect)
704        time_format = self.format_time(expression)
705        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
706            return self.sql(
707                exp.cast(
708                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
709                    "date",
710                )
711            )
712        return self.sql(exp.cast(expression.this, "date"))
713
714    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
717def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
718    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
721def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
722    delim, *rest_args = expression.expressions
723    return self.sql(
724        reduce(
725            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
726            rest_args,
727        )
728    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
731def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
732    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
733    if bad_args:
734        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
735
736    return self.func(
737        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
738    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
741def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
742    bad_args = list(
743        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
744    )
745    if bad_args:
746        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
747
748    return self.func(
749        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
750    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
753def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
754    names = []
755    for agg in aggregations:
756        if isinstance(agg, exp.Alias):
757            names.append(agg.alias)
758        else:
759            """
760            This case corresponds to aggregations without aliases being used as suffixes
761            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
762            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
763            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
764            """
765            agg_all_unquoted = agg.transform(
766                lambda node: exp.Identifier(this=node.name, quoted=False)
767                if isinstance(node, exp.Identifier)
768                else node
769            )
770            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
771
772    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
775def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
776    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
780def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
781    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
784def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
785    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
788def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
789    a = self.sql(expression.left)
790    b = self.sql(expression.right)
791    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
795def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
796    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
799def is_parse_json(expression: exp.Expression) -> bool:
800    return isinstance(expression, exp.ParseJSON) or (
801        isinstance(expression, exp.Cast) and expression.is_type("json")
802    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
805def isnull_to_is_null(args: t.List) -> exp.Expression:
806    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
809def generatedasidentitycolumnconstraint_sql(
810    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
811) -> str:
812    start = self.sql(expression, "start") or "1"
813    increment = self.sql(expression, "increment") or "1"
814    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
817def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
818    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
819        if expression.args.get("count"):
820            self.unsupported(f"Only two arguments are supported in function {name}.")
821
822        return self.func(name, expression.this, expression.expression)
823
824    return _arg_max_or_min_sql