sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "databricks", "hive", "spark", "spark2"): 165 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 166 for modifier in ("cluster", "distribute", "sort"): 167 modifier_transforms.pop(modifier, None) 168 169 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 170 171 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 172 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 173 TokenType.ANTI, 174 TokenType.SEMI, 175 } 176 177 return klass 178 179 180class Dialect(metaclass=_Dialect): 181 INDEX_OFFSET = 0 182 """The base index offset for arrays.""" 183 184 WEEK_OFFSET = 0 185 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 186 187 UNNEST_COLUMN_ONLY = False 188 """Whether `UNNEST` table aliases are treated as column aliases.""" 189 190 ALIAS_POST_TABLESAMPLE = False 191 """Whether the table alias comes after tablesample.""" 192 193 TABLESAMPLE_SIZE_IS_PERCENT = False 194 """Whether a size in the table sample clause represents percentage.""" 195 196 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 197 """Specifies the strategy according to which identifiers should be normalized.""" 198 199 IDENTIFIERS_CAN_START_WITH_DIGIT = False 200 """Whether an unquoted identifier can start with a digit.""" 201 202 DPIPE_IS_STRING_CONCAT = True 203 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 204 205 STRICT_STRING_CONCAT = False 206 """Whether `CONCAT`'s arguments must be strings.""" 207 208 SUPPORTS_USER_DEFINED_TYPES = True 209 """Whether user-defined data types are supported.""" 210 211 SUPPORTS_SEMI_ANTI_JOIN = True 212 """Whether `SEMI` or `ANTI` joins are supported.""" 213 214 NORMALIZE_FUNCTIONS: bool | str = "upper" 215 """ 216 Determines how function names are going to be normalized. 217 Possible values: 218 "upper" or True: Convert names to uppercase. 219 "lower": Convert names to lowercase. 220 False: Disables function name normalization. 221 """ 222 223 LOG_BASE_FIRST: t.Optional[bool] = True 224 """ 225 Whether the base comes first in the `LOG` function. 226 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 227 """ 228 229 NULL_ORDERING = "nulls_are_small" 230 """ 231 Default `NULL` ordering method to use if not explicitly set. 232 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 233 """ 234 235 TYPED_DIVISION = False 236 """ 237 Whether the behavior of `a / b` depends on the types of `a` and `b`. 238 False means `a / b` is always float division. 239 True means `a / b` is integer division if both `a` and `b` are integers. 240 """ 241 242 SAFE_DIVISION = False 243 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 244 245 CONCAT_COALESCE = False 246 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 247 248 DATE_FORMAT = "'%Y-%m-%d'" 249 DATEINT_FORMAT = "'%Y%m%d'" 250 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 251 252 TIME_MAPPING: t.Dict[str, str] = {} 253 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 254 255 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 256 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 257 FORMAT_MAPPING: t.Dict[str, str] = {} 258 """ 259 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 260 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 261 """ 262 263 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 264 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 265 266 PSEUDOCOLUMNS: t.Set[str] = set() 267 """ 268 Columns that are auto-generated by the engine corresponding to this dialect. 269 For example, such columns may be excluded from `SELECT *` queries. 270 """ 271 272 PREFER_CTE_ALIAS_COLUMN = False 273 """ 274 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 275 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 276 any projection aliases in the subquery. 277 278 For example, 279 WITH y(c) AS ( 280 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 281 ) SELECT c FROM y; 282 283 will be rewritten as 284 285 WITH y(c) AS ( 286 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 287 ) SELECT c FROM y; 288 """ 289 290 # --- Autofilled --- 291 292 tokenizer_class = Tokenizer 293 parser_class = Parser 294 generator_class = Generator 295 296 # A trie of the time_mapping keys 297 TIME_TRIE: t.Dict = {} 298 FORMAT_TRIE: t.Dict = {} 299 300 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 301 INVERSE_TIME_TRIE: t.Dict = {} 302 303 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 304 305 # Delimiters for string literals and identifiers 306 QUOTE_START = "'" 307 QUOTE_END = "'" 308 IDENTIFIER_START = '"' 309 IDENTIFIER_END = '"' 310 311 # Delimiters for bit, hex, byte and unicode literals 312 BIT_START: t.Optional[str] = None 313 BIT_END: t.Optional[str] = None 314 HEX_START: t.Optional[str] = None 315 HEX_END: t.Optional[str] = None 316 BYTE_START: t.Optional[str] = None 317 BYTE_END: t.Optional[str] = None 318 UNICODE_START: t.Optional[str] = None 319 UNICODE_END: t.Optional[str] = None 320 321 @classmethod 322 def get_or_raise(cls, dialect: DialectType) -> Dialect: 323 """ 324 Look up a dialect in the global dialect registry and return it if it exists. 325 326 Args: 327 dialect: The target dialect. If this is a string, it can be optionally followed by 328 additional key-value pairs that are separated by commas and are used to specify 329 dialect settings, such as whether the dialect's identifiers are case-sensitive. 330 331 Example: 332 >>> dialect = dialect_class = get_or_raise("duckdb") 333 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 334 335 Returns: 336 The corresponding Dialect instance. 337 """ 338 339 if not dialect: 340 return cls() 341 if isinstance(dialect, _Dialect): 342 return dialect() 343 if isinstance(dialect, Dialect): 344 return dialect 345 if isinstance(dialect, str): 346 try: 347 dialect_name, *kv_pairs = dialect.split(",") 348 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 349 except ValueError: 350 raise ValueError( 351 f"Invalid dialect format: '{dialect}'. " 352 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 353 ) 354 355 result = cls.get(dialect_name.strip()) 356 if not result: 357 from difflib import get_close_matches 358 359 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 360 if similar: 361 similar = f" Did you mean {similar}?" 362 363 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 364 365 return result(**kwargs) 366 367 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 368 369 @classmethod 370 def format_time( 371 cls, expression: t.Optional[str | exp.Expression] 372 ) -> t.Optional[exp.Expression]: 373 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 374 if isinstance(expression, str): 375 return exp.Literal.string( 376 # the time formats are quoted 377 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 378 ) 379 380 if expression and expression.is_string: 381 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 382 383 return expression 384 385 def __init__(self, **kwargs) -> None: 386 normalization_strategy = kwargs.get("normalization_strategy") 387 388 if normalization_strategy is None: 389 self.normalization_strategy = self.NORMALIZATION_STRATEGY 390 else: 391 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 392 393 def __eq__(self, other: t.Any) -> bool: 394 # Does not currently take dialect state into account 395 return type(self) == other 396 397 def __hash__(self) -> int: 398 # Does not currently take dialect state into account 399 return hash(type(self)) 400 401 def normalize_identifier(self, expression: E) -> E: 402 """ 403 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 404 405 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 406 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 407 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 408 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 409 410 There are also dialects like Spark, which are case-insensitive even when quotes are 411 present, and dialects like MySQL, whose resolution rules match those employed by the 412 underlying operating system, for example they may always be case-sensitive in Linux. 413 414 Finally, the normalization behavior of some engines can even be controlled through flags, 415 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 416 417 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 418 that it can analyze queries in the optimizer and successfully capture their semantics. 419 """ 420 if ( 421 isinstance(expression, exp.Identifier) 422 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 423 and ( 424 not expression.quoted 425 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 426 ) 427 ): 428 expression.set( 429 "this", 430 ( 431 expression.this.upper() 432 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 433 else expression.this.lower() 434 ), 435 ) 436 437 return expression 438 439 def case_sensitive(self, text: str) -> bool: 440 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 441 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 442 return False 443 444 unsafe = ( 445 str.islower 446 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 447 else str.isupper 448 ) 449 return any(unsafe(char) for char in text) 450 451 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 452 """Checks if text can be identified given an identify option. 453 454 Args: 455 text: The text to check. 456 identify: 457 `"always"` or `True`: Always returns `True`. 458 `"safe"`: Only returns `True` if the identifier is case-insensitive. 459 460 Returns: 461 Whether the given text can be identified. 462 """ 463 if identify is True or identify == "always": 464 return True 465 466 if identify == "safe": 467 return not self.case_sensitive(text) 468 469 return False 470 471 def quote_identifier(self, expression: E, identify: bool = True) -> E: 472 """ 473 Adds quotes to a given identifier. 474 475 Args: 476 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 477 identify: If set to `False`, the quotes will only be added if the identifier is deemed 478 "unsafe", with respect to its characters and this dialect's normalization strategy. 479 """ 480 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 481 name = expression.this 482 expression.set( 483 "quoted", 484 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 485 ) 486 487 return expression 488 489 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 490 if isinstance(path, exp.Literal): 491 path_text = path.name 492 if path.is_number: 493 path_text = f"[{path_text}]" 494 495 try: 496 return parse_json_path(path_text) 497 except ParseError as e: 498 logger.warning(f"Invalid JSON path syntax. {str(e)}") 499 500 return path 501 502 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 503 return self.parser(**opts).parse(self.tokenize(sql), sql) 504 505 def parse_into( 506 self, expression_type: exp.IntoType, sql: str, **opts 507 ) -> t.List[t.Optional[exp.Expression]]: 508 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 509 510 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 511 return self.generator(**opts).generate(expression, copy=copy) 512 513 def transpile(self, sql: str, **opts) -> t.List[str]: 514 return [ 515 self.generate(expression, copy=False, **opts) if expression else "" 516 for expression in self.parse(sql) 517 ] 518 519 def tokenize(self, sql: str) -> t.List[Token]: 520 return self.tokenizer.tokenize(sql) 521 522 @property 523 def tokenizer(self) -> Tokenizer: 524 if not hasattr(self, "_tokenizer"): 525 self._tokenizer = self.tokenizer_class(dialect=self) 526 return self._tokenizer 527 528 def parser(self, **opts) -> Parser: 529 return self.parser_class(dialect=self, **opts) 530 531 def generator(self, **opts) -> Generator: 532 return self.generator_class(dialect=self, **opts) 533 534 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 536 537 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 539 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 540 541 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 543 if expression.args.get("accuracy"): 544 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 545 return self.func("APPROX_COUNT_DISTINCT", expression.this) 546 547 548def if_sql( 549 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 550) -> t.Callable[[Generator, exp.If], str]: 551 def _if_sql(self: Generator, expression: exp.If) -> str: 552 return self.func( 553 name, 554 expression.this, 555 expression.args.get("true"), 556 expression.args.get("false") or false_value, 557 ) 558 559 return _if_sql 560 561 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 563 this = expression.this 564 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 565 this.replace(exp.cast(this, "json")) 566 567 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 568 569 570def inline_array_sql(self: Generator, expression: exp.Array) -> str: 571 return f"[{self.expressions(expression, flat=True)}]" 572 573 574def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 575 return self.like_sql( 576 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 577 ) 578 579 580def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 581 zone = self.sql(expression, "this") 582 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 583 584 585def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 586 if expression.args.get("recursive"): 587 self.unsupported("Recursive CTEs are unsupported") 588 expression.args["recursive"] = False 589 return self.with_sql(expression) 590 591 592def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 593 n = self.sql(expression, "this") 594 d = self.sql(expression, "expression") 595 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 596 597 598def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 599 self.unsupported("TABLESAMPLE unsupported") 600 return self.sql(expression.this) 601 602 603def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 604 self.unsupported("PIVOT unsupported") 605 return "" 606 607 608def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 609 return self.cast_sql(expression) 610 611 612def no_comment_column_constraint_sql( 613 self: Generator, expression: exp.CommentColumnConstraint 614) -> str: 615 self.unsupported("CommentColumnConstraint unsupported") 616 return "" 617 618 619def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 620 self.unsupported("MAP_FROM_ENTRIES unsupported") 621 return "" 622 623 624def str_position_sql( 625 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 626) -> str: 627 this = self.sql(expression, "this") 628 substr = self.sql(expression, "substr") 629 position = self.sql(expression, "position") 630 instance = expression.args.get("instance") if generate_instance else None 631 position_offset = "" 632 633 if position: 634 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 635 this = self.func("SUBSTR", this, position) 636 position_offset = f" + {position} - 1" 637 638 return self.func("STRPOS", this, substr, instance) + position_offset 639 640 641def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 642 return ( 643 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 644 ) 645 646 647def var_map_sql( 648 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 649) -> str: 650 keys = expression.args["keys"] 651 values = expression.args["values"] 652 653 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 654 self.unsupported("Cannot convert array columns into map.") 655 return self.func(map_func_name, keys, values) 656 657 args = [] 658 for key, value in zip(keys.expressions, values.expressions): 659 args.append(self.sql(key)) 660 args.append(self.sql(value)) 661 662 return self.func(map_func_name, *args) 663 664 665def build_formatted_time( 666 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 667) -> t.Callable[[t.List], E]: 668 """Helper used for time expressions. 669 670 Args: 671 exp_class: the expression class to instantiate. 672 dialect: target sql dialect. 673 default: the default format, True being time. 674 675 Returns: 676 A callable that can be used to return the appropriately formatted time expression. 677 """ 678 679 def _builder(args: t.List): 680 return exp_class( 681 this=seq_get(args, 0), 682 format=Dialect[dialect].format_time( 683 seq_get(args, 1) 684 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 685 ), 686 ) 687 688 return _builder 689 690 691def time_format( 692 dialect: DialectType = None, 693) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 694 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 695 """ 696 Returns the time format for a given expression, unless it's equivalent 697 to the default time format of the dialect of interest. 698 """ 699 time_format = self.format_time(expression) 700 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 701 702 return _time_format 703 704 705def build_date_delta( 706 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 707) -> t.Callable[[t.List], E]: 708 def _builder(args: t.List) -> E: 709 unit_based = len(args) == 3 710 this = args[2] if unit_based else seq_get(args, 0) 711 unit = args[0] if unit_based else exp.Literal.string("DAY") 712 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 713 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 714 715 return _builder 716 717 718def build_date_delta_with_interval( 719 expression_class: t.Type[E], 720) -> t.Callable[[t.List], t.Optional[E]]: 721 def _builder(args: t.List) -> t.Optional[E]: 722 if len(args) < 2: 723 return None 724 725 interval = args[1] 726 727 if not isinstance(interval, exp.Interval): 728 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 729 730 expression = interval.this 731 if expression and expression.is_string: 732 expression = exp.Literal.number(expression.this) 733 734 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 735 736 return _builder 737 738 739def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 740 unit = seq_get(args, 0) 741 this = seq_get(args, 1) 742 743 if isinstance(this, exp.Cast) and this.is_type("date"): 744 return exp.DateTrunc(unit=unit, this=this) 745 return exp.TimestampTrunc(this=this, unit=unit) 746 747 748def date_add_interval_sql( 749 data_type: str, kind: str 750) -> t.Callable[[Generator, exp.Expression], str]: 751 def func(self: Generator, expression: exp.Expression) -> str: 752 this = self.sql(expression, "this") 753 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 754 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 755 756 return func 757 758 759def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 760 return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) 761 762 763def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 764 if not expression.expression: 765 from sqlglot.optimizer.annotate_types import annotate_types 766 767 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 768 return self.sql(exp.cast(expression.this, to=target_type)) 769 if expression.text("expression").lower() in TIMEZONES: 770 return self.sql( 771 exp.AtTimeZone( 772 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 773 zone=expression.expression, 774 ) 775 ) 776 return self.func("TIMESTAMP", expression.this, expression.expression) 777 778 779def locate_to_strposition(args: t.List) -> exp.Expression: 780 return exp.StrPosition( 781 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 782 ) 783 784 785def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 786 return self.func( 787 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 788 ) 789 790 791def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 792 return self.sql( 793 exp.Substring( 794 this=expression.this, start=exp.Literal.number(1), length=expression.expression 795 ) 796 ) 797 798 799def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 800 return self.sql( 801 exp.Substring( 802 this=expression.this, 803 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 804 ) 805 ) 806 807 808def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 809 return self.sql(exp.cast(expression.this, "timestamp")) 810 811 812def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 813 return self.sql(exp.cast(expression.this, "date")) 814 815 816# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 817def encode_decode_sql( 818 self: Generator, expression: exp.Expression, name: str, replace: bool = True 819) -> str: 820 charset = expression.args.get("charset") 821 if charset and charset.name.lower() != "utf-8": 822 self.unsupported(f"Expected utf-8 character set, got {charset}.") 823 824 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 825 826 827def min_or_least(self: Generator, expression: exp.Min) -> str: 828 name = "LEAST" if expression.expressions else "MIN" 829 return rename_func(name)(self, expression) 830 831 832def max_or_greatest(self: Generator, expression: exp.Max) -> str: 833 name = "GREATEST" if expression.expressions else "MAX" 834 return rename_func(name)(self, expression) 835 836 837def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 838 cond = expression.this 839 840 if isinstance(expression.this, exp.Distinct): 841 cond = expression.this.expressions[0] 842 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 843 844 return self.func("sum", exp.func("if", cond, 1, 0)) 845 846 847def trim_sql(self: Generator, expression: exp.Trim) -> str: 848 target = self.sql(expression, "this") 849 trim_type = self.sql(expression, "position") 850 remove_chars = self.sql(expression, "expression") 851 collation = self.sql(expression, "collation") 852 853 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 854 if not remove_chars and not collation: 855 return self.trim_sql(expression) 856 857 trim_type = f"{trim_type} " if trim_type else "" 858 remove_chars = f"{remove_chars} " if remove_chars else "" 859 from_part = "FROM " if trim_type or remove_chars else "" 860 collation = f" COLLATE {collation}" if collation else "" 861 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 862 863 864def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 865 return self.func("STRPTIME", expression.this, self.format_time(expression)) 866 867 868def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 869 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 870 871 872def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 873 delim, *rest_args = expression.expressions 874 return self.sql( 875 reduce( 876 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 877 rest_args, 878 ) 879 ) 880 881 882def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 883 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 884 if bad_args: 885 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 886 887 return self.func( 888 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 889 ) 890 891 892def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 893 bad_args = list( 894 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 895 ) 896 if bad_args: 897 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 898 899 return self.func( 900 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 901 ) 902 903 904def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 905 names = [] 906 for agg in aggregations: 907 if isinstance(agg, exp.Alias): 908 names.append(agg.alias) 909 else: 910 """ 911 This case corresponds to aggregations without aliases being used as suffixes 912 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 913 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 914 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 915 """ 916 agg_all_unquoted = agg.transform( 917 lambda node: ( 918 exp.Identifier(this=node.name, quoted=False) 919 if isinstance(node, exp.Identifier) 920 else node 921 ) 922 ) 923 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 924 925 return names 926 927 928def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 929 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 930 931 932# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 933def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 934 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 935 936 937def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 938 return self.func("MAX", expression.this) 939 940 941def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 942 a = self.sql(expression.left) 943 b = self.sql(expression.right) 944 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 945 946 947def is_parse_json(expression: exp.Expression) -> bool: 948 return isinstance(expression, exp.ParseJSON) or ( 949 isinstance(expression, exp.Cast) and expression.is_type("json") 950 ) 951 952 953def isnull_to_is_null(args: t.List) -> exp.Expression: 954 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 955 956 957def generatedasidentitycolumnconstraint_sql( 958 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 959) -> str: 960 start = self.sql(expression, "start") or "1" 961 increment = self.sql(expression, "increment") or "1" 962 return f"IDENTITY({start}, {increment})" 963 964 965def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 966 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 967 if expression.args.get("count"): 968 self.unsupported(f"Only two arguments are supported in function {name}.") 969 970 return self.func(name, expression.this, expression.expression) 971 972 return _arg_max_or_min_sql 973 974 975def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 976 this = expression.this.copy() 977 978 return_type = expression.return_type 979 if return_type.is_type(exp.DataType.Type.DATE): 980 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 981 # can truncate timestamp strings, because some dialects can't cast them to DATE 982 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 983 984 expression.this.replace(exp.cast(this, return_type)) 985 return expression 986 987 988def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 989 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 990 if cast and isinstance(expression, exp.TsOrDsAdd): 991 expression = ts_or_ds_add_cast(expression) 992 993 return self.func( 994 name, 995 unit_to_var(expression), 996 expression.expression, 997 expression.this, 998 ) 999 1000 return _delta_sql 1001 1002 1003def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1004 unit = expression.args.get("unit") 1005 1006 if isinstance(unit, exp.Placeholder): 1007 return unit 1008 if unit: 1009 return exp.Literal.string(unit.name) 1010 return exp.Literal.string(default) if default else None 1011 1012 1013def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1014 unit = expression.args.get("unit") 1015 1016 if isinstance(unit, (exp.Var, exp.Placeholder)): 1017 return unit 1018 return exp.Var(this=default) if default else None 1019 1020 1021def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1022 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1023 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1024 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1025 1026 return self.sql(exp.cast(minus_one_day, "date")) 1027 1028 1029def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1030 """Remove table refs from columns in when statements.""" 1031 alias = expression.this.args.get("alias") 1032 1033 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1034 return self.dialect.normalize_identifier(identifier).name if identifier else None 1035 1036 targets = {normalize(expression.this.this)} 1037 1038 if alias: 1039 targets.add(normalize(alias.this)) 1040 1041 for when in expression.expressions: 1042 when.transform( 1043 lambda node: ( 1044 exp.column(node.this) 1045 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1046 else node 1047 ), 1048 copy=False, 1049 ) 1050 1051 return self.merge_sql(expression) 1052 1053 1054def build_json_extract_path( 1055 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1056) -> t.Callable[[t.List], F]: 1057 def _builder(args: t.List) -> F: 1058 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1059 for arg in args[1:]: 1060 if not isinstance(arg, exp.Literal): 1061 # We use the fallback parser because we can't really transpile non-literals safely 1062 return expr_type.from_arg_list(args) 1063 1064 text = arg.name 1065 if is_int(text): 1066 index = int(text) 1067 segments.append( 1068 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1069 ) 1070 else: 1071 segments.append(exp.JSONPathKey(this=text)) 1072 1073 # This is done to avoid failing in the expression validator due to the arg count 1074 del args[2:] 1075 return expr_type( 1076 this=seq_get(args, 0), 1077 expression=exp.JSONPath(expressions=segments), 1078 only_json_types=arrow_req_json_type, 1079 ) 1080 1081 return _builder 1082 1083 1084def json_extract_segments( 1085 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1086) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1087 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1088 path = expression.expression 1089 if not isinstance(path, exp.JSONPath): 1090 return rename_func(name)(self, expression) 1091 1092 segments = [] 1093 for segment in path.expressions: 1094 path = self.sql(segment) 1095 if path: 1096 if isinstance(segment, exp.JSONPathPart) and ( 1097 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1098 ): 1099 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1100 1101 segments.append(path) 1102 1103 if op: 1104 return f" {op} ".join([self.sql(expression.this), *segments]) 1105 return self.func(name, expression.this, *segments) 1106 1107 return _json_extract_segments 1108 1109 1110def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1111 if isinstance(expression.this, exp.JSONPathWildcard): 1112 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1113 1114 return expression.name 1115 1116 1117def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1118 cond = expression.expression 1119 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1120 alias = cond.expressions[0] 1121 cond = cond.this 1122 elif isinstance(cond, exp.Predicate): 1123 alias = "_u" 1124 else: 1125 self.unsupported("Unsupported filter condition") 1126 return "" 1127 1128 unnest = exp.Unnest(expressions=[expression.this]) 1129 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1130 return self.sql(exp.Array(expressions=[filtered])) 1131 1132 1133def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1134 return self.func( 1135 "TO_NUMBER", 1136 expression.this, 1137 expression.args.get("format"), 1138 expression.args.get("nlsparam"), 1139 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
181class Dialect(metaclass=_Dialect): 182 INDEX_OFFSET = 0 183 """The base index offset for arrays.""" 184 185 WEEK_OFFSET = 0 186 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 187 188 UNNEST_COLUMN_ONLY = False 189 """Whether `UNNEST` table aliases are treated as column aliases.""" 190 191 ALIAS_POST_TABLESAMPLE = False 192 """Whether the table alias comes after tablesample.""" 193 194 TABLESAMPLE_SIZE_IS_PERCENT = False 195 """Whether a size in the table sample clause represents percentage.""" 196 197 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 198 """Specifies the strategy according to which identifiers should be normalized.""" 199 200 IDENTIFIERS_CAN_START_WITH_DIGIT = False 201 """Whether an unquoted identifier can start with a digit.""" 202 203 DPIPE_IS_STRING_CONCAT = True 204 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 205 206 STRICT_STRING_CONCAT = False 207 """Whether `CONCAT`'s arguments must be strings.""" 208 209 SUPPORTS_USER_DEFINED_TYPES = True 210 """Whether user-defined data types are supported.""" 211 212 SUPPORTS_SEMI_ANTI_JOIN = True 213 """Whether `SEMI` or `ANTI` joins are supported.""" 214 215 NORMALIZE_FUNCTIONS: bool | str = "upper" 216 """ 217 Determines how function names are going to be normalized. 218 Possible values: 219 "upper" or True: Convert names to uppercase. 220 "lower": Convert names to lowercase. 221 False: Disables function name normalization. 222 """ 223 224 LOG_BASE_FIRST: t.Optional[bool] = True 225 """ 226 Whether the base comes first in the `LOG` function. 227 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 228 """ 229 230 NULL_ORDERING = "nulls_are_small" 231 """ 232 Default `NULL` ordering method to use if not explicitly set. 233 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 234 """ 235 236 TYPED_DIVISION = False 237 """ 238 Whether the behavior of `a / b` depends on the types of `a` and `b`. 239 False means `a / b` is always float division. 240 True means `a / b` is integer division if both `a` and `b` are integers. 241 """ 242 243 SAFE_DIVISION = False 244 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 245 246 CONCAT_COALESCE = False 247 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 248 249 DATE_FORMAT = "'%Y-%m-%d'" 250 DATEINT_FORMAT = "'%Y%m%d'" 251 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 252 253 TIME_MAPPING: t.Dict[str, str] = {} 254 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 255 256 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 257 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 258 FORMAT_MAPPING: t.Dict[str, str] = {} 259 """ 260 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 261 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 262 """ 263 264 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 265 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 266 267 PSEUDOCOLUMNS: t.Set[str] = set() 268 """ 269 Columns that are auto-generated by the engine corresponding to this dialect. 270 For example, such columns may be excluded from `SELECT *` queries. 271 """ 272 273 PREFER_CTE_ALIAS_COLUMN = False 274 """ 275 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 276 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 277 any projection aliases in the subquery. 278 279 For example, 280 WITH y(c) AS ( 281 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 282 ) SELECT c FROM y; 283 284 will be rewritten as 285 286 WITH y(c) AS ( 287 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 288 ) SELECT c FROM y; 289 """ 290 291 # --- Autofilled --- 292 293 tokenizer_class = Tokenizer 294 parser_class = Parser 295 generator_class = Generator 296 297 # A trie of the time_mapping keys 298 TIME_TRIE: t.Dict = {} 299 FORMAT_TRIE: t.Dict = {} 300 301 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 302 INVERSE_TIME_TRIE: t.Dict = {} 303 304 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 305 306 # Delimiters for string literals and identifiers 307 QUOTE_START = "'" 308 QUOTE_END = "'" 309 IDENTIFIER_START = '"' 310 IDENTIFIER_END = '"' 311 312 # Delimiters for bit, hex, byte and unicode literals 313 BIT_START: t.Optional[str] = None 314 BIT_END: t.Optional[str] = None 315 HEX_START: t.Optional[str] = None 316 HEX_END: t.Optional[str] = None 317 BYTE_START: t.Optional[str] = None 318 BYTE_END: t.Optional[str] = None 319 UNICODE_START: t.Optional[str] = None 320 UNICODE_END: t.Optional[str] = None 321 322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 369 370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression 385 386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 393 394 def __eq__(self, other: t.Any) -> bool: 395 # Does not currently take dialect state into account 396 return type(self) == other 397 398 def __hash__(self) -> int: 399 # Does not currently take dialect state into account 400 return hash(type(self)) 401 402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression 439 440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text) 451 452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False 471 472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression 489 490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path 502 503 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 504 return self.parser(**opts).parse(self.tokenize(sql), sql) 505 506 def parse_into( 507 self, expression_type: exp.IntoType, sql: str, **opts 508 ) -> t.List[t.Optional[exp.Expression]]: 509 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 510 511 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 512 return self.generator(**opts).generate(expression, copy=copy) 513 514 def transpile(self, sql: str, **opts) -> t.List[str]: 515 return [ 516 self.generate(expression, copy=False, **opts) if expression else "" 517 for expression in self.parse(sql) 518 ] 519 520 def tokenize(self, sql: str) -> t.List[Token]: 521 return self.tokenizer.tokenize(sql) 522 523 @property 524 def tokenizer(self) -> Tokenizer: 525 if not hasattr(self, "_tokenizer"): 526 self._tokenizer = self.tokenizer_class(dialect=self) 527 return self._tokenizer 528 529 def parser(self, **opts) -> Parser: 530 return self.parser_class(dialect=self, **opts) 531 532 def generator(self, **opts) -> Generator: 533 return self.generator_class(dialect=self, **opts)
386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path
549def if_sql( 550 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 551) -> t.Callable[[Generator, exp.If], str]: 552 def _if_sql(self: Generator, expression: exp.If) -> str: 553 return self.func( 554 name, 555 expression.this, 556 expression.args.get("true"), 557 expression.args.get("false") or false_value, 558 ) 559 560 return _if_sql
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 564 this = expression.this 565 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 566 this.replace(exp.cast(this, "json")) 567 568 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
625def str_position_sql( 626 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 627) -> str: 628 this = self.sql(expression, "this") 629 substr = self.sql(expression, "substr") 630 position = self.sql(expression, "position") 631 instance = expression.args.get("instance") if generate_instance else None 632 position_offset = "" 633 634 if position: 635 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 636 this = self.func("SUBSTR", this, position) 637 position_offset = f" + {position} - 1" 638 639 return self.func("STRPOS", this, substr, instance) + position_offset
648def var_map_sql( 649 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 650) -> str: 651 keys = expression.args["keys"] 652 values = expression.args["values"] 653 654 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 655 self.unsupported("Cannot convert array columns into map.") 656 return self.func(map_func_name, keys, values) 657 658 args = [] 659 for key, value in zip(keys.expressions, values.expressions): 660 args.append(self.sql(key)) 661 args.append(self.sql(value)) 662 663 return self.func(map_func_name, *args)
666def build_formatted_time( 667 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 668) -> t.Callable[[t.List], E]: 669 """Helper used for time expressions. 670 671 Args: 672 exp_class: the expression class to instantiate. 673 dialect: target sql dialect. 674 default: the default format, True being time. 675 676 Returns: 677 A callable that can be used to return the appropriately formatted time expression. 678 """ 679 680 def _builder(args: t.List): 681 return exp_class( 682 this=seq_get(args, 0), 683 format=Dialect[dialect].format_time( 684 seq_get(args, 1) 685 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 686 ), 687 ) 688 689 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
692def time_format( 693 dialect: DialectType = None, 694) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 695 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 696 """ 697 Returns the time format for a given expression, unless it's equivalent 698 to the default time format of the dialect of interest. 699 """ 700 time_format = self.format_time(expression) 701 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 702 703 return _time_format
706def build_date_delta( 707 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 708) -> t.Callable[[t.List], E]: 709 def _builder(args: t.List) -> E: 710 unit_based = len(args) == 3 711 this = args[2] if unit_based else seq_get(args, 0) 712 unit = args[0] if unit_based else exp.Literal.string("DAY") 713 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 714 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 715 716 return _builder
719def build_date_delta_with_interval( 720 expression_class: t.Type[E], 721) -> t.Callable[[t.List], t.Optional[E]]: 722 def _builder(args: t.List) -> t.Optional[E]: 723 if len(args) < 2: 724 return None 725 726 interval = args[1] 727 728 if not isinstance(interval, exp.Interval): 729 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 730 731 expression = interval.this 732 if expression and expression.is_string: 733 expression = exp.Literal.number(expression.this) 734 735 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 736 737 return _builder
749def date_add_interval_sql( 750 data_type: str, kind: str 751) -> t.Callable[[Generator, exp.Expression], str]: 752 def func(self: Generator, expression: exp.Expression) -> str: 753 this = self.sql(expression, "this") 754 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 755 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 756 757 return func
764def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 765 if not expression.expression: 766 from sqlglot.optimizer.annotate_types import annotate_types 767 768 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 769 return self.sql(exp.cast(expression.this, to=target_type)) 770 if expression.text("expression").lower() in TIMEZONES: 771 return self.sql( 772 exp.AtTimeZone( 773 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 774 zone=expression.expression, 775 ) 776 ) 777 return self.func("TIMESTAMP", expression.this, expression.expression)
818def encode_decode_sql( 819 self: Generator, expression: exp.Expression, name: str, replace: bool = True 820) -> str: 821 charset = expression.args.get("charset") 822 if charset and charset.name.lower() != "utf-8": 823 self.unsupported(f"Expected utf-8 character set, got {charset}.") 824 825 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
838def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 839 cond = expression.this 840 841 if isinstance(expression.this, exp.Distinct): 842 cond = expression.this.expressions[0] 843 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 844 845 return self.func("sum", exp.func("if", cond, 1, 0))
848def trim_sql(self: Generator, expression: exp.Trim) -> str: 849 target = self.sql(expression, "this") 850 trim_type = self.sql(expression, "position") 851 remove_chars = self.sql(expression, "expression") 852 collation = self.sql(expression, "collation") 853 854 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 855 if not remove_chars and not collation: 856 return self.trim_sql(expression) 857 858 trim_type = f"{trim_type} " if trim_type else "" 859 remove_chars = f"{remove_chars} " if remove_chars else "" 860 from_part = "FROM " if trim_type or remove_chars else "" 861 collation = f" COLLATE {collation}" if collation else "" 862 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
883def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 884 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 885 if bad_args: 886 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 887 888 return self.func( 889 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 890 )
893def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 894 bad_args = list( 895 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 896 ) 897 if bad_args: 898 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 899 900 return self.func( 901 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 902 )
905def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 906 names = [] 907 for agg in aggregations: 908 if isinstance(agg, exp.Alias): 909 names.append(agg.alias) 910 else: 911 """ 912 This case corresponds to aggregations without aliases being used as suffixes 913 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 914 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 915 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 916 """ 917 agg_all_unquoted = agg.transform( 918 lambda node: ( 919 exp.Identifier(this=node.name, quoted=False) 920 if isinstance(node, exp.Identifier) 921 else node 922 ) 923 ) 924 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 925 926 return names
966def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 967 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 968 if expression.args.get("count"): 969 self.unsupported(f"Only two arguments are supported in function {name}.") 970 971 return self.func(name, expression.this, expression.expression) 972 973 return _arg_max_or_min_sql
976def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 977 this = expression.this.copy() 978 979 return_type = expression.return_type 980 if return_type.is_type(exp.DataType.Type.DATE): 981 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 982 # can truncate timestamp strings, because some dialects can't cast them to DATE 983 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 984 985 expression.this.replace(exp.cast(this, return_type)) 986 return expression
989def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 990 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 991 if cast and isinstance(expression, exp.TsOrDsAdd): 992 expression = ts_or_ds_add_cast(expression) 993 994 return self.func( 995 name, 996 unit_to_var(expression), 997 expression.expression, 998 expression.this, 999 ) 1000 1001 return _delta_sql
1004def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1005 unit = expression.args.get("unit") 1006 1007 if isinstance(unit, exp.Placeholder): 1008 return unit 1009 if unit: 1010 return exp.Literal.string(unit.name) 1011 return exp.Literal.string(default) if default else None
1022def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1023 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1024 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1025 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1026 1027 return self.sql(exp.cast(minus_one_day, "date"))
1030def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1031 """Remove table refs from columns in when statements.""" 1032 alias = expression.this.args.get("alias") 1033 1034 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1035 return self.dialect.normalize_identifier(identifier).name if identifier else None 1036 1037 targets = {normalize(expression.this.this)} 1038 1039 if alias: 1040 targets.add(normalize(alias.this)) 1041 1042 for when in expression.expressions: 1043 when.transform( 1044 lambda node: ( 1045 exp.column(node.this) 1046 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1047 else node 1048 ), 1049 copy=False, 1050 ) 1051 1052 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1055def build_json_extract_path( 1056 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1057) -> t.Callable[[t.List], F]: 1058 def _builder(args: t.List) -> F: 1059 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1060 for arg in args[1:]: 1061 if not isinstance(arg, exp.Literal): 1062 # We use the fallback parser because we can't really transpile non-literals safely 1063 return expr_type.from_arg_list(args) 1064 1065 text = arg.name 1066 if is_int(text): 1067 index = int(text) 1068 segments.append( 1069 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1070 ) 1071 else: 1072 segments.append(exp.JSONPathKey(this=text)) 1073 1074 # This is done to avoid failing in the expression validator due to the arg count 1075 del args[2:] 1076 return expr_type( 1077 this=seq_get(args, 0), 1078 expression=exp.JSONPath(expressions=segments), 1079 only_json_types=arrow_req_json_type, 1080 ) 1081 1082 return _builder
1085def json_extract_segments( 1086 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1087) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1088 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1089 path = expression.expression 1090 if not isinstance(path, exp.JSONPath): 1091 return rename_func(name)(self, expression) 1092 1093 segments = [] 1094 for segment in path.expressions: 1095 path = self.sql(segment) 1096 if path: 1097 if isinstance(segment, exp.JSONPathPart) and ( 1098 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1099 ): 1100 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1101 1102 segments.append(path) 1103 1104 if op: 1105 return f" {op} ".join([self.sql(expression.this), *segments]) 1106 return self.func(name, expression.this, *segments) 1107 1108 return _json_extract_segments
1118def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1119 cond = expression.expression 1120 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1121 alias = cond.expressions[0] 1122 cond = cond.this 1123 elif isinstance(cond, exp.Predicate): 1124 alias = "_u" 1125 else: 1126 self.unsupported("Unsupported filter condition") 1127 return "" 1128 1129 unnest = exp.Unnest(expressions=[expression.this]) 1130 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1131 return self.sql(exp.Array(expressions=[filtered]))