sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 85 klass.parser_class = getattr(klass, "Parser", Parser) 86 klass.generator_class = getattr(klass, "Generator", Generator) 87 88 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 89 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 90 klass.tokenizer_class._IDENTIFIERS.items() 91 )[0] 92 93 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 94 return next( 95 ( 96 (s, e) 97 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 98 if t == token_type 99 ), 100 (None, None), 101 ) 102 103 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 104 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 105 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 106 107 dialect_properties = { 108 **{ 109 k: v 110 for k, v in vars(klass).items() 111 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 112 }, 113 "TOKENIZER_CLASS": klass.tokenizer_class, 114 } 115 116 if enum not in ("", "bigquery"): 117 dialect_properties["SELECT_KINDS"] = () 118 119 # Pass required dialect properties to the tokenizer, parser and generator classes 120 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 121 for name, value in dialect_properties.items(): 122 if hasattr(subclass, name): 123 setattr(subclass, name, value) 124 125 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 126 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 127 128 klass.generator_class.can_identify = klass.can_identify 129 130 return klass 131 132 133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Determines whether or not unquoted identifiers are resolved as uppercase 144 # When set to None, it means that the dialect treats all identifiers as case-insensitive 145 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 151 DPIPE_IS_STRING_CONCAT = True 152 153 # Determines whether or not CONCAT's arguments must be strings 154 STRICT_STRING_CONCAT = False 155 156 # Determines how function names are going to be normalized 157 NORMALIZE_FUNCTIONS: bool | str = "upper" 158 159 # Indicates the default null ordering method to use if not explicitly set 160 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 161 NULL_ORDERING = "nulls_are_small" 162 163 DATE_FORMAT = "'%Y-%m-%d'" 164 DATEINT_FORMAT = "'%Y%m%d'" 165 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 166 167 # Custom time mappings in which the key represents dialect time format 168 # and the value represents a python time format 169 TIME_MAPPING: t.Dict[str, str] = {} 170 171 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 172 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 173 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 174 FORMAT_MAPPING: t.Dict[str, str] = {} 175 176 # Columns that are auto-generated by the engine corresponding to this dialect 177 # Such columns may be excluded from SELECT * queries, for example 178 PSEUDOCOLUMNS: t.Set[str] = set() 179 180 # Autofilled 181 tokenizer_class = Tokenizer 182 parser_class = Parser 183 generator_class = Generator 184 185 # A trie of the time_mapping keys 186 TIME_TRIE: t.Dict = {} 187 FORMAT_TRIE: t.Dict = {} 188 189 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 190 INVERSE_TIME_TRIE: t.Dict = {} 191 192 def __eq__(self, other: t.Any) -> bool: 193 return type(self) == other 194 195 def __hash__(self) -> int: 196 return hash(type(self)) 197 198 @classmethod 199 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 200 if not dialect: 201 return cls 202 if isinstance(dialect, _Dialect): 203 return dialect 204 if isinstance(dialect, Dialect): 205 return dialect.__class__ 206 207 result = cls.get(dialect) 208 if not result: 209 raise ValueError(f"Unknown dialect '{dialect}'") 210 211 return result 212 213 @classmethod 214 def format_time( 215 cls, expression: t.Optional[str | exp.Expression] 216 ) -> t.Optional[exp.Expression]: 217 if isinstance(expression, str): 218 return exp.Literal.string( 219 # the time formats are quoted 220 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 221 ) 222 223 if expression and expression.is_string: 224 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 225 226 return expression 227 228 @classmethod 229 def normalize_identifier(cls, expression: E) -> E: 230 """ 231 Normalizes an unquoted identifier to either lower or upper case, thus essentially 232 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 233 they will be normalized regardless of being quoted or not. 234 """ 235 if isinstance(expression, exp.Identifier) and ( 236 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 237 ): 238 expression.set( 239 "this", 240 expression.this.upper() 241 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 242 else expression.this.lower(), 243 ) 244 245 return expression 246 247 @classmethod 248 def case_sensitive(cls, text: str) -> bool: 249 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 250 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 251 return False 252 253 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 254 return any(unsafe(char) for char in text) 255 256 @classmethod 257 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 258 """Checks if text can be identified given an identify option. 259 260 Args: 261 text: The text to check. 262 identify: 263 "always" or `True`: Always returns true. 264 "safe": True if the identifier is case-insensitive. 265 266 Returns: 267 Whether or not the given text can be identified. 268 """ 269 if identify is True or identify == "always": 270 return True 271 272 if identify == "safe": 273 return not cls.case_sensitive(text) 274 275 return False 276 277 @classmethod 278 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 279 if isinstance(expression, exp.Identifier): 280 name = expression.this 281 expression.set( 282 "quoted", 283 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 284 ) 285 286 return expression 287 288 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 289 return self.parser(**opts).parse(self.tokenize(sql), sql) 290 291 def parse_into( 292 self, expression_type: exp.IntoType, sql: str, **opts 293 ) -> t.List[t.Optional[exp.Expression]]: 294 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 295 296 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 297 return self.generator(**opts).generate(expression) 298 299 def transpile(self, sql: str, **opts) -> t.List[str]: 300 return [self.generate(expression, **opts) for expression in self.parse(sql)] 301 302 def tokenize(self, sql: str) -> t.List[Token]: 303 return self.tokenizer.tokenize(sql) 304 305 @property 306 def tokenizer(self) -> Tokenizer: 307 if not hasattr(self, "_tokenizer"): 308 self._tokenizer = self.tokenizer_class() 309 return self._tokenizer 310 311 def parser(self, **opts) -> Parser: 312 return self.parser_class(**opts) 313 314 def generator(self, **opts) -> Generator: 315 return self.generator_class(**opts) 316 317 318DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 319 320 321def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 322 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 323 324 325def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 326 if expression.args.get("accuracy"): 327 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 328 return self.func("APPROX_COUNT_DISTINCT", expression.this) 329 330 331def if_sql(self: Generator, expression: exp.If) -> str: 332 return self.func( 333 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 334 ) 335 336 337def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 338 return self.binary(expression, "->") 339 340 341def arrow_json_extract_scalar_sql( 342 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 343) -> str: 344 return self.binary(expression, "->>") 345 346 347def inline_array_sql(self: Generator, expression: exp.Array) -> str: 348 return f"[{self.expressions(expression, flat=True)}]" 349 350 351def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 352 return self.like_sql( 353 exp.Like( 354 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 355 ) 356 ) 357 358 359def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 360 zone = self.sql(expression, "this") 361 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 362 363 364def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 365 if expression.args.get("recursive"): 366 self.unsupported("Recursive CTEs are unsupported") 367 expression.args["recursive"] = False 368 return self.with_sql(expression) 369 370 371def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 372 n = self.sql(expression, "this") 373 d = self.sql(expression, "expression") 374 return f"IF({d} <> 0, {n} / {d}, NULL)" 375 376 377def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 378 self.unsupported("TABLESAMPLE unsupported") 379 return self.sql(expression.this) 380 381 382def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 383 self.unsupported("PIVOT unsupported") 384 return "" 385 386 387def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 388 return self.cast_sql(expression) 389 390 391def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 392 self.unsupported("Properties unsupported") 393 return "" 394 395 396def no_comment_column_constraint_sql( 397 self: Generator, expression: exp.CommentColumnConstraint 398) -> str: 399 self.unsupported("CommentColumnConstraint unsupported") 400 return "" 401 402 403def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 404 self.unsupported("MAP_FROM_ENTRIES unsupported") 405 return "" 406 407 408def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 409 this = self.sql(expression, "this") 410 substr = self.sql(expression, "substr") 411 position = self.sql(expression, "position") 412 if position: 413 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 414 return f"STRPOS({this}, {substr})" 415 416 417def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 418 return ( 419 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 420 ) 421 422 423def var_map_sql( 424 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 425) -> str: 426 keys = expression.args["keys"] 427 values = expression.args["values"] 428 429 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 430 self.unsupported("Cannot convert array columns into map.") 431 return self.func(map_func_name, keys, values) 432 433 args = [] 434 for key, value in zip(keys.expressions, values.expressions): 435 args.append(self.sql(key)) 436 args.append(self.sql(value)) 437 438 return self.func(map_func_name, *args) 439 440 441def format_time_lambda( 442 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 443) -> t.Callable[[t.List], E]: 444 """Helper used for time expressions. 445 446 Args: 447 exp_class: the expression class to instantiate. 448 dialect: target sql dialect. 449 default: the default format, True being time. 450 451 Returns: 452 A callable that can be used to return the appropriately formatted time expression. 453 """ 454 455 def _format_time(args: t.List): 456 return exp_class( 457 this=seq_get(args, 0), 458 format=Dialect[dialect].format_time( 459 seq_get(args, 1) 460 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 461 ), 462 ) 463 464 return _format_time 465 466 467def time_format( 468 dialect: DialectType = None, 469) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 470 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 471 """ 472 Returns the time format for a given expression, unless it's equivalent 473 to the default time format of the dialect of interest. 474 """ 475 time_format = self.format_time(expression) 476 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 477 478 return _time_format 479 480 481def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 482 """ 483 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 484 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 485 columns are removed from the create statement. 486 """ 487 has_schema = isinstance(expression.this, exp.Schema) 488 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 489 490 if has_schema and is_partitionable: 491 expression = expression.copy() 492 prop = expression.find(exp.PartitionedByProperty) 493 if prop and prop.this and not isinstance(prop.this, exp.Schema): 494 schema = expression.this 495 columns = {v.name.upper() for v in prop.this.expressions} 496 partitions = [col for col in schema.expressions if col.name.upper() in columns] 497 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 498 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 499 expression.set("this", schema) 500 501 return self.create_sql(expression) 502 503 504def parse_date_delta( 505 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 506) -> t.Callable[[t.List], E]: 507 def inner_func(args: t.List) -> E: 508 unit_based = len(args) == 3 509 this = args[2] if unit_based else seq_get(args, 0) 510 unit = args[0] if unit_based else exp.Literal.string("DAY") 511 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 512 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 513 514 return inner_func 515 516 517def parse_date_delta_with_interval( 518 expression_class: t.Type[E], 519) -> t.Callable[[t.List], t.Optional[E]]: 520 def func(args: t.List) -> t.Optional[E]: 521 if len(args) < 2: 522 return None 523 524 interval = args[1] 525 526 if not isinstance(interval, exp.Interval): 527 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 528 529 expression = interval.this 530 if expression and expression.is_string: 531 expression = exp.Literal.number(expression.this) 532 533 return expression_class( 534 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 535 ) 536 537 return func 538 539 540def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 541 unit = seq_get(args, 0) 542 this = seq_get(args, 1) 543 544 if isinstance(this, exp.Cast) and this.is_type("date"): 545 return exp.DateTrunc(unit=unit, this=this) 546 return exp.TimestampTrunc(this=this, unit=unit) 547 548 549def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 550 return self.func( 551 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 552 ) 553 554 555def locate_to_strposition(args: t.List) -> exp.Expression: 556 return exp.StrPosition( 557 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 558 ) 559 560 561def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 562 return self.func( 563 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 564 ) 565 566 567def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 568 expression = expression.copy() 569 return self.sql( 570 exp.Substring( 571 this=expression.this, start=exp.Literal.number(1), length=expression.expression 572 ) 573 ) 574 575 576def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 577 expression = expression.copy() 578 return self.sql( 579 exp.Substring( 580 this=expression.this, 581 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 582 ) 583 ) 584 585 586def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 587 return self.sql(exp.cast(expression.this, "timestamp")) 588 589 590def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 591 return self.sql(exp.cast(expression.this, "date")) 592 593 594# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 595def encode_decode_sql( 596 self: Generator, expression: exp.Expression, name: str, replace: bool = True 597) -> str: 598 charset = expression.args.get("charset") 599 if charset and charset.name.lower() != "utf-8": 600 self.unsupported(f"Expected utf-8 character set, got {charset}.") 601 602 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 603 604 605def min_or_least(self: Generator, expression: exp.Min) -> str: 606 name = "LEAST" if expression.expressions else "MIN" 607 return rename_func(name)(self, expression) 608 609 610def max_or_greatest(self: Generator, expression: exp.Max) -> str: 611 name = "GREATEST" if expression.expressions else "MAX" 612 return rename_func(name)(self, expression) 613 614 615def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 616 cond = expression.this 617 618 if isinstance(expression.this, exp.Distinct): 619 cond = expression.this.expressions[0] 620 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 621 622 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 623 624 625def trim_sql(self: Generator, expression: exp.Trim) -> str: 626 target = self.sql(expression, "this") 627 trim_type = self.sql(expression, "position") 628 remove_chars = self.sql(expression, "expression") 629 collation = self.sql(expression, "collation") 630 631 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 632 if not remove_chars and not collation: 633 return self.trim_sql(expression) 634 635 trim_type = f"{trim_type} " if trim_type else "" 636 remove_chars = f"{remove_chars} " if remove_chars else "" 637 from_part = "FROM " if trim_type or remove_chars else "" 638 collation = f" COLLATE {collation}" if collation else "" 639 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 640 641 642def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 643 return self.func("STRPTIME", expression.this, self.format_time(expression)) 644 645 646def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 647 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 648 _dialect = Dialect.get_or_raise(dialect) 649 time_format = self.format_time(expression) 650 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 651 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 652 653 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 654 655 return _ts_or_ds_to_date_sql 656 657 658def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 659 expression = expression.copy() 660 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 661 662 663def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 664 expression = expression.copy() 665 delim, *rest_args = expression.expressions 666 return self.sql( 667 reduce( 668 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 669 rest_args, 670 ) 671 ) 672 673 674def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 675 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 676 if bad_args: 677 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 678 679 return self.func( 680 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 681 ) 682 683 684def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 685 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 686 if bad_args: 687 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 688 689 return self.func( 690 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 691 ) 692 693 694def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 695 names = [] 696 for agg in aggregations: 697 if isinstance(agg, exp.Alias): 698 names.append(agg.alias) 699 else: 700 """ 701 This case corresponds to aggregations without aliases being used as suffixes 702 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 703 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 704 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 705 """ 706 agg_all_unquoted = agg.transform( 707 lambda node: exp.Identifier(this=node.name, quoted=False) 708 if isinstance(node, exp.Identifier) 709 else node 710 ) 711 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 712 713 return names 714 715 716def simplify_literal(expression: E) -> E: 717 if not isinstance(expression.expression, exp.Literal): 718 from sqlglot.optimizer.simplify import simplify 719 720 simplify(expression.expression) 721 722 return expression 723 724 725def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 726 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 727 728 729# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 730def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 731 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 732 733 734def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 735 return self.func("MAX", expression.this) 736 737 738# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 739def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str: 740 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
134class Dialect(metaclass=_Dialect): 135 # Determines the base index offset for arrays 136 INDEX_OFFSET = 0 137 138 # If true unnest table aliases are considered only as column aliases 139 UNNEST_COLUMN_ONLY = False 140 141 # Determines whether or not the table alias comes after tablesample 142 ALIAS_POST_TABLESAMPLE = False 143 144 # Determines whether or not unquoted identifiers are resolved as uppercase 145 # When set to None, it means that the dialect treats all identifiers as case-insensitive 146 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 147 148 # Determines whether or not an unquoted identifier can start with a digit 149 IDENTIFIERS_CAN_START_WITH_DIGIT = False 150 151 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 152 DPIPE_IS_STRING_CONCAT = True 153 154 # Determines whether or not CONCAT's arguments must be strings 155 STRICT_STRING_CONCAT = False 156 157 # Determines how function names are going to be normalized 158 NORMALIZE_FUNCTIONS: bool | str = "upper" 159 160 # Indicates the default null ordering method to use if not explicitly set 161 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 162 NULL_ORDERING = "nulls_are_small" 163 164 DATE_FORMAT = "'%Y-%m-%d'" 165 DATEINT_FORMAT = "'%Y%m%d'" 166 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 167 168 # Custom time mappings in which the key represents dialect time format 169 # and the value represents a python time format 170 TIME_MAPPING: t.Dict[str, str] = {} 171 172 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 173 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 174 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 175 FORMAT_MAPPING: t.Dict[str, str] = {} 176 177 # Columns that are auto-generated by the engine corresponding to this dialect 178 # Such columns may be excluded from SELECT * queries, for example 179 PSEUDOCOLUMNS: t.Set[str] = set() 180 181 # Autofilled 182 tokenizer_class = Tokenizer 183 parser_class = Parser 184 generator_class = Generator 185 186 # A trie of the time_mapping keys 187 TIME_TRIE: t.Dict = {} 188 FORMAT_TRIE: t.Dict = {} 189 190 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 191 INVERSE_TIME_TRIE: t.Dict = {} 192 193 def __eq__(self, other: t.Any) -> bool: 194 return type(self) == other 195 196 def __hash__(self) -> int: 197 return hash(type(self)) 198 199 @classmethod 200 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 201 if not dialect: 202 return cls 203 if isinstance(dialect, _Dialect): 204 return dialect 205 if isinstance(dialect, Dialect): 206 return dialect.__class__ 207 208 result = cls.get(dialect) 209 if not result: 210 raise ValueError(f"Unknown dialect '{dialect}'") 211 212 return result 213 214 @classmethod 215 def format_time( 216 cls, expression: t.Optional[str | exp.Expression] 217 ) -> t.Optional[exp.Expression]: 218 if isinstance(expression, str): 219 return exp.Literal.string( 220 # the time formats are quoted 221 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 222 ) 223 224 if expression and expression.is_string: 225 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 226 227 return expression 228 229 @classmethod 230 def normalize_identifier(cls, expression: E) -> E: 231 """ 232 Normalizes an unquoted identifier to either lower or upper case, thus essentially 233 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 234 they will be normalized regardless of being quoted or not. 235 """ 236 if isinstance(expression, exp.Identifier) and ( 237 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 238 ): 239 expression.set( 240 "this", 241 expression.this.upper() 242 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 243 else expression.this.lower(), 244 ) 245 246 return expression 247 248 @classmethod 249 def case_sensitive(cls, text: str) -> bool: 250 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 251 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 252 return False 253 254 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 255 return any(unsafe(char) for char in text) 256 257 @classmethod 258 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 259 """Checks if text can be identified given an identify option. 260 261 Args: 262 text: The text to check. 263 identify: 264 "always" or `True`: Always returns true. 265 "safe": True if the identifier is case-insensitive. 266 267 Returns: 268 Whether or not the given text can be identified. 269 """ 270 if identify is True or identify == "always": 271 return True 272 273 if identify == "safe": 274 return not cls.case_sensitive(text) 275 276 return False 277 278 @classmethod 279 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 280 if isinstance(expression, exp.Identifier): 281 name = expression.this 282 expression.set( 283 "quoted", 284 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 285 ) 286 287 return expression 288 289 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 290 return self.parser(**opts).parse(self.tokenize(sql), sql) 291 292 def parse_into( 293 self, expression_type: exp.IntoType, sql: str, **opts 294 ) -> t.List[t.Optional[exp.Expression]]: 295 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 296 297 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 298 return self.generator(**opts).generate(expression) 299 300 def transpile(self, sql: str, **opts) -> t.List[str]: 301 return [self.generate(expression, **opts) for expression in self.parse(sql)] 302 303 def tokenize(self, sql: str) -> t.List[Token]: 304 return self.tokenizer.tokenize(sql) 305 306 @property 307 def tokenizer(self) -> Tokenizer: 308 if not hasattr(self, "_tokenizer"): 309 self._tokenizer = self.tokenizer_class() 310 return self._tokenizer 311 312 def parser(self, **opts) -> Parser: 313 return self.parser_class(**opts) 314 315 def generator(self, **opts) -> Generator: 316 return self.generator_class(**opts)
199 @classmethod 200 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 201 if not dialect: 202 return cls 203 if isinstance(dialect, _Dialect): 204 return dialect 205 if isinstance(dialect, Dialect): 206 return dialect.__class__ 207 208 result = cls.get(dialect) 209 if not result: 210 raise ValueError(f"Unknown dialect '{dialect}'") 211 212 return result
214 @classmethod 215 def format_time( 216 cls, expression: t.Optional[str | exp.Expression] 217 ) -> t.Optional[exp.Expression]: 218 if isinstance(expression, str): 219 return exp.Literal.string( 220 # the time formats are quoted 221 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 222 ) 223 224 if expression and expression.is_string: 225 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 226 227 return expression
229 @classmethod 230 def normalize_identifier(cls, expression: E) -> E: 231 """ 232 Normalizes an unquoted identifier to either lower or upper case, thus essentially 233 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 234 they will be normalized regardless of being quoted or not. 235 """ 236 if isinstance(expression, exp.Identifier) and ( 237 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 238 ): 239 expression.set( 240 "this", 241 expression.this.upper() 242 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 243 else expression.this.lower(), 244 ) 245 246 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
248 @classmethod 249 def case_sensitive(cls, text: str) -> bool: 250 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 251 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 252 return False 253 254 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 255 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
257 @classmethod 258 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 259 """Checks if text can be identified given an identify option. 260 261 Args: 262 text: The text to check. 263 identify: 264 "always" or `True`: Always returns true. 265 "safe": True if the identifier is case-insensitive. 266 267 Returns: 268 Whether or not the given text can be identified. 269 """ 270 if identify is True or identify == "always": 271 return True 272 273 if identify == "safe": 274 return not cls.case_sensitive(text) 275 276 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
278 @classmethod 279 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 280 if isinstance(expression, exp.Identifier): 281 name = expression.this 282 expression.set( 283 "quoted", 284 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 285 ) 286 287 return expression
409def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 410 this = self.sql(expression, "this") 411 substr = self.sql(expression, "substr") 412 position = self.sql(expression, "position") 413 if position: 414 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 415 return f"STRPOS({this}, {substr})"
424def var_map_sql( 425 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 426) -> str: 427 keys = expression.args["keys"] 428 values = expression.args["values"] 429 430 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 431 self.unsupported("Cannot convert array columns into map.") 432 return self.func(map_func_name, keys, values) 433 434 args = [] 435 for key, value in zip(keys.expressions, values.expressions): 436 args.append(self.sql(key)) 437 args.append(self.sql(value)) 438 439 return self.func(map_func_name, *args)
442def format_time_lambda( 443 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 444) -> t.Callable[[t.List], E]: 445 """Helper used for time expressions. 446 447 Args: 448 exp_class: the expression class to instantiate. 449 dialect: target sql dialect. 450 default: the default format, True being time. 451 452 Returns: 453 A callable that can be used to return the appropriately formatted time expression. 454 """ 455 456 def _format_time(args: t.List): 457 return exp_class( 458 this=seq_get(args, 0), 459 format=Dialect[dialect].format_time( 460 seq_get(args, 1) 461 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 462 ), 463 ) 464 465 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
468def time_format( 469 dialect: DialectType = None, 470) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 471 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 472 """ 473 Returns the time format for a given expression, unless it's equivalent 474 to the default time format of the dialect of interest. 475 """ 476 time_format = self.format_time(expression) 477 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 478 479 return _time_format
482def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 483 """ 484 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 485 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 486 columns are removed from the create statement. 487 """ 488 has_schema = isinstance(expression.this, exp.Schema) 489 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 490 491 if has_schema and is_partitionable: 492 expression = expression.copy() 493 prop = expression.find(exp.PartitionedByProperty) 494 if prop and prop.this and not isinstance(prop.this, exp.Schema): 495 schema = expression.this 496 columns = {v.name.upper() for v in prop.this.expressions} 497 partitions = [col for col in schema.expressions if col.name.upper() in columns] 498 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 499 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 500 expression.set("this", schema) 501 502 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
505def parse_date_delta( 506 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 507) -> t.Callable[[t.List], E]: 508 def inner_func(args: t.List) -> E: 509 unit_based = len(args) == 3 510 this = args[2] if unit_based else seq_get(args, 0) 511 unit = args[0] if unit_based else exp.Literal.string("DAY") 512 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 513 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 514 515 return inner_func
518def parse_date_delta_with_interval( 519 expression_class: t.Type[E], 520) -> t.Callable[[t.List], t.Optional[E]]: 521 def func(args: t.List) -> t.Optional[E]: 522 if len(args) < 2: 523 return None 524 525 interval = args[1] 526 527 if not isinstance(interval, exp.Interval): 528 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 529 530 expression = interval.this 531 if expression and expression.is_string: 532 expression = exp.Literal.number(expression.this) 533 534 return expression_class( 535 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 536 ) 537 538 return func
596def encode_decode_sql( 597 self: Generator, expression: exp.Expression, name: str, replace: bool = True 598) -> str: 599 charset = expression.args.get("charset") 600 if charset and charset.name.lower() != "utf-8": 601 self.unsupported(f"Expected utf-8 character set, got {charset}.") 602 603 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
616def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 617 cond = expression.this 618 619 if isinstance(expression.this, exp.Distinct): 620 cond = expression.this.expressions[0] 621 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 622 623 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
626def trim_sql(self: Generator, expression: exp.Trim) -> str: 627 target = self.sql(expression, "this") 628 trim_type = self.sql(expression, "position") 629 remove_chars = self.sql(expression, "expression") 630 collation = self.sql(expression, "collation") 631 632 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 633 if not remove_chars and not collation: 634 return self.trim_sql(expression) 635 636 trim_type = f"{trim_type} " if trim_type else "" 637 remove_chars = f"{remove_chars} " if remove_chars else "" 638 from_part = "FROM " if trim_type or remove_chars else "" 639 collation = f" COLLATE {collation}" if collation else "" 640 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
647def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 648 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 649 _dialect = Dialect.get_or_raise(dialect) 650 time_format = self.format_time(expression) 651 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 652 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 653 654 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 655 656 return _ts_or_ds_to_date_sql
664def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 665 expression = expression.copy() 666 delim, *rest_args = expression.expressions 667 return self.sql( 668 reduce( 669 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 670 rest_args, 671 ) 672 )
675def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 676 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 677 if bad_args: 678 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 679 680 return self.func( 681 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 682 )
685def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 686 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 687 if bad_args: 688 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 689 690 return self.func( 691 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 692 )
695def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 696 names = [] 697 for agg in aggregations: 698 if isinstance(agg, exp.Alias): 699 names.append(agg.alias) 700 else: 701 """ 702 This case corresponds to aggregations without aliases being used as suffixes 703 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 704 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 705 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 706 """ 707 agg_all_unquoted = agg.transform( 708 lambda node: exp.Identifier(this=node.name, quoted=False) 709 if isinstance(node, exp.Identifier) 710 else node 711 ) 712 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 713 714 return names