sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum, auto 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 21 22 23class Dialects(str, Enum): 24 DIALECT = "" 25 26 BIGQUERY = "bigquery" 27 CLICKHOUSE = "clickhouse" 28 DATABRICKS = "databricks" 29 DRILL = "drill" 30 DUCKDB = "duckdb" 31 HIVE = "hive" 32 MYSQL = "mysql" 33 ORACLE = "oracle" 34 POSTGRES = "postgres" 35 PRESTO = "presto" 36 REDSHIFT = "redshift" 37 SNOWFLAKE = "snowflake" 38 SPARK = "spark" 39 SPARK2 = "spark2" 40 SQLITE = "sqlite" 41 STARROCKS = "starrocks" 42 TABLEAU = "tableau" 43 TERADATA = "teradata" 44 TRINO = "trino" 45 TSQL = "tsql" 46 Doris = "doris" 47 48 49class NormalizationStrategy(str, AutoName): 50 """Specifies the strategy according to which identifiers should be normalized.""" 51 52 LOWERCASE = auto() # Unquoted identifiers are lowercased 53 UPPERCASE = auto() # Unquoted identifiers are uppercased 54 CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes 55 CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes 56 57 58class _Dialect(type): 59 classes: t.Dict[str, t.Type[Dialect]] = {} 60 61 def __eq__(cls, other: t.Any) -> bool: 62 if cls is other: 63 return True 64 if isinstance(other, str): 65 return cls is cls.get(other) 66 if isinstance(other, Dialect): 67 return cls is type(other) 68 69 return False 70 71 def __hash__(cls) -> int: 72 return hash(cls.__name__.lower()) 73 74 @classmethod 75 def __getitem__(cls, key: str) -> t.Type[Dialect]: 76 return cls.classes[key] 77 78 @classmethod 79 def get( 80 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 81 ) -> t.Optional[t.Type[Dialect]]: 82 return cls.classes.get(key, default) 83 84 def __new__(cls, clsname, bases, attrs): 85 klass = super().__new__(cls, clsname, bases, attrs) 86 enum = Dialects.__members__.get(clsname.upper()) 87 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 88 89 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 90 klass.FORMAT_TRIE = ( 91 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 92 ) 93 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 94 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 95 96 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 97 98 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 99 klass.parser_class = getattr(klass, "Parser", Parser) 100 klass.generator_class = getattr(klass, "Generator", Generator) 101 102 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 103 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 104 klass.tokenizer_class._IDENTIFIERS.items() 105 )[0] 106 107 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 108 return next( 109 ( 110 (s, e) 111 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 112 if t == token_type 113 ), 114 (None, None), 115 ) 116 117 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 118 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 119 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 120 121 if enum not in ("", "bigquery"): 122 klass.generator_class.SELECT_KINDS = () 123 124 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 125 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 126 TokenType.ANTI, 127 TokenType.SEMI, 128 } 129 130 return klass 131 132 133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Specifies the strategy according to which identifiers should be normalized. 144 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 145 146 # Determines whether or not an unquoted identifier can start with a digit 147 IDENTIFIERS_CAN_START_WITH_DIGIT = False 148 149 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 150 DPIPE_IS_STRING_CONCAT = True 151 152 # Determines whether or not CONCAT's arguments must be strings 153 STRICT_STRING_CONCAT = False 154 155 # Determines whether or not user-defined data types are supported 156 SUPPORTS_USER_DEFINED_TYPES = True 157 158 # Determines whether or not SEMI/ANTI JOINs are supported 159 SUPPORTS_SEMI_ANTI_JOIN = True 160 161 # Determines how function names are going to be normalized 162 NORMALIZE_FUNCTIONS: bool | str = "upper" 163 164 # Determines whether the base comes first in the LOG function 165 LOG_BASE_FIRST = True 166 167 # Indicates the default null ordering method to use if not explicitly set 168 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 169 NULL_ORDERING = "nulls_are_small" 170 171 # Whether the behavior of a / b depends on the types of a and b. 172 # False means a / b is always float division. 173 # True means a / b is integer division if both a and b are integers. 174 TYPED_DIVISION = False 175 176 # False means 1 / 0 throws an error. 177 # True means 1 / 0 returns null. 178 SAFE_DIVISION = False 179 180 DATE_FORMAT = "'%Y-%m-%d'" 181 DATEINT_FORMAT = "'%Y%m%d'" 182 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 183 184 # Custom time mappings in which the key represents dialect time format 185 # and the value represents a python time format 186 TIME_MAPPING: t.Dict[str, str] = {} 187 188 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 189 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 190 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 191 FORMAT_MAPPING: t.Dict[str, str] = {} 192 193 # Mapping of an unescaped escape sequence to the corresponding character 194 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 195 196 # Columns that are auto-generated by the engine corresponding to this dialect 197 # Such columns may be excluded from SELECT * queries, for example 198 PSEUDOCOLUMNS: t.Set[str] = set() 199 200 # --- Autofilled --- 201 202 tokenizer_class = Tokenizer 203 parser_class = Parser 204 generator_class = Generator 205 206 # A trie of the time_mapping keys 207 TIME_TRIE: t.Dict = {} 208 FORMAT_TRIE: t.Dict = {} 209 210 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 211 INVERSE_TIME_TRIE: t.Dict = {} 212 213 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 214 215 # Delimiters for quotes, identifiers and the corresponding escape characters 216 QUOTE_START = "'" 217 QUOTE_END = "'" 218 IDENTIFIER_START = '"' 219 IDENTIFIER_END = '"' 220 221 # Delimiters for bit, hex and byte literals 222 BIT_START: t.Optional[str] = None 223 BIT_END: t.Optional[str] = None 224 HEX_START: t.Optional[str] = None 225 HEX_END: t.Optional[str] = None 226 BYTE_START: t.Optional[str] = None 227 BYTE_END: t.Optional[str] = None 228 229 @classmethod 230 def get_or_raise(cls, dialect: DialectType) -> Dialect: 231 """ 232 Look up a dialect in the global dialect registry and return it if it exists. 233 234 Args: 235 dialect: The target dialect. If this is a string, it can be optionally followed by 236 additional key-value pairs that are separated by commas and are used to specify 237 dialect settings, such as whether the dialect's identifiers are case-sensitive. 238 239 Example: 240 >>> dialect = dialect_class = get_or_raise("duckdb") 241 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 242 243 Returns: 244 The corresponding Dialect instance. 245 """ 246 247 if not dialect: 248 return cls() 249 if isinstance(dialect, _Dialect): 250 return dialect() 251 if isinstance(dialect, Dialect): 252 return dialect 253 if isinstance(dialect, str): 254 try: 255 dialect_name, *kv_pairs = dialect.split(",") 256 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 257 except ValueError: 258 raise ValueError( 259 f"Invalid dialect format: '{dialect}'. " 260 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 261 ) 262 263 result = cls.get(dialect_name.strip()) 264 if not result: 265 raise ValueError(f"Unknown dialect '{dialect_name}'.") 266 267 return result(**kwargs) 268 269 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 270 271 @classmethod 272 def format_time( 273 cls, expression: t.Optional[str | exp.Expression] 274 ) -> t.Optional[exp.Expression]: 275 if isinstance(expression, str): 276 return exp.Literal.string( 277 # the time formats are quoted 278 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 279 ) 280 281 if expression and expression.is_string: 282 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 283 284 return expression 285 286 def __init__(self, **kwargs) -> None: 287 normalization_strategy = kwargs.get("normalization_strategy") 288 289 if normalization_strategy is None: 290 self.normalization_strategy = self.NORMALIZATION_STRATEGY 291 else: 292 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 293 294 def __eq__(self, other: t.Any) -> bool: 295 # Does not currently take dialect state into account 296 return type(self) == other 297 298 def __hash__(self) -> int: 299 # Does not currently take dialect state into account 300 return hash(type(self)) 301 302 def normalize_identifier(self, expression: E) -> E: 303 """ 304 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 305 306 For example, an identifier like FoO would be resolved as foo in Postgres, because it 307 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 308 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 309 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 310 311 There are also dialects like Spark, which are case-insensitive even when quotes are 312 present, and dialects like MySQL, whose resolution rules match those employed by the 313 underlying operating system, for example they may always be case-sensitive in Linux. 314 315 Finally, the normalization behavior of some engines can even be controlled through flags, 316 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 317 318 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 319 that it can analyze queries in the optimizer and successfully capture their semantics. 320 """ 321 if ( 322 isinstance(expression, exp.Identifier) 323 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 324 and ( 325 not expression.quoted 326 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 327 ) 328 ): 329 expression.set( 330 "this", 331 expression.this.upper() 332 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 333 else expression.this.lower(), 334 ) 335 336 return expression 337 338 def case_sensitive(self, text: str) -> bool: 339 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 340 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 341 return False 342 343 unsafe = ( 344 str.islower 345 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 346 else str.isupper 347 ) 348 return any(unsafe(char) for char in text) 349 350 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 351 """Checks if text can be identified given an identify option. 352 353 Args: 354 text: The text to check. 355 identify: 356 "always" or `True`: Always returns true. 357 "safe": True if the identifier is case-insensitive. 358 359 Returns: 360 Whether or not the given text can be identified. 361 """ 362 if identify is True or identify == "always": 363 return True 364 365 if identify == "safe": 366 return not self.case_sensitive(text) 367 368 return False 369 370 def quote_identifier(self, expression: E, identify: bool = True) -> E: 371 if isinstance(expression, exp.Identifier): 372 name = expression.this 373 expression.set( 374 "quoted", 375 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 376 ) 377 378 return expression 379 380 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 381 return self.parser(**opts).parse(self.tokenize(sql), sql) 382 383 def parse_into( 384 self, expression_type: exp.IntoType, sql: str, **opts 385 ) -> t.List[t.Optional[exp.Expression]]: 386 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 387 388 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 389 return self.generator(**opts).generate(expression, copy=copy) 390 391 def transpile(self, sql: str, **opts) -> t.List[str]: 392 return [ 393 self.generate(expression, copy=False, **opts) if expression else "" 394 for expression in self.parse(sql) 395 ] 396 397 def tokenize(self, sql: str) -> t.List[Token]: 398 return self.tokenizer.tokenize(sql) 399 400 @property 401 def tokenizer(self) -> Tokenizer: 402 if not hasattr(self, "_tokenizer"): 403 self._tokenizer = self.tokenizer_class(dialect=self) 404 return self._tokenizer 405 406 def parser(self, **opts) -> Parser: 407 return self.parser_class(dialect=self, **opts) 408 409 def generator(self, **opts) -> Generator: 410 return self.generator_class(dialect=self, **opts) 411 412 413DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 414 415 416def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 417 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 418 419 420def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 421 if expression.args.get("accuracy"): 422 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 423 return self.func("APPROX_COUNT_DISTINCT", expression.this) 424 425 426def if_sql( 427 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 428) -> t.Callable[[Generator, exp.If], str]: 429 def _if_sql(self: Generator, expression: exp.If) -> str: 430 return self.func( 431 name, 432 expression.this, 433 expression.args.get("true"), 434 expression.args.get("false") or false_value, 435 ) 436 437 return _if_sql 438 439 440def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 441 return self.binary(expression, "->") 442 443 444def arrow_json_extract_scalar_sql( 445 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 446) -> str: 447 return self.binary(expression, "->>") 448 449 450def inline_array_sql(self: Generator, expression: exp.Array) -> str: 451 return f"[{self.expressions(expression, flat=True)}]" 452 453 454def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 455 return self.like_sql( 456 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 457 ) 458 459 460def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 461 zone = self.sql(expression, "this") 462 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 463 464 465def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 466 if expression.args.get("recursive"): 467 self.unsupported("Recursive CTEs are unsupported") 468 expression.args["recursive"] = False 469 return self.with_sql(expression) 470 471 472def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 473 n = self.sql(expression, "this") 474 d = self.sql(expression, "expression") 475 return f"IF({d} <> 0, {n} / {d}, NULL)" 476 477 478def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 479 self.unsupported("TABLESAMPLE unsupported") 480 return self.sql(expression.this) 481 482 483def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 484 self.unsupported("PIVOT unsupported") 485 return "" 486 487 488def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 489 return self.cast_sql(expression) 490 491 492def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 493 self.unsupported("Properties unsupported") 494 return "" 495 496 497def no_comment_column_constraint_sql( 498 self: Generator, expression: exp.CommentColumnConstraint 499) -> str: 500 self.unsupported("CommentColumnConstraint unsupported") 501 return "" 502 503 504def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 505 self.unsupported("MAP_FROM_ENTRIES unsupported") 506 return "" 507 508 509def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 510 this = self.sql(expression, "this") 511 substr = self.sql(expression, "substr") 512 position = self.sql(expression, "position") 513 if position: 514 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 515 return f"STRPOS({this}, {substr})" 516 517 518def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 519 return ( 520 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 521 ) 522 523 524def var_map_sql( 525 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 526) -> str: 527 keys = expression.args["keys"] 528 values = expression.args["values"] 529 530 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 531 self.unsupported("Cannot convert array columns into map.") 532 return self.func(map_func_name, keys, values) 533 534 args = [] 535 for key, value in zip(keys.expressions, values.expressions): 536 args.append(self.sql(key)) 537 args.append(self.sql(value)) 538 539 return self.func(map_func_name, *args) 540 541 542def format_time_lambda( 543 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 544) -> t.Callable[[t.List], E]: 545 """Helper used for time expressions. 546 547 Args: 548 exp_class: the expression class to instantiate. 549 dialect: target sql dialect. 550 default: the default format, True being time. 551 552 Returns: 553 A callable that can be used to return the appropriately formatted time expression. 554 """ 555 556 def _format_time(args: t.List): 557 return exp_class( 558 this=seq_get(args, 0), 559 format=Dialect[dialect].format_time( 560 seq_get(args, 1) 561 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 562 ), 563 ) 564 565 return _format_time 566 567 568def time_format( 569 dialect: DialectType = None, 570) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 571 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 572 """ 573 Returns the time format for a given expression, unless it's equivalent 574 to the default time format of the dialect of interest. 575 """ 576 time_format = self.format_time(expression) 577 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 578 579 return _time_format 580 581 582def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 583 """ 584 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 585 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 586 columns are removed from the create statement. 587 """ 588 has_schema = isinstance(expression.this, exp.Schema) 589 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 590 591 if has_schema and is_partitionable: 592 prop = expression.find(exp.PartitionedByProperty) 593 if prop and prop.this and not isinstance(prop.this, exp.Schema): 594 schema = expression.this 595 columns = {v.name.upper() for v in prop.this.expressions} 596 partitions = [col for col in schema.expressions if col.name.upper() in columns] 597 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 598 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 599 expression.set("this", schema) 600 601 return self.create_sql(expression) 602 603 604def parse_date_delta( 605 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 606) -> t.Callable[[t.List], E]: 607 def inner_func(args: t.List) -> E: 608 unit_based = len(args) == 3 609 this = args[2] if unit_based else seq_get(args, 0) 610 unit = args[0] if unit_based else exp.Literal.string("DAY") 611 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 612 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 613 614 return inner_func 615 616 617def parse_date_delta_with_interval( 618 expression_class: t.Type[E], 619) -> t.Callable[[t.List], t.Optional[E]]: 620 def func(args: t.List) -> t.Optional[E]: 621 if len(args) < 2: 622 return None 623 624 interval = args[1] 625 626 if not isinstance(interval, exp.Interval): 627 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 628 629 expression = interval.this 630 if expression and expression.is_string: 631 expression = exp.Literal.number(expression.this) 632 633 return expression_class( 634 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 635 ) 636 637 return func 638 639 640def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 641 unit = seq_get(args, 0) 642 this = seq_get(args, 1) 643 644 if isinstance(this, exp.Cast) and this.is_type("date"): 645 return exp.DateTrunc(unit=unit, this=this) 646 return exp.TimestampTrunc(this=this, unit=unit) 647 648 649def date_add_interval_sql( 650 data_type: str, kind: str 651) -> t.Callable[[Generator, exp.Expression], str]: 652 def func(self: Generator, expression: exp.Expression) -> str: 653 this = self.sql(expression, "this") 654 unit = expression.args.get("unit") 655 unit = exp.var(unit.name.upper() if unit else "DAY") 656 interval = exp.Interval(this=expression.expression, unit=unit) 657 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 658 659 return func 660 661 662def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 663 return self.func( 664 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 665 ) 666 667 668def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 669 if not expression.expression: 670 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 671 if expression.text("expression").lower() in TIMEZONES: 672 return self.sql( 673 exp.AtTimeZone( 674 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 675 zone=expression.expression, 676 ) 677 ) 678 return self.function_fallback_sql(expression) 679 680 681def locate_to_strposition(args: t.List) -> exp.Expression: 682 return exp.StrPosition( 683 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 684 ) 685 686 687def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 688 return self.func( 689 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 690 ) 691 692 693def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 694 return self.sql( 695 exp.Substring( 696 this=expression.this, start=exp.Literal.number(1), length=expression.expression 697 ) 698 ) 699 700 701def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 702 return self.sql( 703 exp.Substring( 704 this=expression.this, 705 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 706 ) 707 ) 708 709 710def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 711 return self.sql(exp.cast(expression.this, "timestamp")) 712 713 714def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 715 return self.sql(exp.cast(expression.this, "date")) 716 717 718# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 719def encode_decode_sql( 720 self: Generator, expression: exp.Expression, name: str, replace: bool = True 721) -> str: 722 charset = expression.args.get("charset") 723 if charset and charset.name.lower() != "utf-8": 724 self.unsupported(f"Expected utf-8 character set, got {charset}.") 725 726 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 727 728 729def min_or_least(self: Generator, expression: exp.Min) -> str: 730 name = "LEAST" if expression.expressions else "MIN" 731 return rename_func(name)(self, expression) 732 733 734def max_or_greatest(self: Generator, expression: exp.Max) -> str: 735 name = "GREATEST" if expression.expressions else "MAX" 736 return rename_func(name)(self, expression) 737 738 739def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 740 cond = expression.this 741 742 if isinstance(expression.this, exp.Distinct): 743 cond = expression.this.expressions[0] 744 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 745 746 return self.func("sum", exp.func("if", cond, 1, 0)) 747 748 749def trim_sql(self: Generator, expression: exp.Trim) -> str: 750 target = self.sql(expression, "this") 751 trim_type = self.sql(expression, "position") 752 remove_chars = self.sql(expression, "expression") 753 collation = self.sql(expression, "collation") 754 755 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 756 if not remove_chars and not collation: 757 return self.trim_sql(expression) 758 759 trim_type = f"{trim_type} " if trim_type else "" 760 remove_chars = f"{remove_chars} " if remove_chars else "" 761 from_part = "FROM " if trim_type or remove_chars else "" 762 collation = f" COLLATE {collation}" if collation else "" 763 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 764 765 766def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 767 return self.func("STRPTIME", expression.this, self.format_time(expression)) 768 769 770def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 771 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 772 _dialect = Dialect.get_or_raise(dialect) 773 time_format = self.format_time(expression) 774 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 775 return self.sql( 776 exp.cast( 777 exp.StrToTime(this=expression.this, format=expression.args["format"]), 778 "date", 779 ) 780 ) 781 return self.sql(exp.cast(expression.this, "date")) 782 783 return _ts_or_ds_to_date_sql 784 785 786def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 787 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 788 789 790def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 791 delim, *rest_args = expression.expressions 792 return self.sql( 793 reduce( 794 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 795 rest_args, 796 ) 797 ) 798 799 800def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 801 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 802 if bad_args: 803 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 804 805 return self.func( 806 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 807 ) 808 809 810def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 811 bad_args = list( 812 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 813 ) 814 if bad_args: 815 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 816 817 return self.func( 818 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 819 ) 820 821 822def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 823 names = [] 824 for agg in aggregations: 825 if isinstance(agg, exp.Alias): 826 names.append(agg.alias) 827 else: 828 """ 829 This case corresponds to aggregations without aliases being used as suffixes 830 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 831 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 832 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 833 """ 834 agg_all_unquoted = agg.transform( 835 lambda node: exp.Identifier(this=node.name, quoted=False) 836 if isinstance(node, exp.Identifier) 837 else node 838 ) 839 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 840 841 return names 842 843 844def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 845 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 846 847 848# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 849def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 850 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 851 852 853def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 854 return self.func("MAX", expression.this) 855 856 857def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 858 a = self.sql(expression.left) 859 b = self.sql(expression.right) 860 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 861 862 863# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 864def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 865 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 866 867 868def is_parse_json(expression: exp.Expression) -> bool: 869 return isinstance(expression, exp.ParseJSON) or ( 870 isinstance(expression, exp.Cast) and expression.is_type("json") 871 ) 872 873 874def isnull_to_is_null(args: t.List) -> exp.Expression: 875 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 876 877 878def generatedasidentitycolumnconstraint_sql( 879 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 880) -> str: 881 start = self.sql(expression, "start") or "1" 882 increment = self.sql(expression, "increment") or "1" 883 return f"IDENTITY({start}, {increment})" 884 885 886def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 887 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 888 if expression.args.get("count"): 889 self.unsupported(f"Only two arguments are supported in function {name}.") 890 891 return self.func(name, expression.this, expression.expression) 892 893 return _arg_max_or_min_sql 894 895 896def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 897 this = expression.this.copy() 898 899 return_type = expression.return_type 900 if return_type.is_type(exp.DataType.Type.DATE): 901 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 902 # can truncate timestamp strings, because some dialects can't cast them to DATE 903 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 904 905 expression.this.replace(exp.cast(this, return_type)) 906 return expression 907 908 909def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 910 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 911 if cast and isinstance(expression, exp.TsOrDsAdd): 912 expression = ts_or_ds_add_cast(expression) 913 914 return self.func( 915 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 916 ) 917 918 return _delta_sql
24class Dialects(str, Enum): 25 DIALECT = "" 26 27 BIGQUERY = "bigquery" 28 CLICKHOUSE = "clickhouse" 29 DATABRICKS = "databricks" 30 DRILL = "drill" 31 DUCKDB = "duckdb" 32 HIVE = "hive" 33 MYSQL = "mysql" 34 ORACLE = "oracle" 35 POSTGRES = "postgres" 36 PRESTO = "presto" 37 REDSHIFT = "redshift" 38 SNOWFLAKE = "snowflake" 39 SPARK = "spark" 40 SPARK2 = "spark2" 41 SQLITE = "sqlite" 42 STARROCKS = "starrocks" 43 TABLEAU = "tableau" 44 TERADATA = "teradata" 45 TRINO = "trino" 46 TSQL = "tsql" 47 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
50class NormalizationStrategy(str, AutoName): 51 """Specifies the strategy according to which identifiers should be normalized.""" 52 53 LOWERCASE = auto() # Unquoted identifiers are lowercased 54 UPPERCASE = auto() # Unquoted identifiers are uppercased 55 CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes 56 CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
Specifies the strategy according to which identifiers should be normalized.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
134class Dialect(metaclass=_Dialect): 135 # Determines the base index offset for arrays 136 INDEX_OFFSET = 0 137 138 # If true unnest table aliases are considered only as column aliases 139 UNNEST_COLUMN_ONLY = False 140 141 # Determines whether or not the table alias comes after tablesample 142 ALIAS_POST_TABLESAMPLE = False 143 144 # Specifies the strategy according to which identifiers should be normalized. 145 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 151 DPIPE_IS_STRING_CONCAT = True 152 153 # Determines whether or not CONCAT's arguments must be strings 154 STRICT_STRING_CONCAT = False 155 156 # Determines whether or not user-defined data types are supported 157 SUPPORTS_USER_DEFINED_TYPES = True 158 159 # Determines whether or not SEMI/ANTI JOINs are supported 160 SUPPORTS_SEMI_ANTI_JOIN = True 161 162 # Determines how function names are going to be normalized 163 NORMALIZE_FUNCTIONS: bool | str = "upper" 164 165 # Determines whether the base comes first in the LOG function 166 LOG_BASE_FIRST = True 167 168 # Indicates the default null ordering method to use if not explicitly set 169 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 170 NULL_ORDERING = "nulls_are_small" 171 172 # Whether the behavior of a / b depends on the types of a and b. 173 # False means a / b is always float division. 174 # True means a / b is integer division if both a and b are integers. 175 TYPED_DIVISION = False 176 177 # False means 1 / 0 throws an error. 178 # True means 1 / 0 returns null. 179 SAFE_DIVISION = False 180 181 DATE_FORMAT = "'%Y-%m-%d'" 182 DATEINT_FORMAT = "'%Y%m%d'" 183 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 184 185 # Custom time mappings in which the key represents dialect time format 186 # and the value represents a python time format 187 TIME_MAPPING: t.Dict[str, str] = {} 188 189 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 190 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 191 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 192 FORMAT_MAPPING: t.Dict[str, str] = {} 193 194 # Mapping of an unescaped escape sequence to the corresponding character 195 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 196 197 # Columns that are auto-generated by the engine corresponding to this dialect 198 # Such columns may be excluded from SELECT * queries, for example 199 PSEUDOCOLUMNS: t.Set[str] = set() 200 201 # --- Autofilled --- 202 203 tokenizer_class = Tokenizer 204 parser_class = Parser 205 generator_class = Generator 206 207 # A trie of the time_mapping keys 208 TIME_TRIE: t.Dict = {} 209 FORMAT_TRIE: t.Dict = {} 210 211 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 212 INVERSE_TIME_TRIE: t.Dict = {} 213 214 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 215 216 # Delimiters for quotes, identifiers and the corresponding escape characters 217 QUOTE_START = "'" 218 QUOTE_END = "'" 219 IDENTIFIER_START = '"' 220 IDENTIFIER_END = '"' 221 222 # Delimiters for bit, hex and byte literals 223 BIT_START: t.Optional[str] = None 224 BIT_END: t.Optional[str] = None 225 HEX_START: t.Optional[str] = None 226 HEX_END: t.Optional[str] = None 227 BYTE_START: t.Optional[str] = None 228 BYTE_END: t.Optional[str] = None 229 230 @classmethod 231 def get_or_raise(cls, dialect: DialectType) -> Dialect: 232 """ 233 Look up a dialect in the global dialect registry and return it if it exists. 234 235 Args: 236 dialect: The target dialect. If this is a string, it can be optionally followed by 237 additional key-value pairs that are separated by commas and are used to specify 238 dialect settings, such as whether the dialect's identifiers are case-sensitive. 239 240 Example: 241 >>> dialect = dialect_class = get_or_raise("duckdb") 242 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 243 244 Returns: 245 The corresponding Dialect instance. 246 """ 247 248 if not dialect: 249 return cls() 250 if isinstance(dialect, _Dialect): 251 return dialect() 252 if isinstance(dialect, Dialect): 253 return dialect 254 if isinstance(dialect, str): 255 try: 256 dialect_name, *kv_pairs = dialect.split(",") 257 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 258 except ValueError: 259 raise ValueError( 260 f"Invalid dialect format: '{dialect}'. " 261 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 262 ) 263 264 result = cls.get(dialect_name.strip()) 265 if not result: 266 raise ValueError(f"Unknown dialect '{dialect_name}'.") 267 268 return result(**kwargs) 269 270 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 271 272 @classmethod 273 def format_time( 274 cls, expression: t.Optional[str | exp.Expression] 275 ) -> t.Optional[exp.Expression]: 276 if isinstance(expression, str): 277 return exp.Literal.string( 278 # the time formats are quoted 279 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 280 ) 281 282 if expression and expression.is_string: 283 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 284 285 return expression 286 287 def __init__(self, **kwargs) -> None: 288 normalization_strategy = kwargs.get("normalization_strategy") 289 290 if normalization_strategy is None: 291 self.normalization_strategy = self.NORMALIZATION_STRATEGY 292 else: 293 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 294 295 def __eq__(self, other: t.Any) -> bool: 296 # Does not currently take dialect state into account 297 return type(self) == other 298 299 def __hash__(self) -> int: 300 # Does not currently take dialect state into account 301 return hash(type(self)) 302 303 def normalize_identifier(self, expression: E) -> E: 304 """ 305 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 306 307 For example, an identifier like FoO would be resolved as foo in Postgres, because it 308 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 309 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 310 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 311 312 There are also dialects like Spark, which are case-insensitive even when quotes are 313 present, and dialects like MySQL, whose resolution rules match those employed by the 314 underlying operating system, for example they may always be case-sensitive in Linux. 315 316 Finally, the normalization behavior of some engines can even be controlled through flags, 317 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 318 319 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 320 that it can analyze queries in the optimizer and successfully capture their semantics. 321 """ 322 if ( 323 isinstance(expression, exp.Identifier) 324 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 325 and ( 326 not expression.quoted 327 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 328 ) 329 ): 330 expression.set( 331 "this", 332 expression.this.upper() 333 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 334 else expression.this.lower(), 335 ) 336 337 return expression 338 339 def case_sensitive(self, text: str) -> bool: 340 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 341 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 342 return False 343 344 unsafe = ( 345 str.islower 346 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 347 else str.isupper 348 ) 349 return any(unsafe(char) for char in text) 350 351 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 352 """Checks if text can be identified given an identify option. 353 354 Args: 355 text: The text to check. 356 identify: 357 "always" or `True`: Always returns true. 358 "safe": True if the identifier is case-insensitive. 359 360 Returns: 361 Whether or not the given text can be identified. 362 """ 363 if identify is True or identify == "always": 364 return True 365 366 if identify == "safe": 367 return not self.case_sensitive(text) 368 369 return False 370 371 def quote_identifier(self, expression: E, identify: bool = True) -> E: 372 if isinstance(expression, exp.Identifier): 373 name = expression.this 374 expression.set( 375 "quoted", 376 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 377 ) 378 379 return expression 380 381 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 382 return self.parser(**opts).parse(self.tokenize(sql), sql) 383 384 def parse_into( 385 self, expression_type: exp.IntoType, sql: str, **opts 386 ) -> t.List[t.Optional[exp.Expression]]: 387 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 388 389 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 390 return self.generator(**opts).generate(expression, copy=copy) 391 392 def transpile(self, sql: str, **opts) -> t.List[str]: 393 return [ 394 self.generate(expression, copy=False, **opts) if expression else "" 395 for expression in self.parse(sql) 396 ] 397 398 def tokenize(self, sql: str) -> t.List[Token]: 399 return self.tokenizer.tokenize(sql) 400 401 @property 402 def tokenizer(self) -> Tokenizer: 403 if not hasattr(self, "_tokenizer"): 404 self._tokenizer = self.tokenizer_class(dialect=self) 405 return self._tokenizer 406 407 def parser(self, **opts) -> Parser: 408 return self.parser_class(dialect=self, **opts) 409 410 def generator(self, **opts) -> Generator: 411 return self.generator_class(dialect=self, **opts)
287 def __init__(self, **kwargs) -> None: 288 normalization_strategy = kwargs.get("normalization_strategy") 289 290 if normalization_strategy is None: 291 self.normalization_strategy = self.NORMALIZATION_STRATEGY 292 else: 293 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
230 @classmethod 231 def get_or_raise(cls, dialect: DialectType) -> Dialect: 232 """ 233 Look up a dialect in the global dialect registry and return it if it exists. 234 235 Args: 236 dialect: The target dialect. If this is a string, it can be optionally followed by 237 additional key-value pairs that are separated by commas and are used to specify 238 dialect settings, such as whether the dialect's identifiers are case-sensitive. 239 240 Example: 241 >>> dialect = dialect_class = get_or_raise("duckdb") 242 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 243 244 Returns: 245 The corresponding Dialect instance. 246 """ 247 248 if not dialect: 249 return cls() 250 if isinstance(dialect, _Dialect): 251 return dialect() 252 if isinstance(dialect, Dialect): 253 return dialect 254 if isinstance(dialect, str): 255 try: 256 dialect_name, *kv_pairs = dialect.split(",") 257 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 258 except ValueError: 259 raise ValueError( 260 f"Invalid dialect format: '{dialect}'. " 261 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 262 ) 263 264 result = cls.get(dialect_name.strip()) 265 if not result: 266 raise ValueError(f"Unknown dialect '{dialect_name}'.") 267 268 return result(**kwargs) 269 270 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
272 @classmethod 273 def format_time( 274 cls, expression: t.Optional[str | exp.Expression] 275 ) -> t.Optional[exp.Expression]: 276 if isinstance(expression, str): 277 return exp.Literal.string( 278 # the time formats are quoted 279 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 280 ) 281 282 if expression and expression.is_string: 283 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 284 285 return expression
303 def normalize_identifier(self, expression: E) -> E: 304 """ 305 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 306 307 For example, an identifier like FoO would be resolved as foo in Postgres, because it 308 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 309 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 310 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 311 312 There are also dialects like Spark, which are case-insensitive even when quotes are 313 present, and dialects like MySQL, whose resolution rules match those employed by the 314 underlying operating system, for example they may always be case-sensitive in Linux. 315 316 Finally, the normalization behavior of some engines can even be controlled through flags, 317 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 318 319 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 320 that it can analyze queries in the optimizer and successfully capture their semantics. 321 """ 322 if ( 323 isinstance(expression, exp.Identifier) 324 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 325 and ( 326 not expression.quoted 327 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 328 ) 329 ): 330 expression.set( 331 "this", 332 expression.this.upper() 333 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 334 else expression.this.lower(), 335 ) 336 337 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
339 def case_sensitive(self, text: str) -> bool: 340 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 341 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 342 return False 343 344 unsafe = ( 345 str.islower 346 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 347 else str.isupper 348 ) 349 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
351 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 352 """Checks if text can be identified given an identify option. 353 354 Args: 355 text: The text to check. 356 identify: 357 "always" or `True`: Always returns true. 358 "safe": True if the identifier is case-insensitive. 359 360 Returns: 361 Whether or not the given text can be identified. 362 """ 363 if identify is True or identify == "always": 364 return True 365 366 if identify == "safe": 367 return not self.case_sensitive(text) 368 369 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
427def if_sql( 428 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 429) -> t.Callable[[Generator, exp.If], str]: 430 def _if_sql(self: Generator, expression: exp.If) -> str: 431 return self.func( 432 name, 433 expression.this, 434 expression.args.get("true"), 435 expression.args.get("false") or false_value, 436 ) 437 438 return _if_sql
510def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 511 this = self.sql(expression, "this") 512 substr = self.sql(expression, "substr") 513 position = self.sql(expression, "position") 514 if position: 515 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 516 return f"STRPOS({this}, {substr})"
525def var_map_sql( 526 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 527) -> str: 528 keys = expression.args["keys"] 529 values = expression.args["values"] 530 531 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 532 self.unsupported("Cannot convert array columns into map.") 533 return self.func(map_func_name, keys, values) 534 535 args = [] 536 for key, value in zip(keys.expressions, values.expressions): 537 args.append(self.sql(key)) 538 args.append(self.sql(value)) 539 540 return self.func(map_func_name, *args)
543def format_time_lambda( 544 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 545) -> t.Callable[[t.List], E]: 546 """Helper used for time expressions. 547 548 Args: 549 exp_class: the expression class to instantiate. 550 dialect: target sql dialect. 551 default: the default format, True being time. 552 553 Returns: 554 A callable that can be used to return the appropriately formatted time expression. 555 """ 556 557 def _format_time(args: t.List): 558 return exp_class( 559 this=seq_get(args, 0), 560 format=Dialect[dialect].format_time( 561 seq_get(args, 1) 562 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 563 ), 564 ) 565 566 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
569def time_format( 570 dialect: DialectType = None, 571) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 572 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 573 """ 574 Returns the time format for a given expression, unless it's equivalent 575 to the default time format of the dialect of interest. 576 """ 577 time_format = self.format_time(expression) 578 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 579 580 return _time_format
583def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 584 """ 585 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 586 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 587 columns are removed from the create statement. 588 """ 589 has_schema = isinstance(expression.this, exp.Schema) 590 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 591 592 if has_schema and is_partitionable: 593 prop = expression.find(exp.PartitionedByProperty) 594 if prop and prop.this and not isinstance(prop.this, exp.Schema): 595 schema = expression.this 596 columns = {v.name.upper() for v in prop.this.expressions} 597 partitions = [col for col in schema.expressions if col.name.upper() in columns] 598 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 599 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 600 expression.set("this", schema) 601 602 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
605def parse_date_delta( 606 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 607) -> t.Callable[[t.List], E]: 608 def inner_func(args: t.List) -> E: 609 unit_based = len(args) == 3 610 this = args[2] if unit_based else seq_get(args, 0) 611 unit = args[0] if unit_based else exp.Literal.string("DAY") 612 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 613 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 614 615 return inner_func
618def parse_date_delta_with_interval( 619 expression_class: t.Type[E], 620) -> t.Callable[[t.List], t.Optional[E]]: 621 def func(args: t.List) -> t.Optional[E]: 622 if len(args) < 2: 623 return None 624 625 interval = args[1] 626 627 if not isinstance(interval, exp.Interval): 628 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 629 630 expression = interval.this 631 if expression and expression.is_string: 632 expression = exp.Literal.number(expression.this) 633 634 return expression_class( 635 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 636 ) 637 638 return func
650def date_add_interval_sql( 651 data_type: str, kind: str 652) -> t.Callable[[Generator, exp.Expression], str]: 653 def func(self: Generator, expression: exp.Expression) -> str: 654 this = self.sql(expression, "this") 655 unit = expression.args.get("unit") 656 unit = exp.var(unit.name.upper() if unit else "DAY") 657 interval = exp.Interval(this=expression.expression, unit=unit) 658 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 659 660 return func
669def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 670 if not expression.expression: 671 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 672 if expression.text("expression").lower() in TIMEZONES: 673 return self.sql( 674 exp.AtTimeZone( 675 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 676 zone=expression.expression, 677 ) 678 ) 679 return self.function_fallback_sql(expression)
720def encode_decode_sql( 721 self: Generator, expression: exp.Expression, name: str, replace: bool = True 722) -> str: 723 charset = expression.args.get("charset") 724 if charset and charset.name.lower() != "utf-8": 725 self.unsupported(f"Expected utf-8 character set, got {charset}.") 726 727 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
740def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 741 cond = expression.this 742 743 if isinstance(expression.this, exp.Distinct): 744 cond = expression.this.expressions[0] 745 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 746 747 return self.func("sum", exp.func("if", cond, 1, 0))
750def trim_sql(self: Generator, expression: exp.Trim) -> str: 751 target = self.sql(expression, "this") 752 trim_type = self.sql(expression, "position") 753 remove_chars = self.sql(expression, "expression") 754 collation = self.sql(expression, "collation") 755 756 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 757 if not remove_chars and not collation: 758 return self.trim_sql(expression) 759 760 trim_type = f"{trim_type} " if trim_type else "" 761 remove_chars = f"{remove_chars} " if remove_chars else "" 762 from_part = "FROM " if trim_type or remove_chars else "" 763 collation = f" COLLATE {collation}" if collation else "" 764 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
771def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 772 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 773 _dialect = Dialect.get_or_raise(dialect) 774 time_format = self.format_time(expression) 775 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 776 return self.sql( 777 exp.cast( 778 exp.StrToTime(this=expression.this, format=expression.args["format"]), 779 "date", 780 ) 781 ) 782 return self.sql(exp.cast(expression.this, "date")) 783 784 return _ts_or_ds_to_date_sql
801def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 802 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 803 if bad_args: 804 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 805 806 return self.func( 807 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 808 )
811def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 812 bad_args = list( 813 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 814 ) 815 if bad_args: 816 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 817 818 return self.func( 819 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 820 )
823def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 824 names = [] 825 for agg in aggregations: 826 if isinstance(agg, exp.Alias): 827 names.append(agg.alias) 828 else: 829 """ 830 This case corresponds to aggregations without aliases being used as suffixes 831 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 832 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 833 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 834 """ 835 agg_all_unquoted = agg.transform( 836 lambda node: exp.Identifier(this=node.name, quoted=False) 837 if isinstance(node, exp.Identifier) 838 else node 839 ) 840 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 841 842 return names
887def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 888 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 889 if expression.args.get("count"): 890 self.unsupported(f"Only two arguments are supported in function {name}.") 891 892 return self.func(name, expression.this, expression.expression) 893 894 return _arg_max_or_min_sql
897def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 898 this = expression.this.copy() 899 900 return_type = expression.return_type 901 if return_type.is_type(exp.DataType.Type.DATE): 902 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 903 # can truncate timestamp strings, because some dialects can't cast them to DATE 904 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 905 906 expression.this.replace(exp.cast(this, return_type)) 907 return expression
910def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 911 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 912 if cast and isinstance(expression, exp.TsOrDsAdd): 913 expression = ts_or_ds_add_cast(expression) 914 915 return self.func( 916 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 917 ) 918 919 return _delta_sql