Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer, TokenType
 12from sqlglot.trie import new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
 19# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
 20RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
 21
 22
 23class Dialects(str, Enum):
 24    DIALECT = ""
 25
 26    BIGQUERY = "bigquery"
 27    CLICKHOUSE = "clickhouse"
 28    DUCKDB = "duckdb"
 29    HIVE = "hive"
 30    MYSQL = "mysql"
 31    ORACLE = "oracle"
 32    POSTGRES = "postgres"
 33    PRESTO = "presto"
 34    REDSHIFT = "redshift"
 35    SNOWFLAKE = "snowflake"
 36    SPARK = "spark"
 37    SPARK2 = "spark2"
 38    SQLITE = "sqlite"
 39    STARROCKS = "starrocks"
 40    TABLEAU = "tableau"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    DATABRICKS = "databricks"
 44    DRILL = "drill"
 45    TERADATA = "teradata"
 46
 47
 48class _Dialect(type):
 49    classes: t.Dict[str, t.Type[Dialect]] = {}
 50
 51    def __eq__(cls, other: t.Any) -> bool:
 52        if cls is other:
 53            return True
 54        if isinstance(other, str):
 55            return cls is cls.get(other)
 56        if isinstance(other, Dialect):
 57            return cls is type(other)
 58
 59        return False
 60
 61    def __hash__(cls) -> int:
 62        return hash(cls.__name__.lower())
 63
 64    @classmethod
 65    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 66        return cls.classes[key]
 67
 68    @classmethod
 69    def get(
 70        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 71    ) -> t.Optional[t.Type[Dialect]]:
 72        return cls.classes.get(key, default)
 73
 74    def __new__(cls, clsname, bases, attrs):
 75        klass = super().__new__(cls, clsname, bases, attrs)
 76        enum = Dialects.__members__.get(clsname.upper())
 77        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 78
 79        klass.time_trie = new_trie(klass.time_mapping)
 80        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 81        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 82
 83        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 84        klass.parser_class = getattr(klass, "Parser", Parser)
 85        klass.generator_class = getattr(klass, "Generator", Generator)
 86
 87        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 88        klass.identifier_start, klass.identifier_end = list(
 89            klass.tokenizer_class._IDENTIFIERS.items()
 90        )[0]
 91
 92        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 93            return next(
 94                (
 95                    (s, e)
 96                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 97                    if t == token_type
 98                ),
 99                (None, None),
100            )
101
102        klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
103        klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
104        klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
105        klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
106
107        klass.tokenizer_class.identifiers_can_start_with_digit = (
108            klass.identifiers_can_start_with_digit
109        )
110
111        return klass
112
113
114class Dialect(metaclass=_Dialect):
115    index_offset = 0
116    unnest_column_only = False
117    alias_post_tablesample = False
118    identifiers_can_start_with_digit = False
119    normalize_functions: t.Optional[str] = "upper"
120    null_ordering = "nulls_are_small"
121
122    date_format = "'%Y-%m-%d'"
123    dateint_format = "'%Y%m%d'"
124    time_format = "'%Y-%m-%d %H:%M:%S'"
125    time_mapping: t.Dict[str, str] = {}
126
127    # autofilled
128    quote_start = None
129    quote_end = None
130    identifier_start = None
131    identifier_end = None
132
133    time_trie = None
134    inverse_time_mapping = None
135    inverse_time_trie = None
136    tokenizer_class = None
137    parser_class = None
138    generator_class = None
139
140    def __eq__(self, other: t.Any) -> bool:
141        return type(self) == other
142
143    def __hash__(self) -> int:
144        return hash(type(self))
145
146    @classmethod
147    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
148        if not dialect:
149            return cls
150        if isinstance(dialect, _Dialect):
151            return dialect
152        if isinstance(dialect, Dialect):
153            return dialect.__class__
154
155        result = cls.get(dialect)
156        if not result:
157            raise ValueError(f"Unknown dialect '{dialect}'")
158
159        return result
160
161    @classmethod
162    def format_time(
163        cls, expression: t.Optional[str | exp.Expression]
164    ) -> t.Optional[exp.Expression]:
165        if isinstance(expression, str):
166            return exp.Literal.string(
167                format_time(
168                    expression[1:-1],  # the time formats are quoted
169                    cls.time_mapping,
170                    cls.time_trie,
171                )
172            )
173        if expression and expression.is_string:
174            return exp.Literal.string(
175                format_time(
176                    expression.this,
177                    cls.time_mapping,
178                    cls.time_trie,
179                )
180            )
181        return expression
182
183    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
184        return self.parser(**opts).parse(self.tokenize(sql), sql)
185
186    def parse_into(
187        self, expression_type: exp.IntoType, sql: str, **opts
188    ) -> t.List[t.Optional[exp.Expression]]:
189        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
190
191    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
192        return self.generator(**opts).generate(expression)
193
194    def transpile(self, sql: str, **opts) -> t.List[str]:
195        return [self.generate(expression, **opts) for expression in self.parse(sql)]
196
197    def tokenize(self, sql: str) -> t.List[Token]:
198        return self.tokenizer.tokenize(sql)
199
200    @property
201    def tokenizer(self) -> Tokenizer:
202        if not hasattr(self, "_tokenizer"):
203            self._tokenizer = self.tokenizer_class()  # type: ignore
204        return self._tokenizer
205
206    def parser(self, **opts) -> Parser:
207        return self.parser_class(  # type: ignore
208            **{
209                "index_offset": self.index_offset,
210                "unnest_column_only": self.unnest_column_only,
211                "alias_post_tablesample": self.alias_post_tablesample,
212                "null_ordering": self.null_ordering,
213                **opts,
214            },
215        )
216
217    def generator(self, **opts) -> Generator:
218        return self.generator_class(  # type: ignore
219            **{
220                "quote_start": self.quote_start,
221                "quote_end": self.quote_end,
222                "bit_start": self.bit_start,
223                "bit_end": self.bit_end,
224                "hex_start": self.hex_start,
225                "hex_end": self.hex_end,
226                "byte_start": self.byte_start,
227                "byte_end": self.byte_end,
228                "raw_start": self.raw_start,
229                "raw_end": self.raw_end,
230                "identifier_start": self.identifier_start,
231                "identifier_end": self.identifier_end,
232                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
233                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
234                "index_offset": self.index_offset,
235                "time_mapping": self.inverse_time_mapping,
236                "time_trie": self.inverse_time_trie,
237                "unnest_column_only": self.unnest_column_only,
238                "alias_post_tablesample": self.alias_post_tablesample,
239                "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
240                "normalize_functions": self.normalize_functions,
241                "null_ordering": self.null_ordering,
242                **opts,
243            }
244        )
245
246
247DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
248
249
250def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
251    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
252
253
254def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
255    if expression.args.get("accuracy"):
256        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
257    return self.func("APPROX_COUNT_DISTINCT", expression.this)
258
259
260def if_sql(self: Generator, expression: exp.If) -> str:
261    return self.func(
262        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
263    )
264
265
266def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
267    return self.binary(expression, "->")
268
269
270def arrow_json_extract_scalar_sql(
271    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
272) -> str:
273    return self.binary(expression, "->>")
274
275
276def inline_array_sql(self: Generator, expression: exp.Array) -> str:
277    return f"[{self.expressions(expression)}]"
278
279
280def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
281    return self.like_sql(
282        exp.Like(
283            this=exp.Lower(this=expression.this),
284            expression=expression.args["expression"],
285        )
286    )
287
288
289def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
290    zone = self.sql(expression, "this")
291    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
292
293
294def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
295    if expression.args.get("recursive"):
296        self.unsupported("Recursive CTEs are unsupported")
297        expression.args["recursive"] = False
298    return self.with_sql(expression)
299
300
301def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
302    n = self.sql(expression, "this")
303    d = self.sql(expression, "expression")
304    return f"IF({d} <> 0, {n} / {d}, NULL)"
305
306
307def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
308    self.unsupported("TABLESAMPLE unsupported")
309    return self.sql(expression.this)
310
311
312def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
313    self.unsupported("PIVOT unsupported")
314    return ""
315
316
317def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
318    return self.cast_sql(expression)
319
320
321def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
322    self.unsupported("Properties unsupported")
323    return ""
324
325
326def no_comment_column_constraint_sql(
327    self: Generator, expression: exp.CommentColumnConstraint
328) -> str:
329    self.unsupported("CommentColumnConstraint unsupported")
330    return ""
331
332
333def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
334    this = self.sql(expression, "this")
335    substr = self.sql(expression, "substr")
336    position = self.sql(expression, "position")
337    if position:
338        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
339    return f"STRPOS({this}, {substr})"
340
341
342def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
343    this = self.sql(expression, "this")
344    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
345    return f"{this}.{struct_key}"
346
347
348def var_map_sql(
349    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
350) -> str:
351    keys = expression.args["keys"]
352    values = expression.args["values"]
353
354    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
355        self.unsupported("Cannot convert array columns into map.")
356        return self.func(map_func_name, keys, values)
357
358    args = []
359    for key, value in zip(keys.expressions, values.expressions):
360        args.append(self.sql(key))
361        args.append(self.sql(value))
362    return self.func(map_func_name, *args)
363
364
365def format_time_lambda(
366    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
367) -> t.Callable[[t.List], E]:
368    """Helper used for time expressions.
369
370    Args:
371        exp_class: the expression class to instantiate.
372        dialect: target sql dialect.
373        default: the default format, True being time.
374
375    Returns:
376        A callable that can be used to return the appropriately formatted time expression.
377    """
378
379    def _format_time(args: t.List):
380        return exp_class(
381            this=seq_get(args, 0),
382            format=Dialect[dialect].format_time(
383                seq_get(args, 1)
384                or (Dialect[dialect].time_format if default is True else default or None)
385            ),
386        )
387
388    return _format_time
389
390
391def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
392    """
393    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
394    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
395    columns are removed from the create statement.
396    """
397    has_schema = isinstance(expression.this, exp.Schema)
398    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
399
400    if has_schema and is_partitionable:
401        expression = expression.copy()
402        prop = expression.find(exp.PartitionedByProperty)
403        if prop and prop.this and not isinstance(prop.this, exp.Schema):
404            schema = expression.this
405            columns = {v.name.upper() for v in prop.this.expressions}
406            partitions = [col for col in schema.expressions if col.name.upper() in columns]
407            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
408            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
409            expression.set("this", schema)
410
411    return self.create_sql(expression)
412
413
414def parse_date_delta(
415    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
416) -> t.Callable[[t.List], E]:
417    def inner_func(args: t.List) -> E:
418        unit_based = len(args) == 3
419        this = args[2] if unit_based else seq_get(args, 0)
420        unit = args[0] if unit_based else exp.Literal.string("DAY")
421        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
422        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
423
424    return inner_func
425
426
427def parse_date_delta_with_interval(
428    expression_class: t.Type[E],
429) -> t.Callable[[t.List], t.Optional[E]]:
430    def func(args: t.List) -> t.Optional[E]:
431        if len(args) < 2:
432            return None
433
434        interval = args[1]
435        expression = interval.this
436        if expression and expression.is_string:
437            expression = exp.Literal.number(expression.this)
438
439        return expression_class(
440            this=args[0],
441            expression=expression,
442            unit=exp.Literal.string(interval.text("unit")),
443        )
444
445    return func
446
447
448def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
449    unit = seq_get(args, 0)
450    this = seq_get(args, 1)
451
452    if isinstance(this, exp.Cast) and this.is_type("date"):
453        return exp.DateTrunc(unit=unit, this=this)
454    return exp.TimestampTrunc(this=this, unit=unit)
455
456
457def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
458    return self.func(
459        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
460    )
461
462
463def locate_to_strposition(args: t.List) -> exp.Expression:
464    return exp.StrPosition(
465        this=seq_get(args, 1),
466        substr=seq_get(args, 0),
467        position=seq_get(args, 2),
468    )
469
470
471def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
472    return self.func(
473        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
474    )
475
476
477def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
478    expression = expression.copy()
479    return self.sql(
480        exp.Substring(
481            this=expression.this, start=exp.Literal.number(1), length=expression.expression
482        )
483    )
484
485
486def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
487    expression = expression.copy()
488    return self.sql(
489        exp.Substring(
490            this=expression.this,
491            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
492        )
493    )
494
495
496def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
497    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
498
499
500def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
501    return f"CAST({self.sql(expression, 'this')} AS DATE)"
502
503
504def min_or_least(self: Generator, expression: exp.Min) -> str:
505    name = "LEAST" if expression.expressions else "MIN"
506    return rename_func(name)(self, expression)
507
508
509def max_or_greatest(self: Generator, expression: exp.Max) -> str:
510    name = "GREATEST" if expression.expressions else "MAX"
511    return rename_func(name)(self, expression)
512
513
514def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
515    cond = expression.this
516
517    if isinstance(expression.this, exp.Distinct):
518        cond = expression.this.expressions[0]
519        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
520
521    return self.func("sum", exp.func("if", cond, 1, 0))
522
523
524def trim_sql(self: Generator, expression: exp.Trim) -> str:
525    target = self.sql(expression, "this")
526    trim_type = self.sql(expression, "position")
527    remove_chars = self.sql(expression, "expression")
528    collation = self.sql(expression, "collation")
529
530    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
531    if not remove_chars and not collation:
532        return self.trim_sql(expression)
533
534    trim_type = f"{trim_type} " if trim_type else ""
535    remove_chars = f"{remove_chars} " if remove_chars else ""
536    from_part = "FROM " if trim_type or remove_chars else ""
537    collation = f" COLLATE {collation}" if collation else ""
538    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
539
540
541def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
542    return self.func("STRPTIME", expression.this, self.format_time(expression))
543
544
545def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
546    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
547        _dialect = Dialect.get_or_raise(dialect)
548        time_format = self.format_time(expression)
549        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
550            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
551        return f"CAST({self.sql(expression, 'this')} AS DATE)"
552
553    return _ts_or_ds_to_date_sql
554
555
556# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
557def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
558    names = []
559    for agg in aggregations:
560        if isinstance(agg, exp.Alias):
561            names.append(agg.alias)
562        else:
563            """
564            This case corresponds to aggregations without aliases being used as suffixes
565            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
566            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
567            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
568            """
569            agg_all_unquoted = agg.transform(
570                lambda node: exp.Identifier(this=node.name, quoted=False)
571                if isinstance(node, exp.Identifier)
572                else node
573            )
574            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
575
576    return names
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    DIALECT = ""
26
27    BIGQUERY = "bigquery"
28    CLICKHOUSE = "clickhouse"
29    DUCKDB = "duckdb"
30    HIVE = "hive"
31    MYSQL = "mysql"
32    ORACLE = "oracle"
33    POSTGRES = "postgres"
34    PRESTO = "presto"
35    REDSHIFT = "redshift"
36    SNOWFLAKE = "snowflake"
37    SPARK = "spark"
38    SPARK2 = "spark2"
39    SQLITE = "sqlite"
40    STARROCKS = "starrocks"
41    TABLEAU = "tableau"
42    TRINO = "trino"
43    TSQL = "tsql"
44    DATABRICKS = "databricks"
45    DRILL = "drill"
46    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
115class Dialect(metaclass=_Dialect):
116    index_offset = 0
117    unnest_column_only = False
118    alias_post_tablesample = False
119    identifiers_can_start_with_digit = False
120    normalize_functions: t.Optional[str] = "upper"
121    null_ordering = "nulls_are_small"
122
123    date_format = "'%Y-%m-%d'"
124    dateint_format = "'%Y%m%d'"
125    time_format = "'%Y-%m-%d %H:%M:%S'"
126    time_mapping: t.Dict[str, str] = {}
127
128    # autofilled
129    quote_start = None
130    quote_end = None
131    identifier_start = None
132    identifier_end = None
133
134    time_trie = None
135    inverse_time_mapping = None
136    inverse_time_trie = None
137    tokenizer_class = None
138    parser_class = None
139    generator_class = None
140
141    def __eq__(self, other: t.Any) -> bool:
142        return type(self) == other
143
144    def __hash__(self) -> int:
145        return hash(type(self))
146
147    @classmethod
148    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
149        if not dialect:
150            return cls
151        if isinstance(dialect, _Dialect):
152            return dialect
153        if isinstance(dialect, Dialect):
154            return dialect.__class__
155
156        result = cls.get(dialect)
157        if not result:
158            raise ValueError(f"Unknown dialect '{dialect}'")
159
160        return result
161
162    @classmethod
163    def format_time(
164        cls, expression: t.Optional[str | exp.Expression]
165    ) -> t.Optional[exp.Expression]:
166        if isinstance(expression, str):
167            return exp.Literal.string(
168                format_time(
169                    expression[1:-1],  # the time formats are quoted
170                    cls.time_mapping,
171                    cls.time_trie,
172                )
173            )
174        if expression and expression.is_string:
175            return exp.Literal.string(
176                format_time(
177                    expression.this,
178                    cls.time_mapping,
179                    cls.time_trie,
180                )
181            )
182        return expression
183
184    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
185        return self.parser(**opts).parse(self.tokenize(sql), sql)
186
187    def parse_into(
188        self, expression_type: exp.IntoType, sql: str, **opts
189    ) -> t.List[t.Optional[exp.Expression]]:
190        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
191
192    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
193        return self.generator(**opts).generate(expression)
194
195    def transpile(self, sql: str, **opts) -> t.List[str]:
196        return [self.generate(expression, **opts) for expression in self.parse(sql)]
197
198    def tokenize(self, sql: str) -> t.List[Token]:
199        return self.tokenizer.tokenize(sql)
200
201    @property
202    def tokenizer(self) -> Tokenizer:
203        if not hasattr(self, "_tokenizer"):
204            self._tokenizer = self.tokenizer_class()  # type: ignore
205        return self._tokenizer
206
207    def parser(self, **opts) -> Parser:
208        return self.parser_class(  # type: ignore
209            **{
210                "index_offset": self.index_offset,
211                "unnest_column_only": self.unnest_column_only,
212                "alias_post_tablesample": self.alias_post_tablesample,
213                "null_ordering": self.null_ordering,
214                **opts,
215            },
216        )
217
218    def generator(self, **opts) -> Generator:
219        return self.generator_class(  # type: ignore
220            **{
221                "quote_start": self.quote_start,
222                "quote_end": self.quote_end,
223                "bit_start": self.bit_start,
224                "bit_end": self.bit_end,
225                "hex_start": self.hex_start,
226                "hex_end": self.hex_end,
227                "byte_start": self.byte_start,
228                "byte_end": self.byte_end,
229                "raw_start": self.raw_start,
230                "raw_end": self.raw_end,
231                "identifier_start": self.identifier_start,
232                "identifier_end": self.identifier_end,
233                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
234                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
235                "index_offset": self.index_offset,
236                "time_mapping": self.inverse_time_mapping,
237                "time_trie": self.inverse_time_trie,
238                "unnest_column_only": self.unnest_column_only,
239                "alias_post_tablesample": self.alias_post_tablesample,
240                "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
241                "normalize_functions": self.normalize_functions,
242                "null_ordering": self.null_ordering,
243                **opts,
244            }
245        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
147    @classmethod
148    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
149        if not dialect:
150            return cls
151        if isinstance(dialect, _Dialect):
152            return dialect
153        if isinstance(dialect, Dialect):
154            return dialect.__class__
155
156        result = cls.get(dialect)
157        if not result:
158            raise ValueError(f"Unknown dialect '{dialect}'")
159
160        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
162    @classmethod
163    def format_time(
164        cls, expression: t.Optional[str | exp.Expression]
165    ) -> t.Optional[exp.Expression]:
166        if isinstance(expression, str):
167            return exp.Literal.string(
168                format_time(
169                    expression[1:-1],  # the time formats are quoted
170                    cls.time_mapping,
171                    cls.time_trie,
172                )
173            )
174        if expression and expression.is_string:
175            return exp.Literal.string(
176                format_time(
177                    expression.this,
178                    cls.time_mapping,
179                    cls.time_trie,
180                )
181            )
182        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
184    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
185        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
187    def parse_into(
188        self, expression_type: exp.IntoType, sql: str, **opts
189    ) -> t.List[t.Optional[exp.Expression]]:
190        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
192    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
193        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
195    def transpile(self, sql: str, **opts) -> t.List[str]:
196        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
198    def tokenize(self, sql: str) -> t.List[Token]:
199        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
207    def parser(self, **opts) -> Parser:
208        return self.parser_class(  # type: ignore
209            **{
210                "index_offset": self.index_offset,
211                "unnest_column_only": self.unnest_column_only,
212                "alias_post_tablesample": self.alias_post_tablesample,
213                "null_ordering": self.null_ordering,
214                **opts,
215            },
216        )
def generator(self, **opts) -> sqlglot.generator.Generator:
218    def generator(self, **opts) -> Generator:
219        return self.generator_class(  # type: ignore
220            **{
221                "quote_start": self.quote_start,
222                "quote_end": self.quote_end,
223                "bit_start": self.bit_start,
224                "bit_end": self.bit_end,
225                "hex_start": self.hex_start,
226                "hex_end": self.hex_end,
227                "byte_start": self.byte_start,
228                "byte_end": self.byte_end,
229                "raw_start": self.raw_start,
230                "raw_end": self.raw_end,
231                "identifier_start": self.identifier_start,
232                "identifier_end": self.identifier_end,
233                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
234                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
235                "index_offset": self.index_offset,
236                "time_mapping": self.inverse_time_mapping,
237                "time_trie": self.inverse_time_trie,
238                "unnest_column_only": self.unnest_column_only,
239                "alias_post_tablesample": self.alias_post_tablesample,
240                "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
241                "normalize_functions": self.normalize_functions,
242                "null_ordering": self.null_ordering,
243                **opts,
244            }
245        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
251def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
252    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
255def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
256    if expression.args.get("accuracy"):
257        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
258    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
261def if_sql(self: Generator, expression: exp.If) -> str:
262    return self.func(
263        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
264    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
267def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
268    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
271def arrow_json_extract_scalar_sql(
272    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
273) -> str:
274    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
277def inline_array_sql(self: Generator, expression: exp.Array) -> str:
278    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
281def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
282    return self.like_sql(
283        exp.Like(
284            this=exp.Lower(this=expression.this),
285            expression=expression.args["expression"],
286        )
287    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
290def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
291    zone = self.sql(expression, "this")
292    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
295def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
296    if expression.args.get("recursive"):
297        self.unsupported("Recursive CTEs are unsupported")
298        expression.args["recursive"] = False
299    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
302def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
303    n = self.sql(expression, "this")
304    d = self.sql(expression, "expression")
305    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
308def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
309    self.unsupported("TABLESAMPLE unsupported")
310    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
313def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
314    self.unsupported("PIVOT unsupported")
315    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
318def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
319    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
322def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
323    self.unsupported("Properties unsupported")
324    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
327def no_comment_column_constraint_sql(
328    self: Generator, expression: exp.CommentColumnConstraint
329) -> str:
330    self.unsupported("CommentColumnConstraint unsupported")
331    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
334def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
335    this = self.sql(expression, "this")
336    substr = self.sql(expression, "substr")
337    position = self.sql(expression, "position")
338    if position:
339        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
340    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
343def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
344    this = self.sql(expression, "this")
345    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
346    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
349def var_map_sql(
350    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
351) -> str:
352    keys = expression.args["keys"]
353    values = expression.args["values"]
354
355    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
356        self.unsupported("Cannot convert array columns into map.")
357        return self.func(map_func_name, keys, values)
358
359    args = []
360    for key, value in zip(keys.expressions, values.expressions):
361        args.append(self.sql(key))
362        args.append(self.sql(value))
363    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[List], ~E]:
366def format_time_lambda(
367    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
368) -> t.Callable[[t.List], E]:
369    """Helper used for time expressions.
370
371    Args:
372        exp_class: the expression class to instantiate.
373        dialect: target sql dialect.
374        default: the default format, True being time.
375
376    Returns:
377        A callable that can be used to return the appropriately formatted time expression.
378    """
379
380    def _format_time(args: t.List):
381        return exp_class(
382            this=seq_get(args, 0),
383            format=Dialect[dialect].format_time(
384                seq_get(args, 1)
385                or (Dialect[dialect].time_format if default is True else default or None)
386            ),
387        )
388
389    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
392def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
393    """
394    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
395    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
396    columns are removed from the create statement.
397    """
398    has_schema = isinstance(expression.this, exp.Schema)
399    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
400
401    if has_schema and is_partitionable:
402        expression = expression.copy()
403        prop = expression.find(exp.PartitionedByProperty)
404        if prop and prop.this and not isinstance(prop.this, exp.Schema):
405            schema = expression.this
406            columns = {v.name.upper() for v in prop.this.expressions}
407            partitions = [col for col in schema.expressions if col.name.upper() in columns]
408            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
409            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
410            expression.set("this", schema)
411
412    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
415def parse_date_delta(
416    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
417) -> t.Callable[[t.List], E]:
418    def inner_func(args: t.List) -> E:
419        unit_based = len(args) == 3
420        this = args[2] if unit_based else seq_get(args, 0)
421        unit = args[0] if unit_based else exp.Literal.string("DAY")
422        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
423        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
424
425    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
428def parse_date_delta_with_interval(
429    expression_class: t.Type[E],
430) -> t.Callable[[t.List], t.Optional[E]]:
431    def func(args: t.List) -> t.Optional[E]:
432        if len(args) < 2:
433            return None
434
435        interval = args[1]
436        expression = interval.this
437        if expression and expression.is_string:
438            expression = exp.Literal.number(expression.this)
439
440        return expression_class(
441            this=args[0],
442            expression=expression,
443            unit=exp.Literal.string(interval.text("unit")),
444        )
445
446    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
449def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
450    unit = seq_get(args, 0)
451    this = seq_get(args, 1)
452
453    if isinstance(this, exp.Cast) and this.is_type("date"):
454        return exp.DateTrunc(unit=unit, this=this)
455    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
458def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
459    return self.func(
460        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
461    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
464def locate_to_strposition(args: t.List) -> exp.Expression:
465    return exp.StrPosition(
466        this=seq_get(args, 1),
467        substr=seq_get(args, 0),
468        position=seq_get(args, 2),
469    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
472def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
473    return self.func(
474        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
475    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
478def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
479    expression = expression.copy()
480    return self.sql(
481        exp.Substring(
482            this=expression.this, start=exp.Literal.number(1), length=expression.expression
483        )
484    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
487def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
488    expression = expression.copy()
489    return self.sql(
490        exp.Substring(
491            this=expression.this,
492            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
493        )
494    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
497def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
498    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
501def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
502    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
505def min_or_least(self: Generator, expression: exp.Min) -> str:
506    name = "LEAST" if expression.expressions else "MIN"
507    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
510def max_or_greatest(self: Generator, expression: exp.Max) -> str:
511    name = "GREATEST" if expression.expressions else "MAX"
512    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
515def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
516    cond = expression.this
517
518    if isinstance(expression.this, exp.Distinct):
519        cond = expression.this.expressions[0]
520        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
521
522    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
525def trim_sql(self: Generator, expression: exp.Trim) -> str:
526    target = self.sql(expression, "this")
527    trim_type = self.sql(expression, "position")
528    remove_chars = self.sql(expression, "expression")
529    collation = self.sql(expression, "collation")
530
531    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
532    if not remove_chars and not collation:
533        return self.trim_sql(expression)
534
535    trim_type = f"{trim_type} " if trim_type else ""
536    remove_chars = f"{remove_chars} " if remove_chars else ""
537    from_part = "FROM " if trim_type or remove_chars else ""
538    collation = f" COLLATE {collation}" if collation else ""
539    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
542def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
543    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
546def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
547    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
548        _dialect = Dialect.get_or_raise(dialect)
549        time_format = self.format_time(expression)
550        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
551            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
552        return f"CAST({self.sql(expression, 'this')} AS DATE)"
553
554    return _ts_or_ds_to_date_sql
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
558def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
559    names = []
560    for agg in aggregations:
561        if isinstance(agg, exp.Alias):
562            names.append(agg.alias)
563        else:
564            """
565            This case corresponds to aggregations without aliases being used as suffixes
566            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
567            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
568            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
569            """
570            agg_all_unquoted = agg.transform(
571                lambda node: exp.Identifier(this=node.name, quoted=False)
572                if isinstance(node, exp.Identifier)
573                else node
574            )
575            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
576
577    return names