sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 85 klass.parser_class = getattr(klass, "Parser", Parser) 86 klass.generator_class = getattr(klass, "Generator", Generator) 87 88 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 89 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 90 klass.tokenizer_class._IDENTIFIERS.items() 91 )[0] 92 93 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 94 return next( 95 ( 96 (s, e) 97 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 98 if t == token_type 99 ), 100 (None, None), 101 ) 102 103 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 104 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 105 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 106 107 dialect_properties = { 108 **{ 109 k: v 110 for k, v in vars(klass).items() 111 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 112 }, 113 "TOKENIZER_CLASS": klass.tokenizer_class, 114 } 115 116 if enum not in ("", "bigquery"): 117 dialect_properties["SELECT_KINDS"] = () 118 119 # Pass required dialect properties to the tokenizer, parser and generator classes 120 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 121 for name, value in dialect_properties.items(): 122 if hasattr(subclass, name): 123 setattr(subclass, name, value) 124 125 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 126 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 127 128 klass.generator_class.can_identify = klass.can_identify 129 130 return klass 131 132 133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Determines whether or not unquoted identifiers are resolved as uppercase 144 # When set to None, it means that the dialect treats all identifiers as case-insensitive 145 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 151 DPIPE_IS_STRING_CONCAT = True 152 153 # Determines whether or not CONCAT's arguments must be strings 154 STRICT_STRING_CONCAT = False 155 156 # Determines whether or not user-defined data types are supported 157 SUPPORTS_USER_DEFINED_TYPES = True 158 159 # Determines how function names are going to be normalized 160 NORMALIZE_FUNCTIONS: bool | str = "upper" 161 162 # Indicates the default null ordering method to use if not explicitly set 163 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 164 NULL_ORDERING = "nulls_are_small" 165 166 DATE_FORMAT = "'%Y-%m-%d'" 167 DATEINT_FORMAT = "'%Y%m%d'" 168 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 169 170 # Custom time mappings in which the key represents dialect time format 171 # and the value represents a python time format 172 TIME_MAPPING: t.Dict[str, str] = {} 173 174 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 175 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 176 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 177 FORMAT_MAPPING: t.Dict[str, str] = {} 178 179 # Columns that are auto-generated by the engine corresponding to this dialect 180 # Such columns may be excluded from SELECT * queries, for example 181 PSEUDOCOLUMNS: t.Set[str] = set() 182 183 # Autofilled 184 tokenizer_class = Tokenizer 185 parser_class = Parser 186 generator_class = Generator 187 188 # A trie of the time_mapping keys 189 TIME_TRIE: t.Dict = {} 190 FORMAT_TRIE: t.Dict = {} 191 192 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 193 INVERSE_TIME_TRIE: t.Dict = {} 194 195 def __eq__(self, other: t.Any) -> bool: 196 return type(self) == other 197 198 def __hash__(self) -> int: 199 return hash(type(self)) 200 201 @classmethod 202 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 203 if not dialect: 204 return cls 205 if isinstance(dialect, _Dialect): 206 return dialect 207 if isinstance(dialect, Dialect): 208 return dialect.__class__ 209 210 result = cls.get(dialect) 211 if not result: 212 raise ValueError(f"Unknown dialect '{dialect}'") 213 214 return result 215 216 @classmethod 217 def format_time( 218 cls, expression: t.Optional[str | exp.Expression] 219 ) -> t.Optional[exp.Expression]: 220 if isinstance(expression, str): 221 return exp.Literal.string( 222 # the time formats are quoted 223 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 224 ) 225 226 if expression and expression.is_string: 227 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 228 229 return expression 230 231 @classmethod 232 def normalize_identifier(cls, expression: E) -> E: 233 """ 234 Normalizes an unquoted identifier to either lower or upper case, thus essentially 235 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 236 they will be normalized regardless of being quoted or not. 237 """ 238 if isinstance(expression, exp.Identifier) and ( 239 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 240 ): 241 expression.set( 242 "this", 243 expression.this.upper() 244 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 245 else expression.this.lower(), 246 ) 247 248 return expression 249 250 @classmethod 251 def case_sensitive(cls, text: str) -> bool: 252 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 253 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 254 return False 255 256 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 257 return any(unsafe(char) for char in text) 258 259 @classmethod 260 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 261 """Checks if text can be identified given an identify option. 262 263 Args: 264 text: The text to check. 265 identify: 266 "always" or `True`: Always returns true. 267 "safe": True if the identifier is case-insensitive. 268 269 Returns: 270 Whether or not the given text can be identified. 271 """ 272 if identify is True or identify == "always": 273 return True 274 275 if identify == "safe": 276 return not cls.case_sensitive(text) 277 278 return False 279 280 @classmethod 281 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 282 if isinstance(expression, exp.Identifier): 283 name = expression.this 284 expression.set( 285 "quoted", 286 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 287 ) 288 289 return expression 290 291 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 292 return self.parser(**opts).parse(self.tokenize(sql), sql) 293 294 def parse_into( 295 self, expression_type: exp.IntoType, sql: str, **opts 296 ) -> t.List[t.Optional[exp.Expression]]: 297 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 298 299 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 300 return self.generator(**opts).generate(expression) 301 302 def transpile(self, sql: str, **opts) -> t.List[str]: 303 return [self.generate(expression, **opts) for expression in self.parse(sql)] 304 305 def tokenize(self, sql: str) -> t.List[Token]: 306 return self.tokenizer.tokenize(sql) 307 308 @property 309 def tokenizer(self) -> Tokenizer: 310 if not hasattr(self, "_tokenizer"): 311 self._tokenizer = self.tokenizer_class() 312 return self._tokenizer 313 314 def parser(self, **opts) -> Parser: 315 return self.parser_class(**opts) 316 317 def generator(self, **opts) -> Generator: 318 return self.generator_class(**opts) 319 320 321DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 322 323 324def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 325 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 326 327 328def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 329 if expression.args.get("accuracy"): 330 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 331 return self.func("APPROX_COUNT_DISTINCT", expression.this) 332 333 334def if_sql(self: Generator, expression: exp.If) -> str: 335 return self.func( 336 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 337 ) 338 339 340def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 341 return self.binary(expression, "->") 342 343 344def arrow_json_extract_scalar_sql( 345 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 346) -> str: 347 return self.binary(expression, "->>") 348 349 350def inline_array_sql(self: Generator, expression: exp.Array) -> str: 351 return f"[{self.expressions(expression, flat=True)}]" 352 353 354def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 355 return self.like_sql( 356 exp.Like( 357 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 358 ) 359 ) 360 361 362def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 363 zone = self.sql(expression, "this") 364 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 365 366 367def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 368 if expression.args.get("recursive"): 369 self.unsupported("Recursive CTEs are unsupported") 370 expression.args["recursive"] = False 371 return self.with_sql(expression) 372 373 374def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 375 n = self.sql(expression, "this") 376 d = self.sql(expression, "expression") 377 return f"IF({d} <> 0, {n} / {d}, NULL)" 378 379 380def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 381 self.unsupported("TABLESAMPLE unsupported") 382 return self.sql(expression.this) 383 384 385def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 386 self.unsupported("PIVOT unsupported") 387 return "" 388 389 390def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 391 return self.cast_sql(expression) 392 393 394def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 395 self.unsupported("Properties unsupported") 396 return "" 397 398 399def no_comment_column_constraint_sql( 400 self: Generator, expression: exp.CommentColumnConstraint 401) -> str: 402 self.unsupported("CommentColumnConstraint unsupported") 403 return "" 404 405 406def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 407 self.unsupported("MAP_FROM_ENTRIES unsupported") 408 return "" 409 410 411def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 412 this = self.sql(expression, "this") 413 substr = self.sql(expression, "substr") 414 position = self.sql(expression, "position") 415 if position: 416 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 417 return f"STRPOS({this}, {substr})" 418 419 420def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 421 return ( 422 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 423 ) 424 425 426def var_map_sql( 427 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 428) -> str: 429 keys = expression.args["keys"] 430 values = expression.args["values"] 431 432 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 433 self.unsupported("Cannot convert array columns into map.") 434 return self.func(map_func_name, keys, values) 435 436 args = [] 437 for key, value in zip(keys.expressions, values.expressions): 438 args.append(self.sql(key)) 439 args.append(self.sql(value)) 440 441 return self.func(map_func_name, *args) 442 443 444def format_time_lambda( 445 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 446) -> t.Callable[[t.List], E]: 447 """Helper used for time expressions. 448 449 Args: 450 exp_class: the expression class to instantiate. 451 dialect: target sql dialect. 452 default: the default format, True being time. 453 454 Returns: 455 A callable that can be used to return the appropriately formatted time expression. 456 """ 457 458 def _format_time(args: t.List): 459 return exp_class( 460 this=seq_get(args, 0), 461 format=Dialect[dialect].format_time( 462 seq_get(args, 1) 463 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 464 ), 465 ) 466 467 return _format_time 468 469 470def time_format( 471 dialect: DialectType = None, 472) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 473 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 474 """ 475 Returns the time format for a given expression, unless it's equivalent 476 to the default time format of the dialect of interest. 477 """ 478 time_format = self.format_time(expression) 479 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 480 481 return _time_format 482 483 484def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 485 """ 486 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 487 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 488 columns are removed from the create statement. 489 """ 490 has_schema = isinstance(expression.this, exp.Schema) 491 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 492 493 if has_schema and is_partitionable: 494 expression = expression.copy() 495 prop = expression.find(exp.PartitionedByProperty) 496 if prop and prop.this and not isinstance(prop.this, exp.Schema): 497 schema = expression.this 498 columns = {v.name.upper() for v in prop.this.expressions} 499 partitions = [col for col in schema.expressions if col.name.upper() in columns] 500 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 501 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 502 expression.set("this", schema) 503 504 return self.create_sql(expression) 505 506 507def parse_date_delta( 508 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 509) -> t.Callable[[t.List], E]: 510 def inner_func(args: t.List) -> E: 511 unit_based = len(args) == 3 512 this = args[2] if unit_based else seq_get(args, 0) 513 unit = args[0] if unit_based else exp.Literal.string("DAY") 514 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 515 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 516 517 return inner_func 518 519 520def parse_date_delta_with_interval( 521 expression_class: t.Type[E], 522) -> t.Callable[[t.List], t.Optional[E]]: 523 def func(args: t.List) -> t.Optional[E]: 524 if len(args) < 2: 525 return None 526 527 interval = args[1] 528 529 if not isinstance(interval, exp.Interval): 530 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 531 532 expression = interval.this 533 if expression and expression.is_string: 534 expression = exp.Literal.number(expression.this) 535 536 return expression_class( 537 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 538 ) 539 540 return func 541 542 543def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 544 unit = seq_get(args, 0) 545 this = seq_get(args, 1) 546 547 if isinstance(this, exp.Cast) and this.is_type("date"): 548 return exp.DateTrunc(unit=unit, this=this) 549 return exp.TimestampTrunc(this=this, unit=unit) 550 551 552def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 553 return self.func( 554 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 555 ) 556 557 558def locate_to_strposition(args: t.List) -> exp.Expression: 559 return exp.StrPosition( 560 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 561 ) 562 563 564def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 565 return self.func( 566 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 567 ) 568 569 570def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 571 expression = expression.copy() 572 return self.sql( 573 exp.Substring( 574 this=expression.this, start=exp.Literal.number(1), length=expression.expression 575 ) 576 ) 577 578 579def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 580 expression = expression.copy() 581 return self.sql( 582 exp.Substring( 583 this=expression.this, 584 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 585 ) 586 ) 587 588 589def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 590 return self.sql(exp.cast(expression.this, "timestamp")) 591 592 593def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 594 return self.sql(exp.cast(expression.this, "date")) 595 596 597# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 598def encode_decode_sql( 599 self: Generator, expression: exp.Expression, name: str, replace: bool = True 600) -> str: 601 charset = expression.args.get("charset") 602 if charset and charset.name.lower() != "utf-8": 603 self.unsupported(f"Expected utf-8 character set, got {charset}.") 604 605 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 606 607 608def min_or_least(self: Generator, expression: exp.Min) -> str: 609 name = "LEAST" if expression.expressions else "MIN" 610 return rename_func(name)(self, expression) 611 612 613def max_or_greatest(self: Generator, expression: exp.Max) -> str: 614 name = "GREATEST" if expression.expressions else "MAX" 615 return rename_func(name)(self, expression) 616 617 618def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 619 cond = expression.this 620 621 if isinstance(expression.this, exp.Distinct): 622 cond = expression.this.expressions[0] 623 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 624 625 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 626 627 628def trim_sql(self: Generator, expression: exp.Trim) -> str: 629 target = self.sql(expression, "this") 630 trim_type = self.sql(expression, "position") 631 remove_chars = self.sql(expression, "expression") 632 collation = self.sql(expression, "collation") 633 634 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 635 if not remove_chars and not collation: 636 return self.trim_sql(expression) 637 638 trim_type = f"{trim_type} " if trim_type else "" 639 remove_chars = f"{remove_chars} " if remove_chars else "" 640 from_part = "FROM " if trim_type or remove_chars else "" 641 collation = f" COLLATE {collation}" if collation else "" 642 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 643 644 645def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 646 return self.func("STRPTIME", expression.this, self.format_time(expression)) 647 648 649def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 650 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 651 _dialect = Dialect.get_or_raise(dialect) 652 time_format = self.format_time(expression) 653 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 654 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 655 656 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 657 658 return _ts_or_ds_to_date_sql 659 660 661def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 662 expression = expression.copy() 663 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 664 665 666def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 667 expression = expression.copy() 668 delim, *rest_args = expression.expressions 669 return self.sql( 670 reduce( 671 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 672 rest_args, 673 ) 674 ) 675 676 677def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 678 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 679 if bad_args: 680 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 681 682 return self.func( 683 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 684 ) 685 686 687def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 688 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 689 if bad_args: 690 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 691 692 return self.func( 693 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 694 ) 695 696 697def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 698 names = [] 699 for agg in aggregations: 700 if isinstance(agg, exp.Alias): 701 names.append(agg.alias) 702 else: 703 """ 704 This case corresponds to aggregations without aliases being used as suffixes 705 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 706 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 707 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 708 """ 709 agg_all_unquoted = agg.transform( 710 lambda node: exp.Identifier(this=node.name, quoted=False) 711 if isinstance(node, exp.Identifier) 712 else node 713 ) 714 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 715 716 return names 717 718 719def simplify_literal(expression: E) -> E: 720 if not isinstance(expression.expression, exp.Literal): 721 from sqlglot.optimizer.simplify import simplify 722 723 simplify(expression.expression) 724 725 return expression 726 727 728def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 729 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 730 731 732# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 733def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 734 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 735 736 737def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 738 return self.func("MAX", expression.this) 739 740 741# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 742def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str: 743 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
134class Dialect(metaclass=_Dialect): 135 # Determines the base index offset for arrays 136 INDEX_OFFSET = 0 137 138 # If true unnest table aliases are considered only as column aliases 139 UNNEST_COLUMN_ONLY = False 140 141 # Determines whether or not the table alias comes after tablesample 142 ALIAS_POST_TABLESAMPLE = False 143 144 # Determines whether or not unquoted identifiers are resolved as uppercase 145 # When set to None, it means that the dialect treats all identifiers as case-insensitive 146 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 147 148 # Determines whether or not an unquoted identifier can start with a digit 149 IDENTIFIERS_CAN_START_WITH_DIGIT = False 150 151 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 152 DPIPE_IS_STRING_CONCAT = True 153 154 # Determines whether or not CONCAT's arguments must be strings 155 STRICT_STRING_CONCAT = False 156 157 # Determines whether or not user-defined data types are supported 158 SUPPORTS_USER_DEFINED_TYPES = True 159 160 # Determines how function names are going to be normalized 161 NORMALIZE_FUNCTIONS: bool | str = "upper" 162 163 # Indicates the default null ordering method to use if not explicitly set 164 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 165 NULL_ORDERING = "nulls_are_small" 166 167 DATE_FORMAT = "'%Y-%m-%d'" 168 DATEINT_FORMAT = "'%Y%m%d'" 169 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 170 171 # Custom time mappings in which the key represents dialect time format 172 # and the value represents a python time format 173 TIME_MAPPING: t.Dict[str, str] = {} 174 175 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 176 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 177 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 178 FORMAT_MAPPING: t.Dict[str, str] = {} 179 180 # Columns that are auto-generated by the engine corresponding to this dialect 181 # Such columns may be excluded from SELECT * queries, for example 182 PSEUDOCOLUMNS: t.Set[str] = set() 183 184 # Autofilled 185 tokenizer_class = Tokenizer 186 parser_class = Parser 187 generator_class = Generator 188 189 # A trie of the time_mapping keys 190 TIME_TRIE: t.Dict = {} 191 FORMAT_TRIE: t.Dict = {} 192 193 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 194 INVERSE_TIME_TRIE: t.Dict = {} 195 196 def __eq__(self, other: t.Any) -> bool: 197 return type(self) == other 198 199 def __hash__(self) -> int: 200 return hash(type(self)) 201 202 @classmethod 203 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 204 if not dialect: 205 return cls 206 if isinstance(dialect, _Dialect): 207 return dialect 208 if isinstance(dialect, Dialect): 209 return dialect.__class__ 210 211 result = cls.get(dialect) 212 if not result: 213 raise ValueError(f"Unknown dialect '{dialect}'") 214 215 return result 216 217 @classmethod 218 def format_time( 219 cls, expression: t.Optional[str | exp.Expression] 220 ) -> t.Optional[exp.Expression]: 221 if isinstance(expression, str): 222 return exp.Literal.string( 223 # the time formats are quoted 224 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 225 ) 226 227 if expression and expression.is_string: 228 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 229 230 return expression 231 232 @classmethod 233 def normalize_identifier(cls, expression: E) -> E: 234 """ 235 Normalizes an unquoted identifier to either lower or upper case, thus essentially 236 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 237 they will be normalized regardless of being quoted or not. 238 """ 239 if isinstance(expression, exp.Identifier) and ( 240 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 241 ): 242 expression.set( 243 "this", 244 expression.this.upper() 245 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 246 else expression.this.lower(), 247 ) 248 249 return expression 250 251 @classmethod 252 def case_sensitive(cls, text: str) -> bool: 253 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 254 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 255 return False 256 257 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 258 return any(unsafe(char) for char in text) 259 260 @classmethod 261 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 262 """Checks if text can be identified given an identify option. 263 264 Args: 265 text: The text to check. 266 identify: 267 "always" or `True`: Always returns true. 268 "safe": True if the identifier is case-insensitive. 269 270 Returns: 271 Whether or not the given text can be identified. 272 """ 273 if identify is True or identify == "always": 274 return True 275 276 if identify == "safe": 277 return not cls.case_sensitive(text) 278 279 return False 280 281 @classmethod 282 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 283 if isinstance(expression, exp.Identifier): 284 name = expression.this 285 expression.set( 286 "quoted", 287 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 288 ) 289 290 return expression 291 292 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 293 return self.parser(**opts).parse(self.tokenize(sql), sql) 294 295 def parse_into( 296 self, expression_type: exp.IntoType, sql: str, **opts 297 ) -> t.List[t.Optional[exp.Expression]]: 298 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 299 300 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 301 return self.generator(**opts).generate(expression) 302 303 def transpile(self, sql: str, **opts) -> t.List[str]: 304 return [self.generate(expression, **opts) for expression in self.parse(sql)] 305 306 def tokenize(self, sql: str) -> t.List[Token]: 307 return self.tokenizer.tokenize(sql) 308 309 @property 310 def tokenizer(self) -> Tokenizer: 311 if not hasattr(self, "_tokenizer"): 312 self._tokenizer = self.tokenizer_class() 313 return self._tokenizer 314 315 def parser(self, **opts) -> Parser: 316 return self.parser_class(**opts) 317 318 def generator(self, **opts) -> Generator: 319 return self.generator_class(**opts)
202 @classmethod 203 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 204 if not dialect: 205 return cls 206 if isinstance(dialect, _Dialect): 207 return dialect 208 if isinstance(dialect, Dialect): 209 return dialect.__class__ 210 211 result = cls.get(dialect) 212 if not result: 213 raise ValueError(f"Unknown dialect '{dialect}'") 214 215 return result
217 @classmethod 218 def format_time( 219 cls, expression: t.Optional[str | exp.Expression] 220 ) -> t.Optional[exp.Expression]: 221 if isinstance(expression, str): 222 return exp.Literal.string( 223 # the time formats are quoted 224 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 225 ) 226 227 if expression and expression.is_string: 228 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 229 230 return expression
232 @classmethod 233 def normalize_identifier(cls, expression: E) -> E: 234 """ 235 Normalizes an unquoted identifier to either lower or upper case, thus essentially 236 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 237 they will be normalized regardless of being quoted or not. 238 """ 239 if isinstance(expression, exp.Identifier) and ( 240 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 241 ): 242 expression.set( 243 "this", 244 expression.this.upper() 245 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 246 else expression.this.lower(), 247 ) 248 249 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
251 @classmethod 252 def case_sensitive(cls, text: str) -> bool: 253 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 254 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 255 return False 256 257 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 258 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
260 @classmethod 261 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 262 """Checks if text can be identified given an identify option. 263 264 Args: 265 text: The text to check. 266 identify: 267 "always" or `True`: Always returns true. 268 "safe": True if the identifier is case-insensitive. 269 270 Returns: 271 Whether or not the given text can be identified. 272 """ 273 if identify is True or identify == "always": 274 return True 275 276 if identify == "safe": 277 return not cls.case_sensitive(text) 278 279 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
281 @classmethod 282 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 283 if isinstance(expression, exp.Identifier): 284 name = expression.this 285 expression.set( 286 "quoted", 287 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 288 ) 289 290 return expression
412def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 413 this = self.sql(expression, "this") 414 substr = self.sql(expression, "substr") 415 position = self.sql(expression, "position") 416 if position: 417 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 418 return f"STRPOS({this}, {substr})"
427def var_map_sql( 428 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 429) -> str: 430 keys = expression.args["keys"] 431 values = expression.args["values"] 432 433 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 434 self.unsupported("Cannot convert array columns into map.") 435 return self.func(map_func_name, keys, values) 436 437 args = [] 438 for key, value in zip(keys.expressions, values.expressions): 439 args.append(self.sql(key)) 440 args.append(self.sql(value)) 441 442 return self.func(map_func_name, *args)
445def format_time_lambda( 446 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 447) -> t.Callable[[t.List], E]: 448 """Helper used for time expressions. 449 450 Args: 451 exp_class: the expression class to instantiate. 452 dialect: target sql dialect. 453 default: the default format, True being time. 454 455 Returns: 456 A callable that can be used to return the appropriately formatted time expression. 457 """ 458 459 def _format_time(args: t.List): 460 return exp_class( 461 this=seq_get(args, 0), 462 format=Dialect[dialect].format_time( 463 seq_get(args, 1) 464 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 465 ), 466 ) 467 468 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
471def time_format( 472 dialect: DialectType = None, 473) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 474 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 475 """ 476 Returns the time format for a given expression, unless it's equivalent 477 to the default time format of the dialect of interest. 478 """ 479 time_format = self.format_time(expression) 480 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 481 482 return _time_format
485def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 486 """ 487 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 488 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 489 columns are removed from the create statement. 490 """ 491 has_schema = isinstance(expression.this, exp.Schema) 492 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 493 494 if has_schema and is_partitionable: 495 expression = expression.copy() 496 prop = expression.find(exp.PartitionedByProperty) 497 if prop and prop.this and not isinstance(prop.this, exp.Schema): 498 schema = expression.this 499 columns = {v.name.upper() for v in prop.this.expressions} 500 partitions = [col for col in schema.expressions if col.name.upper() in columns] 501 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 502 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 503 expression.set("this", schema) 504 505 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
508def parse_date_delta( 509 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 510) -> t.Callable[[t.List], E]: 511 def inner_func(args: t.List) -> E: 512 unit_based = len(args) == 3 513 this = args[2] if unit_based else seq_get(args, 0) 514 unit = args[0] if unit_based else exp.Literal.string("DAY") 515 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 516 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 517 518 return inner_func
521def parse_date_delta_with_interval( 522 expression_class: t.Type[E], 523) -> t.Callable[[t.List], t.Optional[E]]: 524 def func(args: t.List) -> t.Optional[E]: 525 if len(args) < 2: 526 return None 527 528 interval = args[1] 529 530 if not isinstance(interval, exp.Interval): 531 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 532 533 expression = interval.this 534 if expression and expression.is_string: 535 expression = exp.Literal.number(expression.this) 536 537 return expression_class( 538 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 539 ) 540 541 return func
599def encode_decode_sql( 600 self: Generator, expression: exp.Expression, name: str, replace: bool = True 601) -> str: 602 charset = expression.args.get("charset") 603 if charset and charset.name.lower() != "utf-8": 604 self.unsupported(f"Expected utf-8 character set, got {charset}.") 605 606 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
619def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 620 cond = expression.this 621 622 if isinstance(expression.this, exp.Distinct): 623 cond = expression.this.expressions[0] 624 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 625 626 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
629def trim_sql(self: Generator, expression: exp.Trim) -> str: 630 target = self.sql(expression, "this") 631 trim_type = self.sql(expression, "position") 632 remove_chars = self.sql(expression, "expression") 633 collation = self.sql(expression, "collation") 634 635 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 636 if not remove_chars and not collation: 637 return self.trim_sql(expression) 638 639 trim_type = f"{trim_type} " if trim_type else "" 640 remove_chars = f"{remove_chars} " if remove_chars else "" 641 from_part = "FROM " if trim_type or remove_chars else "" 642 collation = f" COLLATE {collation}" if collation else "" 643 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
650def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 651 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 652 _dialect = Dialect.get_or_raise(dialect) 653 time_format = self.format_time(expression) 654 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 655 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 656 657 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 658 659 return _ts_or_ds_to_date_sql
667def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 668 expression = expression.copy() 669 delim, *rest_args = expression.expressions 670 return self.sql( 671 reduce( 672 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 673 rest_args, 674 ) 675 )
678def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 679 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 680 if bad_args: 681 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 682 683 return self.func( 684 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 685 )
688def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 689 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 690 if bad_args: 691 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 692 693 return self.func( 694 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 695 )
698def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 699 names = [] 700 for agg in aggregations: 701 if isinstance(agg, exp.Alias): 702 names.append(agg.alias) 703 else: 704 """ 705 This case corresponds to aggregations without aliases being used as suffixes 706 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 707 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 708 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 709 """ 710 agg_all_unquoted = agg.transform( 711 lambda node: exp.Identifier(this=node.name, quoted=False) 712 if isinstance(node, exp.Identifier) 713 else node 714 ) 715 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 716 717 return names