sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum, auto 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 21 22 23class Dialects(str, Enum): 24 DIALECT = "" 25 26 BIGQUERY = "bigquery" 27 CLICKHOUSE = "clickhouse" 28 DATABRICKS = "databricks" 29 DRILL = "drill" 30 DUCKDB = "duckdb" 31 HIVE = "hive" 32 MYSQL = "mysql" 33 ORACLE = "oracle" 34 POSTGRES = "postgres" 35 PRESTO = "presto" 36 REDSHIFT = "redshift" 37 SNOWFLAKE = "snowflake" 38 SPARK = "spark" 39 SPARK2 = "spark2" 40 SQLITE = "sqlite" 41 STARROCKS = "starrocks" 42 TABLEAU = "tableau" 43 TERADATA = "teradata" 44 TRINO = "trino" 45 TSQL = "tsql" 46 Doris = "doris" 47 48 49class NormalizationStrategy(str, AutoName): 50 """Specifies the strategy according to which identifiers should be normalized.""" 51 52 LOWERCASE = auto() # Unquoted identifiers are lowercased 53 UPPERCASE = auto() # Unquoted identifiers are uppercased 54 CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes 55 CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes 56 57 58class _Dialect(type): 59 classes: t.Dict[str, t.Type[Dialect]] = {} 60 61 def __eq__(cls, other: t.Any) -> bool: 62 if cls is other: 63 return True 64 if isinstance(other, str): 65 return cls is cls.get(other) 66 if isinstance(other, Dialect): 67 return cls is type(other) 68 69 return False 70 71 def __hash__(cls) -> int: 72 return hash(cls.__name__.lower()) 73 74 @classmethod 75 def __getitem__(cls, key: str) -> t.Type[Dialect]: 76 return cls.classes[key] 77 78 @classmethod 79 def get( 80 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 81 ) -> t.Optional[t.Type[Dialect]]: 82 return cls.classes.get(key, default) 83 84 def __new__(cls, clsname, bases, attrs): 85 klass = super().__new__(cls, clsname, bases, attrs) 86 enum = Dialects.__members__.get(clsname.upper()) 87 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 88 89 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 90 klass.FORMAT_TRIE = ( 91 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 92 ) 93 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 94 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 95 96 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 97 98 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 99 klass.parser_class = getattr(klass, "Parser", Parser) 100 klass.generator_class = getattr(klass, "Generator", Generator) 101 102 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 103 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 104 klass.tokenizer_class._IDENTIFIERS.items() 105 )[0] 106 107 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 108 return next( 109 ( 110 (s, e) 111 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 112 if t == token_type 113 ), 114 (None, None), 115 ) 116 117 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 118 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 119 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 120 121 if enum not in ("", "bigquery"): 122 klass.generator_class.SELECT_KINDS = () 123 124 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 125 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 126 TokenType.ANTI, 127 TokenType.SEMI, 128 } 129 130 return klass 131 132 133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Specifies the strategy according to which identifiers should be normalized. 144 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 145 146 # Determines whether or not an unquoted identifier can start with a digit 147 IDENTIFIERS_CAN_START_WITH_DIGIT = False 148 149 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 150 DPIPE_IS_STRING_CONCAT = True 151 152 # Determines whether or not CONCAT's arguments must be strings 153 STRICT_STRING_CONCAT = False 154 155 # Determines whether or not user-defined data types are supported 156 SUPPORTS_USER_DEFINED_TYPES = True 157 158 # Determines whether or not SEMI/ANTI JOINs are supported 159 SUPPORTS_SEMI_ANTI_JOIN = True 160 161 # Determines how function names are going to be normalized 162 NORMALIZE_FUNCTIONS: bool | str = "upper" 163 164 # Determines whether the base comes first in the LOG function 165 LOG_BASE_FIRST = True 166 167 # Indicates the default null ordering method to use if not explicitly set 168 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 169 NULL_ORDERING = "nulls_are_small" 170 171 # Whether the behavior of a / b depends on the types of a and b. 172 # False means a / b is always float division. 173 # True means a / b is integer division if both a and b are integers. 174 TYPED_DIVISION = False 175 176 # False means 1 / 0 throws an error. 177 # True means 1 / 0 returns null. 178 SAFE_DIVISION = False 179 180 # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string 181 CONCAT_COALESCE = False 182 183 DATE_FORMAT = "'%Y-%m-%d'" 184 DATEINT_FORMAT = "'%Y%m%d'" 185 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 186 187 # Custom time mappings in which the key represents dialect time format 188 # and the value represents a python time format 189 TIME_MAPPING: t.Dict[str, str] = {} 190 191 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 192 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 193 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 194 FORMAT_MAPPING: t.Dict[str, str] = {} 195 196 # Mapping of an unescaped escape sequence to the corresponding character 197 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 198 199 # Columns that are auto-generated by the engine corresponding to this dialect 200 # Such columns may be excluded from SELECT * queries, for example 201 PSEUDOCOLUMNS: t.Set[str] = set() 202 203 # --- Autofilled --- 204 205 tokenizer_class = Tokenizer 206 parser_class = Parser 207 generator_class = Generator 208 209 # A trie of the time_mapping keys 210 TIME_TRIE: t.Dict = {} 211 FORMAT_TRIE: t.Dict = {} 212 213 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 214 INVERSE_TIME_TRIE: t.Dict = {} 215 216 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 217 218 # Delimiters for quotes, identifiers and the corresponding escape characters 219 QUOTE_START = "'" 220 QUOTE_END = "'" 221 IDENTIFIER_START = '"' 222 IDENTIFIER_END = '"' 223 224 # Delimiters for bit, hex and byte literals 225 BIT_START: t.Optional[str] = None 226 BIT_END: t.Optional[str] = None 227 HEX_START: t.Optional[str] = None 228 HEX_END: t.Optional[str] = None 229 BYTE_START: t.Optional[str] = None 230 BYTE_END: t.Optional[str] = None 231 232 @classmethod 233 def get_or_raise(cls, dialect: DialectType) -> Dialect: 234 """ 235 Look up a dialect in the global dialect registry and return it if it exists. 236 237 Args: 238 dialect: The target dialect. If this is a string, it can be optionally followed by 239 additional key-value pairs that are separated by commas and are used to specify 240 dialect settings, such as whether the dialect's identifiers are case-sensitive. 241 242 Example: 243 >>> dialect = dialect_class = get_or_raise("duckdb") 244 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 245 246 Returns: 247 The corresponding Dialect instance. 248 """ 249 250 if not dialect: 251 return cls() 252 if isinstance(dialect, _Dialect): 253 return dialect() 254 if isinstance(dialect, Dialect): 255 return dialect 256 if isinstance(dialect, str): 257 try: 258 dialect_name, *kv_pairs = dialect.split(",") 259 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 260 except ValueError: 261 raise ValueError( 262 f"Invalid dialect format: '{dialect}'. " 263 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 264 ) 265 266 result = cls.get(dialect_name.strip()) 267 if not result: 268 raise ValueError(f"Unknown dialect '{dialect_name}'.") 269 270 return result(**kwargs) 271 272 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 273 274 @classmethod 275 def format_time( 276 cls, expression: t.Optional[str | exp.Expression] 277 ) -> t.Optional[exp.Expression]: 278 if isinstance(expression, str): 279 return exp.Literal.string( 280 # the time formats are quoted 281 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 282 ) 283 284 if expression and expression.is_string: 285 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 286 287 return expression 288 289 def __init__(self, **kwargs) -> None: 290 normalization_strategy = kwargs.get("normalization_strategy") 291 292 if normalization_strategy is None: 293 self.normalization_strategy = self.NORMALIZATION_STRATEGY 294 else: 295 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 296 297 def __eq__(self, other: t.Any) -> bool: 298 # Does not currently take dialect state into account 299 return type(self) == other 300 301 def __hash__(self) -> int: 302 # Does not currently take dialect state into account 303 return hash(type(self)) 304 305 def normalize_identifier(self, expression: E) -> E: 306 """ 307 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 308 309 For example, an identifier like FoO would be resolved as foo in Postgres, because it 310 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 311 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 312 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 313 314 There are also dialects like Spark, which are case-insensitive even when quotes are 315 present, and dialects like MySQL, whose resolution rules match those employed by the 316 underlying operating system, for example they may always be case-sensitive in Linux. 317 318 Finally, the normalization behavior of some engines can even be controlled through flags, 319 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 320 321 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 322 that it can analyze queries in the optimizer and successfully capture their semantics. 323 """ 324 if ( 325 isinstance(expression, exp.Identifier) 326 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 327 and ( 328 not expression.quoted 329 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 330 ) 331 ): 332 expression.set( 333 "this", 334 expression.this.upper() 335 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 336 else expression.this.lower(), 337 ) 338 339 return expression 340 341 def case_sensitive(self, text: str) -> bool: 342 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 343 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 344 return False 345 346 unsafe = ( 347 str.islower 348 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 349 else str.isupper 350 ) 351 return any(unsafe(char) for char in text) 352 353 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 354 """Checks if text can be identified given an identify option. 355 356 Args: 357 text: The text to check. 358 identify: 359 "always" or `True`: Always returns true. 360 "safe": True if the identifier is case-insensitive. 361 362 Returns: 363 Whether or not the given text can be identified. 364 """ 365 if identify is True or identify == "always": 366 return True 367 368 if identify == "safe": 369 return not self.case_sensitive(text) 370 371 return False 372 373 def quote_identifier(self, expression: E, identify: bool = True) -> E: 374 if isinstance(expression, exp.Identifier): 375 name = expression.this 376 expression.set( 377 "quoted", 378 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 379 ) 380 381 return expression 382 383 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 384 return self.parser(**opts).parse(self.tokenize(sql), sql) 385 386 def parse_into( 387 self, expression_type: exp.IntoType, sql: str, **opts 388 ) -> t.List[t.Optional[exp.Expression]]: 389 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 390 391 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 392 return self.generator(**opts).generate(expression, copy=copy) 393 394 def transpile(self, sql: str, **opts) -> t.List[str]: 395 return [ 396 self.generate(expression, copy=False, **opts) if expression else "" 397 for expression in self.parse(sql) 398 ] 399 400 def tokenize(self, sql: str) -> t.List[Token]: 401 return self.tokenizer.tokenize(sql) 402 403 @property 404 def tokenizer(self) -> Tokenizer: 405 if not hasattr(self, "_tokenizer"): 406 self._tokenizer = self.tokenizer_class(dialect=self) 407 return self._tokenizer 408 409 def parser(self, **opts) -> Parser: 410 return self.parser_class(dialect=self, **opts) 411 412 def generator(self, **opts) -> Generator: 413 return self.generator_class(dialect=self, **opts) 414 415 416DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 417 418 419def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 420 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 421 422 423def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 424 if expression.args.get("accuracy"): 425 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 426 return self.func("APPROX_COUNT_DISTINCT", expression.this) 427 428 429def if_sql( 430 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 431) -> t.Callable[[Generator, exp.If], str]: 432 def _if_sql(self: Generator, expression: exp.If) -> str: 433 return self.func( 434 name, 435 expression.this, 436 expression.args.get("true"), 437 expression.args.get("false") or false_value, 438 ) 439 440 return _if_sql 441 442 443def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 444 return self.binary(expression, "->") 445 446 447def arrow_json_extract_scalar_sql( 448 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 449) -> str: 450 return self.binary(expression, "->>") 451 452 453def inline_array_sql(self: Generator, expression: exp.Array) -> str: 454 return f"[{self.expressions(expression, flat=True)}]" 455 456 457def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 458 return self.like_sql( 459 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 460 ) 461 462 463def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 464 zone = self.sql(expression, "this") 465 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 466 467 468def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 469 if expression.args.get("recursive"): 470 self.unsupported("Recursive CTEs are unsupported") 471 expression.args["recursive"] = False 472 return self.with_sql(expression) 473 474 475def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 476 n = self.sql(expression, "this") 477 d = self.sql(expression, "expression") 478 return f"IF({d} <> 0, {n} / {d}, NULL)" 479 480 481def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 482 self.unsupported("TABLESAMPLE unsupported") 483 return self.sql(expression.this) 484 485 486def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 487 self.unsupported("PIVOT unsupported") 488 return "" 489 490 491def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 492 return self.cast_sql(expression) 493 494 495def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 496 self.unsupported("Properties unsupported") 497 return "" 498 499 500def no_comment_column_constraint_sql( 501 self: Generator, expression: exp.CommentColumnConstraint 502) -> str: 503 self.unsupported("CommentColumnConstraint unsupported") 504 return "" 505 506 507def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 508 self.unsupported("MAP_FROM_ENTRIES unsupported") 509 return "" 510 511 512def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 513 this = self.sql(expression, "this") 514 substr = self.sql(expression, "substr") 515 position = self.sql(expression, "position") 516 if position: 517 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 518 return f"STRPOS({this}, {substr})" 519 520 521def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 522 return ( 523 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 524 ) 525 526 527def var_map_sql( 528 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 529) -> str: 530 keys = expression.args["keys"] 531 values = expression.args["values"] 532 533 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 534 self.unsupported("Cannot convert array columns into map.") 535 return self.func(map_func_name, keys, values) 536 537 args = [] 538 for key, value in zip(keys.expressions, values.expressions): 539 args.append(self.sql(key)) 540 args.append(self.sql(value)) 541 542 return self.func(map_func_name, *args) 543 544 545def format_time_lambda( 546 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 547) -> t.Callable[[t.List], E]: 548 """Helper used for time expressions. 549 550 Args: 551 exp_class: the expression class to instantiate. 552 dialect: target sql dialect. 553 default: the default format, True being time. 554 555 Returns: 556 A callable that can be used to return the appropriately formatted time expression. 557 """ 558 559 def _format_time(args: t.List): 560 return exp_class( 561 this=seq_get(args, 0), 562 format=Dialect[dialect].format_time( 563 seq_get(args, 1) 564 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 565 ), 566 ) 567 568 return _format_time 569 570 571def time_format( 572 dialect: DialectType = None, 573) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 574 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 575 """ 576 Returns the time format for a given expression, unless it's equivalent 577 to the default time format of the dialect of interest. 578 """ 579 time_format = self.format_time(expression) 580 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 581 582 return _time_format 583 584 585def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 586 """ 587 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 588 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 589 columns are removed from the create statement. 590 """ 591 has_schema = isinstance(expression.this, exp.Schema) 592 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 593 594 if has_schema and is_partitionable: 595 prop = expression.find(exp.PartitionedByProperty) 596 if prop and prop.this and not isinstance(prop.this, exp.Schema): 597 schema = expression.this 598 columns = {v.name.upper() for v in prop.this.expressions} 599 partitions = [col for col in schema.expressions if col.name.upper() in columns] 600 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 601 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 602 expression.set("this", schema) 603 604 return self.create_sql(expression) 605 606 607def parse_date_delta( 608 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 609) -> t.Callable[[t.List], E]: 610 def inner_func(args: t.List) -> E: 611 unit_based = len(args) == 3 612 this = args[2] if unit_based else seq_get(args, 0) 613 unit = args[0] if unit_based else exp.Literal.string("DAY") 614 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 615 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 616 617 return inner_func 618 619 620def parse_date_delta_with_interval( 621 expression_class: t.Type[E], 622) -> t.Callable[[t.List], t.Optional[E]]: 623 def func(args: t.List) -> t.Optional[E]: 624 if len(args) < 2: 625 return None 626 627 interval = args[1] 628 629 if not isinstance(interval, exp.Interval): 630 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 631 632 expression = interval.this 633 if expression and expression.is_string: 634 expression = exp.Literal.number(expression.this) 635 636 return expression_class( 637 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 638 ) 639 640 return func 641 642 643def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 644 unit = seq_get(args, 0) 645 this = seq_get(args, 1) 646 647 if isinstance(this, exp.Cast) and this.is_type("date"): 648 return exp.DateTrunc(unit=unit, this=this) 649 return exp.TimestampTrunc(this=this, unit=unit) 650 651 652def date_add_interval_sql( 653 data_type: str, kind: str 654) -> t.Callable[[Generator, exp.Expression], str]: 655 def func(self: Generator, expression: exp.Expression) -> str: 656 this = self.sql(expression, "this") 657 unit = expression.args.get("unit") 658 unit = exp.var(unit.name.upper() if unit else "DAY") 659 interval = exp.Interval(this=expression.expression, unit=unit) 660 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 661 662 return func 663 664 665def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 666 return self.func( 667 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 668 ) 669 670 671def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 672 if not expression.expression: 673 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 674 if expression.text("expression").lower() in TIMEZONES: 675 return self.sql( 676 exp.AtTimeZone( 677 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 678 zone=expression.expression, 679 ) 680 ) 681 return self.function_fallback_sql(expression) 682 683 684def locate_to_strposition(args: t.List) -> exp.Expression: 685 return exp.StrPosition( 686 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 687 ) 688 689 690def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 691 return self.func( 692 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 693 ) 694 695 696def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 697 return self.sql( 698 exp.Substring( 699 this=expression.this, start=exp.Literal.number(1), length=expression.expression 700 ) 701 ) 702 703 704def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 705 return self.sql( 706 exp.Substring( 707 this=expression.this, 708 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 709 ) 710 ) 711 712 713def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 714 return self.sql(exp.cast(expression.this, "timestamp")) 715 716 717def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 718 return self.sql(exp.cast(expression.this, "date")) 719 720 721# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 722def encode_decode_sql( 723 self: Generator, expression: exp.Expression, name: str, replace: bool = True 724) -> str: 725 charset = expression.args.get("charset") 726 if charset and charset.name.lower() != "utf-8": 727 self.unsupported(f"Expected utf-8 character set, got {charset}.") 728 729 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 730 731 732def min_or_least(self: Generator, expression: exp.Min) -> str: 733 name = "LEAST" if expression.expressions else "MIN" 734 return rename_func(name)(self, expression) 735 736 737def max_or_greatest(self: Generator, expression: exp.Max) -> str: 738 name = "GREATEST" if expression.expressions else "MAX" 739 return rename_func(name)(self, expression) 740 741 742def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 743 cond = expression.this 744 745 if isinstance(expression.this, exp.Distinct): 746 cond = expression.this.expressions[0] 747 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 748 749 return self.func("sum", exp.func("if", cond, 1, 0)) 750 751 752def trim_sql(self: Generator, expression: exp.Trim) -> str: 753 target = self.sql(expression, "this") 754 trim_type = self.sql(expression, "position") 755 remove_chars = self.sql(expression, "expression") 756 collation = self.sql(expression, "collation") 757 758 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 759 if not remove_chars and not collation: 760 return self.trim_sql(expression) 761 762 trim_type = f"{trim_type} " if trim_type else "" 763 remove_chars = f"{remove_chars} " if remove_chars else "" 764 from_part = "FROM " if trim_type or remove_chars else "" 765 collation = f" COLLATE {collation}" if collation else "" 766 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 767 768 769def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 770 return self.func("STRPTIME", expression.this, self.format_time(expression)) 771 772 773def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 774 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 775 _dialect = Dialect.get_or_raise(dialect) 776 time_format = self.format_time(expression) 777 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 778 return self.sql( 779 exp.cast( 780 exp.StrToTime(this=expression.this, format=expression.args["format"]), 781 "date", 782 ) 783 ) 784 return self.sql(exp.cast(expression.this, "date")) 785 786 return _ts_or_ds_to_date_sql 787 788 789def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 790 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 791 792 793def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 794 delim, *rest_args = expression.expressions 795 return self.sql( 796 reduce( 797 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 798 rest_args, 799 ) 800 ) 801 802 803def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 804 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 805 if bad_args: 806 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 807 808 return self.func( 809 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 810 ) 811 812 813def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 814 bad_args = list( 815 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 816 ) 817 if bad_args: 818 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 819 820 return self.func( 821 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 822 ) 823 824 825def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 826 names = [] 827 for agg in aggregations: 828 if isinstance(agg, exp.Alias): 829 names.append(agg.alias) 830 else: 831 """ 832 This case corresponds to aggregations without aliases being used as suffixes 833 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 834 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 835 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 836 """ 837 agg_all_unquoted = agg.transform( 838 lambda node: exp.Identifier(this=node.name, quoted=False) 839 if isinstance(node, exp.Identifier) 840 else node 841 ) 842 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 843 844 return names 845 846 847def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 848 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 849 850 851# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 852def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 853 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 854 855 856def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 857 return self.func("MAX", expression.this) 858 859 860def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 861 a = self.sql(expression.left) 862 b = self.sql(expression.right) 863 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 864 865 866# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 867def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 868 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 869 870 871def is_parse_json(expression: exp.Expression) -> bool: 872 return isinstance(expression, exp.ParseJSON) or ( 873 isinstance(expression, exp.Cast) and expression.is_type("json") 874 ) 875 876 877def isnull_to_is_null(args: t.List) -> exp.Expression: 878 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 879 880 881def generatedasidentitycolumnconstraint_sql( 882 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 883) -> str: 884 start = self.sql(expression, "start") or "1" 885 increment = self.sql(expression, "increment") or "1" 886 return f"IDENTITY({start}, {increment})" 887 888 889def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 890 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 891 if expression.args.get("count"): 892 self.unsupported(f"Only two arguments are supported in function {name}.") 893 894 return self.func(name, expression.this, expression.expression) 895 896 return _arg_max_or_min_sql 897 898 899def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 900 this = expression.this.copy() 901 902 return_type = expression.return_type 903 if return_type.is_type(exp.DataType.Type.DATE): 904 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 905 # can truncate timestamp strings, because some dialects can't cast them to DATE 906 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 907 908 expression.this.replace(exp.cast(this, return_type)) 909 return expression 910 911 912def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 913 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 914 if cast and isinstance(expression, exp.TsOrDsAdd): 915 expression = ts_or_ds_add_cast(expression) 916 917 return self.func( 918 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 919 ) 920 921 return _delta_sql
24class Dialects(str, Enum): 25 DIALECT = "" 26 27 BIGQUERY = "bigquery" 28 CLICKHOUSE = "clickhouse" 29 DATABRICKS = "databricks" 30 DRILL = "drill" 31 DUCKDB = "duckdb" 32 HIVE = "hive" 33 MYSQL = "mysql" 34 ORACLE = "oracle" 35 POSTGRES = "postgres" 36 PRESTO = "presto" 37 REDSHIFT = "redshift" 38 SNOWFLAKE = "snowflake" 39 SPARK = "spark" 40 SPARK2 = "spark2" 41 SQLITE = "sqlite" 42 STARROCKS = "starrocks" 43 TABLEAU = "tableau" 44 TERADATA = "teradata" 45 TRINO = "trino" 46 TSQL = "tsql" 47 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
50class NormalizationStrategy(str, AutoName): 51 """Specifies the strategy according to which identifiers should be normalized.""" 52 53 LOWERCASE = auto() # Unquoted identifiers are lowercased 54 UPPERCASE = auto() # Unquoted identifiers are uppercased 55 CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes 56 CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
Specifies the strategy according to which identifiers should be normalized.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
134class Dialect(metaclass=_Dialect): 135 # Determines the base index offset for arrays 136 INDEX_OFFSET = 0 137 138 # If true unnest table aliases are considered only as column aliases 139 UNNEST_COLUMN_ONLY = False 140 141 # Determines whether or not the table alias comes after tablesample 142 ALIAS_POST_TABLESAMPLE = False 143 144 # Specifies the strategy according to which identifiers should be normalized. 145 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 151 DPIPE_IS_STRING_CONCAT = True 152 153 # Determines whether or not CONCAT's arguments must be strings 154 STRICT_STRING_CONCAT = False 155 156 # Determines whether or not user-defined data types are supported 157 SUPPORTS_USER_DEFINED_TYPES = True 158 159 # Determines whether or not SEMI/ANTI JOINs are supported 160 SUPPORTS_SEMI_ANTI_JOIN = True 161 162 # Determines how function names are going to be normalized 163 NORMALIZE_FUNCTIONS: bool | str = "upper" 164 165 # Determines whether the base comes first in the LOG function 166 LOG_BASE_FIRST = True 167 168 # Indicates the default null ordering method to use if not explicitly set 169 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 170 NULL_ORDERING = "nulls_are_small" 171 172 # Whether the behavior of a / b depends on the types of a and b. 173 # False means a / b is always float division. 174 # True means a / b is integer division if both a and b are integers. 175 TYPED_DIVISION = False 176 177 # False means 1 / 0 throws an error. 178 # True means 1 / 0 returns null. 179 SAFE_DIVISION = False 180 181 # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string 182 CONCAT_COALESCE = False 183 184 DATE_FORMAT = "'%Y-%m-%d'" 185 DATEINT_FORMAT = "'%Y%m%d'" 186 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 187 188 # Custom time mappings in which the key represents dialect time format 189 # and the value represents a python time format 190 TIME_MAPPING: t.Dict[str, str] = {} 191 192 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 193 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 194 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 195 FORMAT_MAPPING: t.Dict[str, str] = {} 196 197 # Mapping of an unescaped escape sequence to the corresponding character 198 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 199 200 # Columns that are auto-generated by the engine corresponding to this dialect 201 # Such columns may be excluded from SELECT * queries, for example 202 PSEUDOCOLUMNS: t.Set[str] = set() 203 204 # --- Autofilled --- 205 206 tokenizer_class = Tokenizer 207 parser_class = Parser 208 generator_class = Generator 209 210 # A trie of the time_mapping keys 211 TIME_TRIE: t.Dict = {} 212 FORMAT_TRIE: t.Dict = {} 213 214 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 215 INVERSE_TIME_TRIE: t.Dict = {} 216 217 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 218 219 # Delimiters for quotes, identifiers and the corresponding escape characters 220 QUOTE_START = "'" 221 QUOTE_END = "'" 222 IDENTIFIER_START = '"' 223 IDENTIFIER_END = '"' 224 225 # Delimiters for bit, hex and byte literals 226 BIT_START: t.Optional[str] = None 227 BIT_END: t.Optional[str] = None 228 HEX_START: t.Optional[str] = None 229 HEX_END: t.Optional[str] = None 230 BYTE_START: t.Optional[str] = None 231 BYTE_END: t.Optional[str] = None 232 233 @classmethod 234 def get_or_raise(cls, dialect: DialectType) -> Dialect: 235 """ 236 Look up a dialect in the global dialect registry and return it if it exists. 237 238 Args: 239 dialect: The target dialect. If this is a string, it can be optionally followed by 240 additional key-value pairs that are separated by commas and are used to specify 241 dialect settings, such as whether the dialect's identifiers are case-sensitive. 242 243 Example: 244 >>> dialect = dialect_class = get_or_raise("duckdb") 245 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 246 247 Returns: 248 The corresponding Dialect instance. 249 """ 250 251 if not dialect: 252 return cls() 253 if isinstance(dialect, _Dialect): 254 return dialect() 255 if isinstance(dialect, Dialect): 256 return dialect 257 if isinstance(dialect, str): 258 try: 259 dialect_name, *kv_pairs = dialect.split(",") 260 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 261 except ValueError: 262 raise ValueError( 263 f"Invalid dialect format: '{dialect}'. " 264 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 265 ) 266 267 result = cls.get(dialect_name.strip()) 268 if not result: 269 raise ValueError(f"Unknown dialect '{dialect_name}'.") 270 271 return result(**kwargs) 272 273 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 274 275 @classmethod 276 def format_time( 277 cls, expression: t.Optional[str | exp.Expression] 278 ) -> t.Optional[exp.Expression]: 279 if isinstance(expression, str): 280 return exp.Literal.string( 281 # the time formats are quoted 282 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 283 ) 284 285 if expression and expression.is_string: 286 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 287 288 return expression 289 290 def __init__(self, **kwargs) -> None: 291 normalization_strategy = kwargs.get("normalization_strategy") 292 293 if normalization_strategy is None: 294 self.normalization_strategy = self.NORMALIZATION_STRATEGY 295 else: 296 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 297 298 def __eq__(self, other: t.Any) -> bool: 299 # Does not currently take dialect state into account 300 return type(self) == other 301 302 def __hash__(self) -> int: 303 # Does not currently take dialect state into account 304 return hash(type(self)) 305 306 def normalize_identifier(self, expression: E) -> E: 307 """ 308 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 309 310 For example, an identifier like FoO would be resolved as foo in Postgres, because it 311 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 312 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 313 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 314 315 There are also dialects like Spark, which are case-insensitive even when quotes are 316 present, and dialects like MySQL, whose resolution rules match those employed by the 317 underlying operating system, for example they may always be case-sensitive in Linux. 318 319 Finally, the normalization behavior of some engines can even be controlled through flags, 320 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 321 322 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 323 that it can analyze queries in the optimizer and successfully capture their semantics. 324 """ 325 if ( 326 isinstance(expression, exp.Identifier) 327 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 328 and ( 329 not expression.quoted 330 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 331 ) 332 ): 333 expression.set( 334 "this", 335 expression.this.upper() 336 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 337 else expression.this.lower(), 338 ) 339 340 return expression 341 342 def case_sensitive(self, text: str) -> bool: 343 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 344 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 345 return False 346 347 unsafe = ( 348 str.islower 349 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 350 else str.isupper 351 ) 352 return any(unsafe(char) for char in text) 353 354 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 355 """Checks if text can be identified given an identify option. 356 357 Args: 358 text: The text to check. 359 identify: 360 "always" or `True`: Always returns true. 361 "safe": True if the identifier is case-insensitive. 362 363 Returns: 364 Whether or not the given text can be identified. 365 """ 366 if identify is True or identify == "always": 367 return True 368 369 if identify == "safe": 370 return not self.case_sensitive(text) 371 372 return False 373 374 def quote_identifier(self, expression: E, identify: bool = True) -> E: 375 if isinstance(expression, exp.Identifier): 376 name = expression.this 377 expression.set( 378 "quoted", 379 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 380 ) 381 382 return expression 383 384 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 385 return self.parser(**opts).parse(self.tokenize(sql), sql) 386 387 def parse_into( 388 self, expression_type: exp.IntoType, sql: str, **opts 389 ) -> t.List[t.Optional[exp.Expression]]: 390 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 391 392 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 393 return self.generator(**opts).generate(expression, copy=copy) 394 395 def transpile(self, sql: str, **opts) -> t.List[str]: 396 return [ 397 self.generate(expression, copy=False, **opts) if expression else "" 398 for expression in self.parse(sql) 399 ] 400 401 def tokenize(self, sql: str) -> t.List[Token]: 402 return self.tokenizer.tokenize(sql) 403 404 @property 405 def tokenizer(self) -> Tokenizer: 406 if not hasattr(self, "_tokenizer"): 407 self._tokenizer = self.tokenizer_class(dialect=self) 408 return self._tokenizer 409 410 def parser(self, **opts) -> Parser: 411 return self.parser_class(dialect=self, **opts) 412 413 def generator(self, **opts) -> Generator: 414 return self.generator_class(dialect=self, **opts)
290 def __init__(self, **kwargs) -> None: 291 normalization_strategy = kwargs.get("normalization_strategy") 292 293 if normalization_strategy is None: 294 self.normalization_strategy = self.NORMALIZATION_STRATEGY 295 else: 296 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
233 @classmethod 234 def get_or_raise(cls, dialect: DialectType) -> Dialect: 235 """ 236 Look up a dialect in the global dialect registry and return it if it exists. 237 238 Args: 239 dialect: The target dialect. If this is a string, it can be optionally followed by 240 additional key-value pairs that are separated by commas and are used to specify 241 dialect settings, such as whether the dialect's identifiers are case-sensitive. 242 243 Example: 244 >>> dialect = dialect_class = get_or_raise("duckdb") 245 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 246 247 Returns: 248 The corresponding Dialect instance. 249 """ 250 251 if not dialect: 252 return cls() 253 if isinstance(dialect, _Dialect): 254 return dialect() 255 if isinstance(dialect, Dialect): 256 return dialect 257 if isinstance(dialect, str): 258 try: 259 dialect_name, *kv_pairs = dialect.split(",") 260 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 261 except ValueError: 262 raise ValueError( 263 f"Invalid dialect format: '{dialect}'. " 264 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 265 ) 266 267 result = cls.get(dialect_name.strip()) 268 if not result: 269 raise ValueError(f"Unknown dialect '{dialect_name}'.") 270 271 return result(**kwargs) 272 273 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
275 @classmethod 276 def format_time( 277 cls, expression: t.Optional[str | exp.Expression] 278 ) -> t.Optional[exp.Expression]: 279 if isinstance(expression, str): 280 return exp.Literal.string( 281 # the time formats are quoted 282 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 283 ) 284 285 if expression and expression.is_string: 286 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 287 288 return expression
306 def normalize_identifier(self, expression: E) -> E: 307 """ 308 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 309 310 For example, an identifier like FoO would be resolved as foo in Postgres, because it 311 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 312 it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, 313 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 314 315 There are also dialects like Spark, which are case-insensitive even when quotes are 316 present, and dialects like MySQL, whose resolution rules match those employed by the 317 underlying operating system, for example they may always be case-sensitive in Linux. 318 319 Finally, the normalization behavior of some engines can even be controlled through flags, 320 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 321 322 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 323 that it can analyze queries in the optimizer and successfully capture their semantics. 324 """ 325 if ( 326 isinstance(expression, exp.Identifier) 327 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 328 and ( 329 not expression.quoted 330 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 331 ) 332 ): 333 expression.set( 334 "this", 335 expression.this.upper() 336 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 337 else expression.this.lower(), 338 ) 339 340 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
342 def case_sensitive(self, text: str) -> bool: 343 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 344 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 345 return False 346 347 unsafe = ( 348 str.islower 349 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 350 else str.isupper 351 ) 352 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
354 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 355 """Checks if text can be identified given an identify option. 356 357 Args: 358 text: The text to check. 359 identify: 360 "always" or `True`: Always returns true. 361 "safe": True if the identifier is case-insensitive. 362 363 Returns: 364 Whether or not the given text can be identified. 365 """ 366 if identify is True or identify == "always": 367 return True 368 369 if identify == "safe": 370 return not self.case_sensitive(text) 371 372 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
430def if_sql( 431 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 432) -> t.Callable[[Generator, exp.If], str]: 433 def _if_sql(self: Generator, expression: exp.If) -> str: 434 return self.func( 435 name, 436 expression.this, 437 expression.args.get("true"), 438 expression.args.get("false") or false_value, 439 ) 440 441 return _if_sql
513def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 514 this = self.sql(expression, "this") 515 substr = self.sql(expression, "substr") 516 position = self.sql(expression, "position") 517 if position: 518 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 519 return f"STRPOS({this}, {substr})"
528def var_map_sql( 529 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 530) -> str: 531 keys = expression.args["keys"] 532 values = expression.args["values"] 533 534 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 535 self.unsupported("Cannot convert array columns into map.") 536 return self.func(map_func_name, keys, values) 537 538 args = [] 539 for key, value in zip(keys.expressions, values.expressions): 540 args.append(self.sql(key)) 541 args.append(self.sql(value)) 542 543 return self.func(map_func_name, *args)
546def format_time_lambda( 547 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 548) -> t.Callable[[t.List], E]: 549 """Helper used for time expressions. 550 551 Args: 552 exp_class: the expression class to instantiate. 553 dialect: target sql dialect. 554 default: the default format, True being time. 555 556 Returns: 557 A callable that can be used to return the appropriately formatted time expression. 558 """ 559 560 def _format_time(args: t.List): 561 return exp_class( 562 this=seq_get(args, 0), 563 format=Dialect[dialect].format_time( 564 seq_get(args, 1) 565 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 566 ), 567 ) 568 569 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
572def time_format( 573 dialect: DialectType = None, 574) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 575 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 576 """ 577 Returns the time format for a given expression, unless it's equivalent 578 to the default time format of the dialect of interest. 579 """ 580 time_format = self.format_time(expression) 581 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 582 583 return _time_format
586def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 587 """ 588 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 589 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 590 columns are removed from the create statement. 591 """ 592 has_schema = isinstance(expression.this, exp.Schema) 593 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 594 595 if has_schema and is_partitionable: 596 prop = expression.find(exp.PartitionedByProperty) 597 if prop and prop.this and not isinstance(prop.this, exp.Schema): 598 schema = expression.this 599 columns = {v.name.upper() for v in prop.this.expressions} 600 partitions = [col for col in schema.expressions if col.name.upper() in columns] 601 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 602 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 603 expression.set("this", schema) 604 605 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
608def parse_date_delta( 609 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 610) -> t.Callable[[t.List], E]: 611 def inner_func(args: t.List) -> E: 612 unit_based = len(args) == 3 613 this = args[2] if unit_based else seq_get(args, 0) 614 unit = args[0] if unit_based else exp.Literal.string("DAY") 615 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 616 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 617 618 return inner_func
621def parse_date_delta_with_interval( 622 expression_class: t.Type[E], 623) -> t.Callable[[t.List], t.Optional[E]]: 624 def func(args: t.List) -> t.Optional[E]: 625 if len(args) < 2: 626 return None 627 628 interval = args[1] 629 630 if not isinstance(interval, exp.Interval): 631 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 632 633 expression = interval.this 634 if expression and expression.is_string: 635 expression = exp.Literal.number(expression.this) 636 637 return expression_class( 638 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 639 ) 640 641 return func
653def date_add_interval_sql( 654 data_type: str, kind: str 655) -> t.Callable[[Generator, exp.Expression], str]: 656 def func(self: Generator, expression: exp.Expression) -> str: 657 this = self.sql(expression, "this") 658 unit = expression.args.get("unit") 659 unit = exp.var(unit.name.upper() if unit else "DAY") 660 interval = exp.Interval(this=expression.expression, unit=unit) 661 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 662 663 return func
672def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 673 if not expression.expression: 674 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 675 if expression.text("expression").lower() in TIMEZONES: 676 return self.sql( 677 exp.AtTimeZone( 678 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 679 zone=expression.expression, 680 ) 681 ) 682 return self.function_fallback_sql(expression)
723def encode_decode_sql( 724 self: Generator, expression: exp.Expression, name: str, replace: bool = True 725) -> str: 726 charset = expression.args.get("charset") 727 if charset and charset.name.lower() != "utf-8": 728 self.unsupported(f"Expected utf-8 character set, got {charset}.") 729 730 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
743def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 744 cond = expression.this 745 746 if isinstance(expression.this, exp.Distinct): 747 cond = expression.this.expressions[0] 748 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 749 750 return self.func("sum", exp.func("if", cond, 1, 0))
753def trim_sql(self: Generator, expression: exp.Trim) -> str: 754 target = self.sql(expression, "this") 755 trim_type = self.sql(expression, "position") 756 remove_chars = self.sql(expression, "expression") 757 collation = self.sql(expression, "collation") 758 759 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 760 if not remove_chars and not collation: 761 return self.trim_sql(expression) 762 763 trim_type = f"{trim_type} " if trim_type else "" 764 remove_chars = f"{remove_chars} " if remove_chars else "" 765 from_part = "FROM " if trim_type or remove_chars else "" 766 collation = f" COLLATE {collation}" if collation else "" 767 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
774def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 775 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 776 _dialect = Dialect.get_or_raise(dialect) 777 time_format = self.format_time(expression) 778 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 779 return self.sql( 780 exp.cast( 781 exp.StrToTime(this=expression.this, format=expression.args["format"]), 782 "date", 783 ) 784 ) 785 return self.sql(exp.cast(expression.this, "date")) 786 787 return _ts_or_ds_to_date_sql
804def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 805 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 806 if bad_args: 807 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 808 809 return self.func( 810 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 811 )
814def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 815 bad_args = list( 816 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 817 ) 818 if bad_args: 819 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 820 821 return self.func( 822 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 823 )
826def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 827 names = [] 828 for agg in aggregations: 829 if isinstance(agg, exp.Alias): 830 names.append(agg.alias) 831 else: 832 """ 833 This case corresponds to aggregations without aliases being used as suffixes 834 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 835 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 836 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 837 """ 838 agg_all_unquoted = agg.transform( 839 lambda node: exp.Identifier(this=node.name, quoted=False) 840 if isinstance(node, exp.Identifier) 841 else node 842 ) 843 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 844 845 return names
890def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 891 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 892 if expression.args.get("count"): 893 self.unsupported(f"Only two arguments are supported in function {name}.") 894 895 return self.func(name, expression.this, expression.expression) 896 897 return _arg_max_or_min_sql
900def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 901 this = expression.this.copy() 902 903 return_type = expression.return_type 904 if return_type.is_type(exp.DataType.Type.DATE): 905 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 906 # can truncate timestamp strings, because some dialects can't cast them to DATE 907 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 908 909 expression.this.replace(exp.cast(this, return_type)) 910 return expression
913def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 914 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 915 if cast and isinstance(expression, exp.TsOrDsAdd): 916 expression = ts_or_ds_add_cast(expression) 917 918 return self.func( 919 name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this 920 ) 921 922 return _delta_sql