sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "athena", "presto", "trino"): 165 klass.generator_class.TRY_SUPPORTED = False 166 167 if enum not in ("", "databricks", "hive", "spark", "spark2"): 168 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 169 for modifier in ("cluster", "distribute", "sort"): 170 modifier_transforms.pop(modifier, None) 171 172 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 173 174 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 175 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 176 TokenType.ANTI, 177 TokenType.SEMI, 178 } 179 180 return klass 181 182 183class Dialect(metaclass=_Dialect): 184 INDEX_OFFSET = 0 185 """The base index offset for arrays.""" 186 187 WEEK_OFFSET = 0 188 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 189 190 UNNEST_COLUMN_ONLY = False 191 """Whether `UNNEST` table aliases are treated as column aliases.""" 192 193 ALIAS_POST_TABLESAMPLE = False 194 """Whether the table alias comes after tablesample.""" 195 196 TABLESAMPLE_SIZE_IS_PERCENT = False 197 """Whether a size in the table sample clause represents percentage.""" 198 199 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 200 """Specifies the strategy according to which identifiers should be normalized.""" 201 202 IDENTIFIERS_CAN_START_WITH_DIGIT = False 203 """Whether an unquoted identifier can start with a digit.""" 204 205 DPIPE_IS_STRING_CONCAT = True 206 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 207 208 STRICT_STRING_CONCAT = False 209 """Whether `CONCAT`'s arguments must be strings.""" 210 211 SUPPORTS_USER_DEFINED_TYPES = True 212 """Whether user-defined data types are supported.""" 213 214 SUPPORTS_SEMI_ANTI_JOIN = True 215 """Whether `SEMI` or `ANTI` joins are supported.""" 216 217 NORMALIZE_FUNCTIONS: bool | str = "upper" 218 """ 219 Determines how function names are going to be normalized. 220 Possible values: 221 "upper" or True: Convert names to uppercase. 222 "lower": Convert names to lowercase. 223 False: Disables function name normalization. 224 """ 225 226 LOG_BASE_FIRST: t.Optional[bool] = True 227 """ 228 Whether the base comes first in the `LOG` function. 229 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 230 """ 231 232 NULL_ORDERING = "nulls_are_small" 233 """ 234 Default `NULL` ordering method to use if not explicitly set. 235 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 236 """ 237 238 TYPED_DIVISION = False 239 """ 240 Whether the behavior of `a / b` depends on the types of `a` and `b`. 241 False means `a / b` is always float division. 242 True means `a / b` is integer division if both `a` and `b` are integers. 243 """ 244 245 SAFE_DIVISION = False 246 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 247 248 CONCAT_COALESCE = False 249 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 250 251 HEX_LOWERCASE = False 252 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 253 254 DATE_FORMAT = "'%Y-%m-%d'" 255 DATEINT_FORMAT = "'%Y%m%d'" 256 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 257 258 TIME_MAPPING: t.Dict[str, str] = {} 259 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 260 261 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 262 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 263 FORMAT_MAPPING: t.Dict[str, str] = {} 264 """ 265 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 266 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 267 """ 268 269 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 270 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 271 272 PSEUDOCOLUMNS: t.Set[str] = set() 273 """ 274 Columns that are auto-generated by the engine corresponding to this dialect. 275 For example, such columns may be excluded from `SELECT *` queries. 276 """ 277 278 PREFER_CTE_ALIAS_COLUMN = False 279 """ 280 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 281 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 282 any projection aliases in the subquery. 283 284 For example, 285 WITH y(c) AS ( 286 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 287 ) SELECT c FROM y; 288 289 will be rewritten as 290 291 WITH y(c) AS ( 292 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 293 ) SELECT c FROM y; 294 """ 295 296 # --- Autofilled --- 297 298 tokenizer_class = Tokenizer 299 parser_class = Parser 300 generator_class = Generator 301 302 # A trie of the time_mapping keys 303 TIME_TRIE: t.Dict = {} 304 FORMAT_TRIE: t.Dict = {} 305 306 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 307 INVERSE_TIME_TRIE: t.Dict = {} 308 309 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 310 311 # Delimiters for string literals and identifiers 312 QUOTE_START = "'" 313 QUOTE_END = "'" 314 IDENTIFIER_START = '"' 315 IDENTIFIER_END = '"' 316 317 # Delimiters for bit, hex, byte and unicode literals 318 BIT_START: t.Optional[str] = None 319 BIT_END: t.Optional[str] = None 320 HEX_START: t.Optional[str] = None 321 HEX_END: t.Optional[str] = None 322 BYTE_START: t.Optional[str] = None 323 BYTE_END: t.Optional[str] = None 324 UNICODE_START: t.Optional[str] = None 325 UNICODE_END: t.Optional[str] = None 326 327 # Separator of COPY statement parameters 328 COPY_PARAMS_ARE_CSV = True 329 330 @classmethod 331 def get_or_raise(cls, dialect: DialectType) -> Dialect: 332 """ 333 Look up a dialect in the global dialect registry and return it if it exists. 334 335 Args: 336 dialect: The target dialect. If this is a string, it can be optionally followed by 337 additional key-value pairs that are separated by commas and are used to specify 338 dialect settings, such as whether the dialect's identifiers are case-sensitive. 339 340 Example: 341 >>> dialect = dialect_class = get_or_raise("duckdb") 342 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 343 344 Returns: 345 The corresponding Dialect instance. 346 """ 347 348 if not dialect: 349 return cls() 350 if isinstance(dialect, _Dialect): 351 return dialect() 352 if isinstance(dialect, Dialect): 353 return dialect 354 if isinstance(dialect, str): 355 try: 356 dialect_name, *kv_pairs = dialect.split(",") 357 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 358 except ValueError: 359 raise ValueError( 360 f"Invalid dialect format: '{dialect}'. " 361 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 362 ) 363 364 result = cls.get(dialect_name.strip()) 365 if not result: 366 from difflib import get_close_matches 367 368 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 369 if similar: 370 similar = f" Did you mean {similar}?" 371 372 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 373 374 return result(**kwargs) 375 376 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 377 378 @classmethod 379 def format_time( 380 cls, expression: t.Optional[str | exp.Expression] 381 ) -> t.Optional[exp.Expression]: 382 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 383 if isinstance(expression, str): 384 return exp.Literal.string( 385 # the time formats are quoted 386 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 387 ) 388 389 if expression and expression.is_string: 390 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 391 392 return expression 393 394 def __init__(self, **kwargs) -> None: 395 normalization_strategy = kwargs.get("normalization_strategy") 396 397 if normalization_strategy is None: 398 self.normalization_strategy = self.NORMALIZATION_STRATEGY 399 else: 400 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 401 402 def __eq__(self, other: t.Any) -> bool: 403 # Does not currently take dialect state into account 404 return type(self) == other 405 406 def __hash__(self) -> int: 407 # Does not currently take dialect state into account 408 return hash(type(self)) 409 410 def normalize_identifier(self, expression: E) -> E: 411 """ 412 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 413 414 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 415 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 416 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 417 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 418 419 There are also dialects like Spark, which are case-insensitive even when quotes are 420 present, and dialects like MySQL, whose resolution rules match those employed by the 421 underlying operating system, for example they may always be case-sensitive in Linux. 422 423 Finally, the normalization behavior of some engines can even be controlled through flags, 424 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 425 426 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 427 that it can analyze queries in the optimizer and successfully capture their semantics. 428 """ 429 if ( 430 isinstance(expression, exp.Identifier) 431 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 432 and ( 433 not expression.quoted 434 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 435 ) 436 ): 437 expression.set( 438 "this", 439 ( 440 expression.this.upper() 441 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 442 else expression.this.lower() 443 ), 444 ) 445 446 return expression 447 448 def case_sensitive(self, text: str) -> bool: 449 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 450 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 451 return False 452 453 unsafe = ( 454 str.islower 455 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 456 else str.isupper 457 ) 458 return any(unsafe(char) for char in text) 459 460 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 461 """Checks if text can be identified given an identify option. 462 463 Args: 464 text: The text to check. 465 identify: 466 `"always"` or `True`: Always returns `True`. 467 `"safe"`: Only returns `True` if the identifier is case-insensitive. 468 469 Returns: 470 Whether the given text can be identified. 471 """ 472 if identify is True or identify == "always": 473 return True 474 475 if identify == "safe": 476 return not self.case_sensitive(text) 477 478 return False 479 480 def quote_identifier(self, expression: E, identify: bool = True) -> E: 481 """ 482 Adds quotes to a given identifier. 483 484 Args: 485 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 486 identify: If set to `False`, the quotes will only be added if the identifier is deemed 487 "unsafe", with respect to its characters and this dialect's normalization strategy. 488 """ 489 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 490 name = expression.this 491 expression.set( 492 "quoted", 493 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 494 ) 495 496 return expression 497 498 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 499 if isinstance(path, exp.Literal): 500 path_text = path.name 501 if path.is_number: 502 path_text = f"[{path_text}]" 503 504 try: 505 return parse_json_path(path_text) 506 except ParseError as e: 507 logger.warning(f"Invalid JSON path syntax. {str(e)}") 508 509 return path 510 511 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 512 return self.parser(**opts).parse(self.tokenize(sql), sql) 513 514 def parse_into( 515 self, expression_type: exp.IntoType, sql: str, **opts 516 ) -> t.List[t.Optional[exp.Expression]]: 517 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 518 519 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 520 return self.generator(**opts).generate(expression, copy=copy) 521 522 def transpile(self, sql: str, **opts) -> t.List[str]: 523 return [ 524 self.generate(expression, copy=False, **opts) if expression else "" 525 for expression in self.parse(sql) 526 ] 527 528 def tokenize(self, sql: str) -> t.List[Token]: 529 return self.tokenizer.tokenize(sql) 530 531 @property 532 def tokenizer(self) -> Tokenizer: 533 if not hasattr(self, "_tokenizer"): 534 self._tokenizer = self.tokenizer_class(dialect=self) 535 return self._tokenizer 536 537 def parser(self, **opts) -> Parser: 538 return self.parser_class(dialect=self, **opts) 539 540 def generator(self, **opts) -> Generator: 541 return self.generator_class(dialect=self, **opts) 542 543 544DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 545 546 547def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 548 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 549 550 551def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 552 if expression.args.get("accuracy"): 553 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 554 return self.func("APPROX_COUNT_DISTINCT", expression.this) 555 556 557def if_sql( 558 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 559) -> t.Callable[[Generator, exp.If], str]: 560 def _if_sql(self: Generator, expression: exp.If) -> str: 561 return self.func( 562 name, 563 expression.this, 564 expression.args.get("true"), 565 expression.args.get("false") or false_value, 566 ) 567 568 return _if_sql 569 570 571def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 572 this = expression.this 573 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 574 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 575 576 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 577 578 579def inline_array_sql(self: Generator, expression: exp.Array) -> str: 580 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 581 582 583def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 584 elem = seq_get(expression.expressions, 0) 585 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 586 return self.func("ARRAY", elem) 587 return inline_array_sql(self, expression) 588 589 590def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 591 return self.like_sql( 592 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 593 ) 594 595 596def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 597 zone = self.sql(expression, "this") 598 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 599 600 601def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 602 if expression.args.get("recursive"): 603 self.unsupported("Recursive CTEs are unsupported") 604 expression.args["recursive"] = False 605 return self.with_sql(expression) 606 607 608def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 609 n = self.sql(expression, "this") 610 d = self.sql(expression, "expression") 611 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 612 613 614def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 615 self.unsupported("TABLESAMPLE unsupported") 616 return self.sql(expression.this) 617 618 619def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 620 self.unsupported("PIVOT unsupported") 621 return "" 622 623 624def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 625 return self.cast_sql(expression) 626 627 628def no_comment_column_constraint_sql( 629 self: Generator, expression: exp.CommentColumnConstraint 630) -> str: 631 self.unsupported("CommentColumnConstraint unsupported") 632 return "" 633 634 635def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 636 self.unsupported("MAP_FROM_ENTRIES unsupported") 637 return "" 638 639 640def str_position_sql( 641 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 642) -> str: 643 this = self.sql(expression, "this") 644 substr = self.sql(expression, "substr") 645 position = self.sql(expression, "position") 646 instance = expression.args.get("instance") if generate_instance else None 647 position_offset = "" 648 649 if position: 650 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 651 this = self.func("SUBSTR", this, position) 652 position_offset = f" + {position} - 1" 653 654 return self.func("STRPOS", this, substr, instance) + position_offset 655 656 657def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 658 return ( 659 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 660 ) 661 662 663def var_map_sql( 664 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 665) -> str: 666 keys = expression.args["keys"] 667 values = expression.args["values"] 668 669 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 670 self.unsupported("Cannot convert array columns into map.") 671 return self.func(map_func_name, keys, values) 672 673 args = [] 674 for key, value in zip(keys.expressions, values.expressions): 675 args.append(self.sql(key)) 676 args.append(self.sql(value)) 677 678 return self.func(map_func_name, *args) 679 680 681def build_formatted_time( 682 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 683) -> t.Callable[[t.List], E]: 684 """Helper used for time expressions. 685 686 Args: 687 exp_class: the expression class to instantiate. 688 dialect: target sql dialect. 689 default: the default format, True being time. 690 691 Returns: 692 A callable that can be used to return the appropriately formatted time expression. 693 """ 694 695 def _builder(args: t.List): 696 return exp_class( 697 this=seq_get(args, 0), 698 format=Dialect[dialect].format_time( 699 seq_get(args, 1) 700 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 701 ), 702 ) 703 704 return _builder 705 706 707def time_format( 708 dialect: DialectType = None, 709) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 710 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 711 """ 712 Returns the time format for a given expression, unless it's equivalent 713 to the default time format of the dialect of interest. 714 """ 715 time_format = self.format_time(expression) 716 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 717 718 return _time_format 719 720 721def build_date_delta( 722 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 723) -> t.Callable[[t.List], E]: 724 def _builder(args: t.List) -> E: 725 unit_based = len(args) == 3 726 this = args[2] if unit_based else seq_get(args, 0) 727 unit = args[0] if unit_based else exp.Literal.string("DAY") 728 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 729 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 730 731 return _builder 732 733 734def build_date_delta_with_interval( 735 expression_class: t.Type[E], 736) -> t.Callable[[t.List], t.Optional[E]]: 737 def _builder(args: t.List) -> t.Optional[E]: 738 if len(args) < 2: 739 return None 740 741 interval = args[1] 742 743 if not isinstance(interval, exp.Interval): 744 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 745 746 expression = interval.this 747 if expression and expression.is_string: 748 expression = exp.Literal.number(expression.this) 749 750 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 751 752 return _builder 753 754 755def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 756 unit = seq_get(args, 0) 757 this = seq_get(args, 1) 758 759 if isinstance(this, exp.Cast) and this.is_type("date"): 760 return exp.DateTrunc(unit=unit, this=this) 761 return exp.TimestampTrunc(this=this, unit=unit) 762 763 764def date_add_interval_sql( 765 data_type: str, kind: str 766) -> t.Callable[[Generator, exp.Expression], str]: 767 def func(self: Generator, expression: exp.Expression) -> str: 768 this = self.sql(expression, "this") 769 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 770 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 771 772 return func 773 774 775def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 776 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 777 args = [unit_to_str(expression), expression.this] 778 if zone: 779 args.append(expression.args.get("zone")) 780 return self.func("DATE_TRUNC", *args) 781 782 return _timestamptrunc_sql 783 784 785def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 786 if not expression.expression: 787 from sqlglot.optimizer.annotate_types import annotate_types 788 789 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 790 return self.sql(exp.cast(expression.this, target_type)) 791 if expression.text("expression").lower() in TIMEZONES: 792 return self.sql( 793 exp.AtTimeZone( 794 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 795 zone=expression.expression, 796 ) 797 ) 798 return self.func("TIMESTAMP", expression.this, expression.expression) 799 800 801def locate_to_strposition(args: t.List) -> exp.Expression: 802 return exp.StrPosition( 803 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 804 ) 805 806 807def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 808 return self.func( 809 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 810 ) 811 812 813def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 814 return self.sql( 815 exp.Substring( 816 this=expression.this, start=exp.Literal.number(1), length=expression.expression 817 ) 818 ) 819 820 821def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 822 return self.sql( 823 exp.Substring( 824 this=expression.this, 825 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 826 ) 827 ) 828 829 830def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 831 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 832 833 834def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 835 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 836 837 838# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 839def encode_decode_sql( 840 self: Generator, expression: exp.Expression, name: str, replace: bool = True 841) -> str: 842 charset = expression.args.get("charset") 843 if charset and charset.name.lower() != "utf-8": 844 self.unsupported(f"Expected utf-8 character set, got {charset}.") 845 846 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 847 848 849def min_or_least(self: Generator, expression: exp.Min) -> str: 850 name = "LEAST" if expression.expressions else "MIN" 851 return rename_func(name)(self, expression) 852 853 854def max_or_greatest(self: Generator, expression: exp.Max) -> str: 855 name = "GREATEST" if expression.expressions else "MAX" 856 return rename_func(name)(self, expression) 857 858 859def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 860 cond = expression.this 861 862 if isinstance(expression.this, exp.Distinct): 863 cond = expression.this.expressions[0] 864 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 865 866 return self.func("sum", exp.func("if", cond, 1, 0)) 867 868 869def trim_sql(self: Generator, expression: exp.Trim) -> str: 870 target = self.sql(expression, "this") 871 trim_type = self.sql(expression, "position") 872 remove_chars = self.sql(expression, "expression") 873 collation = self.sql(expression, "collation") 874 875 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 876 if not remove_chars and not collation: 877 return self.trim_sql(expression) 878 879 trim_type = f"{trim_type} " if trim_type else "" 880 remove_chars = f"{remove_chars} " if remove_chars else "" 881 from_part = "FROM " if trim_type or remove_chars else "" 882 collation = f" COLLATE {collation}" if collation else "" 883 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 884 885 886def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 887 return self.func("STRPTIME", expression.this, self.format_time(expression)) 888 889 890def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 891 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 892 893 894def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 895 delim, *rest_args = expression.expressions 896 return self.sql( 897 reduce( 898 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 899 rest_args, 900 ) 901 ) 902 903 904def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 905 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 906 if bad_args: 907 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 908 909 return self.func( 910 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 911 ) 912 913 914def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 915 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 916 if bad_args: 917 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 918 919 return self.func( 920 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 921 ) 922 923 924def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 925 names = [] 926 for agg in aggregations: 927 if isinstance(agg, exp.Alias): 928 names.append(agg.alias) 929 else: 930 """ 931 This case corresponds to aggregations without aliases being used as suffixes 932 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 933 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 934 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 935 """ 936 agg_all_unquoted = agg.transform( 937 lambda node: ( 938 exp.Identifier(this=node.name, quoted=False) 939 if isinstance(node, exp.Identifier) 940 else node 941 ) 942 ) 943 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 944 945 return names 946 947 948def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 949 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 950 951 952# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 953def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 954 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 955 956 957def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 958 return self.func("MAX", expression.this) 959 960 961def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 962 a = self.sql(expression.left) 963 b = self.sql(expression.right) 964 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 965 966 967def is_parse_json(expression: exp.Expression) -> bool: 968 return isinstance(expression, exp.ParseJSON) or ( 969 isinstance(expression, exp.Cast) and expression.is_type("json") 970 ) 971 972 973def isnull_to_is_null(args: t.List) -> exp.Expression: 974 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 975 976 977def generatedasidentitycolumnconstraint_sql( 978 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 979) -> str: 980 start = self.sql(expression, "start") or "1" 981 increment = self.sql(expression, "increment") or "1" 982 return f"IDENTITY({start}, {increment})" 983 984 985def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 986 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 987 if expression.args.get("count"): 988 self.unsupported(f"Only two arguments are supported in function {name}.") 989 990 return self.func(name, expression.this, expression.expression) 991 992 return _arg_max_or_min_sql 993 994 995def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 996 this = expression.this.copy() 997 998 return_type = expression.return_type 999 if return_type.is_type(exp.DataType.Type.DATE): 1000 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1001 # can truncate timestamp strings, because some dialects can't cast them to DATE 1002 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1003 1004 expression.this.replace(exp.cast(this, return_type)) 1005 return expression 1006 1007 1008def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1009 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1010 if cast and isinstance(expression, exp.TsOrDsAdd): 1011 expression = ts_or_ds_add_cast(expression) 1012 1013 return self.func( 1014 name, 1015 unit_to_var(expression), 1016 expression.expression, 1017 expression.this, 1018 ) 1019 1020 return _delta_sql 1021 1022 1023def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1024 unit = expression.args.get("unit") 1025 1026 if isinstance(unit, exp.Placeholder): 1027 return unit 1028 if unit: 1029 return exp.Literal.string(unit.name) 1030 return exp.Literal.string(default) if default else None 1031 1032 1033def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1034 unit = expression.args.get("unit") 1035 1036 if isinstance(unit, (exp.Var, exp.Placeholder)): 1037 return unit 1038 return exp.Var(this=default) if default else None 1039 1040 1041def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1042 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1043 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1044 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1045 1046 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1047 1048 1049def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1050 """Remove table refs from columns in when statements.""" 1051 alias = expression.this.args.get("alias") 1052 1053 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1054 return self.dialect.normalize_identifier(identifier).name if identifier else None 1055 1056 targets = {normalize(expression.this.this)} 1057 1058 if alias: 1059 targets.add(normalize(alias.this)) 1060 1061 for when in expression.expressions: 1062 when.transform( 1063 lambda node: ( 1064 exp.column(node.this) 1065 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1066 else node 1067 ), 1068 copy=False, 1069 ) 1070 1071 return self.merge_sql(expression) 1072 1073 1074def build_json_extract_path( 1075 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1076) -> t.Callable[[t.List], F]: 1077 def _builder(args: t.List) -> F: 1078 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1079 for arg in args[1:]: 1080 if not isinstance(arg, exp.Literal): 1081 # We use the fallback parser because we can't really transpile non-literals safely 1082 return expr_type.from_arg_list(args) 1083 1084 text = arg.name 1085 if is_int(text): 1086 index = int(text) 1087 segments.append( 1088 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1089 ) 1090 else: 1091 segments.append(exp.JSONPathKey(this=text)) 1092 1093 # This is done to avoid failing in the expression validator due to the arg count 1094 del args[2:] 1095 return expr_type( 1096 this=seq_get(args, 0), 1097 expression=exp.JSONPath(expressions=segments), 1098 only_json_types=arrow_req_json_type, 1099 ) 1100 1101 return _builder 1102 1103 1104def json_extract_segments( 1105 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1106) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1107 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1108 path = expression.expression 1109 if not isinstance(path, exp.JSONPath): 1110 return rename_func(name)(self, expression) 1111 1112 segments = [] 1113 for segment in path.expressions: 1114 path = self.sql(segment) 1115 if path: 1116 if isinstance(segment, exp.JSONPathPart) and ( 1117 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1118 ): 1119 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1120 1121 segments.append(path) 1122 1123 if op: 1124 return f" {op} ".join([self.sql(expression.this), *segments]) 1125 return self.func(name, expression.this, *segments) 1126 1127 return _json_extract_segments 1128 1129 1130def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1131 if isinstance(expression.this, exp.JSONPathWildcard): 1132 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1133 1134 return expression.name 1135 1136 1137def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1138 cond = expression.expression 1139 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1140 alias = cond.expressions[0] 1141 cond = cond.this 1142 elif isinstance(cond, exp.Predicate): 1143 alias = "_u" 1144 else: 1145 self.unsupported("Unsupported filter condition") 1146 return "" 1147 1148 unnest = exp.Unnest(expressions=[expression.this]) 1149 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1150 return self.sql(exp.Array(expressions=[filtered])) 1151 1152 1153def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1154 return self.func( 1155 "TO_NUMBER", 1156 expression.this, 1157 expression.args.get("format"), 1158 expression.args.get("nlsparam"), 1159 ) 1160 1161 1162def build_default_decimal_type( 1163 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1164) -> t.Callable[[exp.DataType], exp.DataType]: 1165 def _builder(dtype: exp.DataType) -> exp.DataType: 1166 if dtype.expressions or precision is None: 1167 return dtype 1168 1169 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1170 return exp.DataType.build(f"DECIMAL({params})") 1171 1172 return _builder
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
184class Dialect(metaclass=_Dialect): 185 INDEX_OFFSET = 0 186 """The base index offset for arrays.""" 187 188 WEEK_OFFSET = 0 189 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 190 191 UNNEST_COLUMN_ONLY = False 192 """Whether `UNNEST` table aliases are treated as column aliases.""" 193 194 ALIAS_POST_TABLESAMPLE = False 195 """Whether the table alias comes after tablesample.""" 196 197 TABLESAMPLE_SIZE_IS_PERCENT = False 198 """Whether a size in the table sample clause represents percentage.""" 199 200 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 201 """Specifies the strategy according to which identifiers should be normalized.""" 202 203 IDENTIFIERS_CAN_START_WITH_DIGIT = False 204 """Whether an unquoted identifier can start with a digit.""" 205 206 DPIPE_IS_STRING_CONCAT = True 207 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 208 209 STRICT_STRING_CONCAT = False 210 """Whether `CONCAT`'s arguments must be strings.""" 211 212 SUPPORTS_USER_DEFINED_TYPES = True 213 """Whether user-defined data types are supported.""" 214 215 SUPPORTS_SEMI_ANTI_JOIN = True 216 """Whether `SEMI` or `ANTI` joins are supported.""" 217 218 NORMALIZE_FUNCTIONS: bool | str = "upper" 219 """ 220 Determines how function names are going to be normalized. 221 Possible values: 222 "upper" or True: Convert names to uppercase. 223 "lower": Convert names to lowercase. 224 False: Disables function name normalization. 225 """ 226 227 LOG_BASE_FIRST: t.Optional[bool] = True 228 """ 229 Whether the base comes first in the `LOG` function. 230 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 231 """ 232 233 NULL_ORDERING = "nulls_are_small" 234 """ 235 Default `NULL` ordering method to use if not explicitly set. 236 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 237 """ 238 239 TYPED_DIVISION = False 240 """ 241 Whether the behavior of `a / b` depends on the types of `a` and `b`. 242 False means `a / b` is always float division. 243 True means `a / b` is integer division if both `a` and `b` are integers. 244 """ 245 246 SAFE_DIVISION = False 247 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 248 249 CONCAT_COALESCE = False 250 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 251 252 HEX_LOWERCASE = False 253 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 254 255 DATE_FORMAT = "'%Y-%m-%d'" 256 DATEINT_FORMAT = "'%Y%m%d'" 257 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 258 259 TIME_MAPPING: t.Dict[str, str] = {} 260 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 261 262 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 263 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 264 FORMAT_MAPPING: t.Dict[str, str] = {} 265 """ 266 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 267 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 268 """ 269 270 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 271 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 272 273 PSEUDOCOLUMNS: t.Set[str] = set() 274 """ 275 Columns that are auto-generated by the engine corresponding to this dialect. 276 For example, such columns may be excluded from `SELECT *` queries. 277 """ 278 279 PREFER_CTE_ALIAS_COLUMN = False 280 """ 281 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 282 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 283 any projection aliases in the subquery. 284 285 For example, 286 WITH y(c) AS ( 287 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 288 ) SELECT c FROM y; 289 290 will be rewritten as 291 292 WITH y(c) AS ( 293 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 294 ) SELECT c FROM y; 295 """ 296 297 # --- Autofilled --- 298 299 tokenizer_class = Tokenizer 300 parser_class = Parser 301 generator_class = Generator 302 303 # A trie of the time_mapping keys 304 TIME_TRIE: t.Dict = {} 305 FORMAT_TRIE: t.Dict = {} 306 307 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 308 INVERSE_TIME_TRIE: t.Dict = {} 309 310 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 311 312 # Delimiters for string literals and identifiers 313 QUOTE_START = "'" 314 QUOTE_END = "'" 315 IDENTIFIER_START = '"' 316 IDENTIFIER_END = '"' 317 318 # Delimiters for bit, hex, byte and unicode literals 319 BIT_START: t.Optional[str] = None 320 BIT_END: t.Optional[str] = None 321 HEX_START: t.Optional[str] = None 322 HEX_END: t.Optional[str] = None 323 BYTE_START: t.Optional[str] = None 324 BYTE_END: t.Optional[str] = None 325 UNICODE_START: t.Optional[str] = None 326 UNICODE_END: t.Optional[str] = None 327 328 # Separator of COPY statement parameters 329 COPY_PARAMS_ARE_CSV = True 330 331 @classmethod 332 def get_or_raise(cls, dialect: DialectType) -> Dialect: 333 """ 334 Look up a dialect in the global dialect registry and return it if it exists. 335 336 Args: 337 dialect: The target dialect. If this is a string, it can be optionally followed by 338 additional key-value pairs that are separated by commas and are used to specify 339 dialect settings, such as whether the dialect's identifiers are case-sensitive. 340 341 Example: 342 >>> dialect = dialect_class = get_or_raise("duckdb") 343 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 344 345 Returns: 346 The corresponding Dialect instance. 347 """ 348 349 if not dialect: 350 return cls() 351 if isinstance(dialect, _Dialect): 352 return dialect() 353 if isinstance(dialect, Dialect): 354 return dialect 355 if isinstance(dialect, str): 356 try: 357 dialect_name, *kv_pairs = dialect.split(",") 358 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 359 except ValueError: 360 raise ValueError( 361 f"Invalid dialect format: '{dialect}'. " 362 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 363 ) 364 365 result = cls.get(dialect_name.strip()) 366 if not result: 367 from difflib import get_close_matches 368 369 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 370 if similar: 371 similar = f" Did you mean {similar}?" 372 373 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 374 375 return result(**kwargs) 376 377 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 378 379 @classmethod 380 def format_time( 381 cls, expression: t.Optional[str | exp.Expression] 382 ) -> t.Optional[exp.Expression]: 383 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 384 if isinstance(expression, str): 385 return exp.Literal.string( 386 # the time formats are quoted 387 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 388 ) 389 390 if expression and expression.is_string: 391 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 392 393 return expression 394 395 def __init__(self, **kwargs) -> None: 396 normalization_strategy = kwargs.get("normalization_strategy") 397 398 if normalization_strategy is None: 399 self.normalization_strategy = self.NORMALIZATION_STRATEGY 400 else: 401 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 402 403 def __eq__(self, other: t.Any) -> bool: 404 # Does not currently take dialect state into account 405 return type(self) == other 406 407 def __hash__(self) -> int: 408 # Does not currently take dialect state into account 409 return hash(type(self)) 410 411 def normalize_identifier(self, expression: E) -> E: 412 """ 413 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 414 415 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 416 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 417 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 418 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 419 420 There are also dialects like Spark, which are case-insensitive even when quotes are 421 present, and dialects like MySQL, whose resolution rules match those employed by the 422 underlying operating system, for example they may always be case-sensitive in Linux. 423 424 Finally, the normalization behavior of some engines can even be controlled through flags, 425 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 426 427 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 428 that it can analyze queries in the optimizer and successfully capture their semantics. 429 """ 430 if ( 431 isinstance(expression, exp.Identifier) 432 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 433 and ( 434 not expression.quoted 435 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 436 ) 437 ): 438 expression.set( 439 "this", 440 ( 441 expression.this.upper() 442 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 443 else expression.this.lower() 444 ), 445 ) 446 447 return expression 448 449 def case_sensitive(self, text: str) -> bool: 450 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 451 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 452 return False 453 454 unsafe = ( 455 str.islower 456 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 457 else str.isupper 458 ) 459 return any(unsafe(char) for char in text) 460 461 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 462 """Checks if text can be identified given an identify option. 463 464 Args: 465 text: The text to check. 466 identify: 467 `"always"` or `True`: Always returns `True`. 468 `"safe"`: Only returns `True` if the identifier is case-insensitive. 469 470 Returns: 471 Whether the given text can be identified. 472 """ 473 if identify is True or identify == "always": 474 return True 475 476 if identify == "safe": 477 return not self.case_sensitive(text) 478 479 return False 480 481 def quote_identifier(self, expression: E, identify: bool = True) -> E: 482 """ 483 Adds quotes to a given identifier. 484 485 Args: 486 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 487 identify: If set to `False`, the quotes will only be added if the identifier is deemed 488 "unsafe", with respect to its characters and this dialect's normalization strategy. 489 """ 490 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 491 name = expression.this 492 expression.set( 493 "quoted", 494 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 495 ) 496 497 return expression 498 499 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 500 if isinstance(path, exp.Literal): 501 path_text = path.name 502 if path.is_number: 503 path_text = f"[{path_text}]" 504 505 try: 506 return parse_json_path(path_text) 507 except ParseError as e: 508 logger.warning(f"Invalid JSON path syntax. {str(e)}") 509 510 return path 511 512 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 513 return self.parser(**opts).parse(self.tokenize(sql), sql) 514 515 def parse_into( 516 self, expression_type: exp.IntoType, sql: str, **opts 517 ) -> t.List[t.Optional[exp.Expression]]: 518 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 519 520 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 521 return self.generator(**opts).generate(expression, copy=copy) 522 523 def transpile(self, sql: str, **opts) -> t.List[str]: 524 return [ 525 self.generate(expression, copy=False, **opts) if expression else "" 526 for expression in self.parse(sql) 527 ] 528 529 def tokenize(self, sql: str) -> t.List[Token]: 530 return self.tokenizer.tokenize(sql) 531 532 @property 533 def tokenizer(self) -> Tokenizer: 534 if not hasattr(self, "_tokenizer"): 535 self._tokenizer = self.tokenizer_class(dialect=self) 536 return self._tokenizer 537 538 def parser(self, **opts) -> Parser: 539 return self.parser_class(dialect=self, **opts) 540 541 def generator(self, **opts) -> Generator: 542 return self.generator_class(dialect=self, **opts)
395 def __init__(self, **kwargs) -> None: 396 normalization_strategy = kwargs.get("normalization_strategy") 397 398 if normalization_strategy is None: 399 self.normalization_strategy = self.NORMALIZATION_STRATEGY 400 else: 401 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
331 @classmethod 332 def get_or_raise(cls, dialect: DialectType) -> Dialect: 333 """ 334 Look up a dialect in the global dialect registry and return it if it exists. 335 336 Args: 337 dialect: The target dialect. If this is a string, it can be optionally followed by 338 additional key-value pairs that are separated by commas and are used to specify 339 dialect settings, such as whether the dialect's identifiers are case-sensitive. 340 341 Example: 342 >>> dialect = dialect_class = get_or_raise("duckdb") 343 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 344 345 Returns: 346 The corresponding Dialect instance. 347 """ 348 349 if not dialect: 350 return cls() 351 if isinstance(dialect, _Dialect): 352 return dialect() 353 if isinstance(dialect, Dialect): 354 return dialect 355 if isinstance(dialect, str): 356 try: 357 dialect_name, *kv_pairs = dialect.split(",") 358 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 359 except ValueError: 360 raise ValueError( 361 f"Invalid dialect format: '{dialect}'. " 362 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 363 ) 364 365 result = cls.get(dialect_name.strip()) 366 if not result: 367 from difflib import get_close_matches 368 369 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 370 if similar: 371 similar = f" Did you mean {similar}?" 372 373 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 374 375 return result(**kwargs) 376 377 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
379 @classmethod 380 def format_time( 381 cls, expression: t.Optional[str | exp.Expression] 382 ) -> t.Optional[exp.Expression]: 383 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 384 if isinstance(expression, str): 385 return exp.Literal.string( 386 # the time formats are quoted 387 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 388 ) 389 390 if expression and expression.is_string: 391 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 392 393 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
411 def normalize_identifier(self, expression: E) -> E: 412 """ 413 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 414 415 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 416 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 417 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 418 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 419 420 There are also dialects like Spark, which are case-insensitive even when quotes are 421 present, and dialects like MySQL, whose resolution rules match those employed by the 422 underlying operating system, for example they may always be case-sensitive in Linux. 423 424 Finally, the normalization behavior of some engines can even be controlled through flags, 425 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 426 427 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 428 that it can analyze queries in the optimizer and successfully capture their semantics. 429 """ 430 if ( 431 isinstance(expression, exp.Identifier) 432 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 433 and ( 434 not expression.quoted 435 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 436 ) 437 ): 438 expression.set( 439 "this", 440 ( 441 expression.this.upper() 442 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 443 else expression.this.lower() 444 ), 445 ) 446 447 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
449 def case_sensitive(self, text: str) -> bool: 450 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 451 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 452 return False 453 454 unsafe = ( 455 str.islower 456 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 457 else str.isupper 458 ) 459 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
461 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 462 """Checks if text can be identified given an identify option. 463 464 Args: 465 text: The text to check. 466 identify: 467 `"always"` or `True`: Always returns `True`. 468 `"safe"`: Only returns `True` if the identifier is case-insensitive. 469 470 Returns: 471 Whether the given text can be identified. 472 """ 473 if identify is True or identify == "always": 474 return True 475 476 if identify == "safe": 477 return not self.case_sensitive(text) 478 479 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
481 def quote_identifier(self, expression: E, identify: bool = True) -> E: 482 """ 483 Adds quotes to a given identifier. 484 485 Args: 486 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 487 identify: If set to `False`, the quotes will only be added if the identifier is deemed 488 "unsafe", with respect to its characters and this dialect's normalization strategy. 489 """ 490 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 491 name = expression.this 492 expression.set( 493 "quoted", 494 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 495 ) 496 497 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
499 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 500 if isinstance(path, exp.Literal): 501 path_text = path.name 502 if path.is_number: 503 path_text = f"[{path_text}]" 504 505 try: 506 return parse_json_path(path_text) 507 except ParseError as e: 508 logger.warning(f"Invalid JSON path syntax. {str(e)}") 509 510 return path
558def if_sql( 559 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 560) -> t.Callable[[Generator, exp.If], str]: 561 def _if_sql(self: Generator, expression: exp.If) -> str: 562 return self.func( 563 name, 564 expression.this, 565 expression.args.get("true"), 566 expression.args.get("false") or false_value, 567 ) 568 569 return _if_sql
572def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 573 this = expression.this 574 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 575 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 576 577 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
641def str_position_sql( 642 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 643) -> str: 644 this = self.sql(expression, "this") 645 substr = self.sql(expression, "substr") 646 position = self.sql(expression, "position") 647 instance = expression.args.get("instance") if generate_instance else None 648 position_offset = "" 649 650 if position: 651 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 652 this = self.func("SUBSTR", this, position) 653 position_offset = f" + {position} - 1" 654 655 return self.func("STRPOS", this, substr, instance) + position_offset
664def var_map_sql( 665 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 666) -> str: 667 keys = expression.args["keys"] 668 values = expression.args["values"] 669 670 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 671 self.unsupported("Cannot convert array columns into map.") 672 return self.func(map_func_name, keys, values) 673 674 args = [] 675 for key, value in zip(keys.expressions, values.expressions): 676 args.append(self.sql(key)) 677 args.append(self.sql(value)) 678 679 return self.func(map_func_name, *args)
682def build_formatted_time( 683 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 684) -> t.Callable[[t.List], E]: 685 """Helper used for time expressions. 686 687 Args: 688 exp_class: the expression class to instantiate. 689 dialect: target sql dialect. 690 default: the default format, True being time. 691 692 Returns: 693 A callable that can be used to return the appropriately formatted time expression. 694 """ 695 696 def _builder(args: t.List): 697 return exp_class( 698 this=seq_get(args, 0), 699 format=Dialect[dialect].format_time( 700 seq_get(args, 1) 701 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 702 ), 703 ) 704 705 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
708def time_format( 709 dialect: DialectType = None, 710) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 711 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 712 """ 713 Returns the time format for a given expression, unless it's equivalent 714 to the default time format of the dialect of interest. 715 """ 716 time_format = self.format_time(expression) 717 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 718 719 return _time_format
722def build_date_delta( 723 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 724) -> t.Callable[[t.List], E]: 725 def _builder(args: t.List) -> E: 726 unit_based = len(args) == 3 727 this = args[2] if unit_based else seq_get(args, 0) 728 unit = args[0] if unit_based else exp.Literal.string("DAY") 729 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 730 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 731 732 return _builder
735def build_date_delta_with_interval( 736 expression_class: t.Type[E], 737) -> t.Callable[[t.List], t.Optional[E]]: 738 def _builder(args: t.List) -> t.Optional[E]: 739 if len(args) < 2: 740 return None 741 742 interval = args[1] 743 744 if not isinstance(interval, exp.Interval): 745 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 746 747 expression = interval.this 748 if expression and expression.is_string: 749 expression = exp.Literal.number(expression.this) 750 751 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 752 753 return _builder
765def date_add_interval_sql( 766 data_type: str, kind: str 767) -> t.Callable[[Generator, exp.Expression], str]: 768 def func(self: Generator, expression: exp.Expression) -> str: 769 this = self.sql(expression, "this") 770 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 771 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 772 773 return func
776def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 777 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 778 args = [unit_to_str(expression), expression.this] 779 if zone: 780 args.append(expression.args.get("zone")) 781 return self.func("DATE_TRUNC", *args) 782 783 return _timestamptrunc_sql
786def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 787 if not expression.expression: 788 from sqlglot.optimizer.annotate_types import annotate_types 789 790 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 791 return self.sql(exp.cast(expression.this, target_type)) 792 if expression.text("expression").lower() in TIMEZONES: 793 return self.sql( 794 exp.AtTimeZone( 795 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 796 zone=expression.expression, 797 ) 798 ) 799 return self.func("TIMESTAMP", expression.this, expression.expression)
840def encode_decode_sql( 841 self: Generator, expression: exp.Expression, name: str, replace: bool = True 842) -> str: 843 charset = expression.args.get("charset") 844 if charset and charset.name.lower() != "utf-8": 845 self.unsupported(f"Expected utf-8 character set, got {charset}.") 846 847 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
860def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 861 cond = expression.this 862 863 if isinstance(expression.this, exp.Distinct): 864 cond = expression.this.expressions[0] 865 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 866 867 return self.func("sum", exp.func("if", cond, 1, 0))
870def trim_sql(self: Generator, expression: exp.Trim) -> str: 871 target = self.sql(expression, "this") 872 trim_type = self.sql(expression, "position") 873 remove_chars = self.sql(expression, "expression") 874 collation = self.sql(expression, "collation") 875 876 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 877 if not remove_chars and not collation: 878 return self.trim_sql(expression) 879 880 trim_type = f"{trim_type} " if trim_type else "" 881 remove_chars = f"{remove_chars} " if remove_chars else "" 882 from_part = "FROM " if trim_type or remove_chars else "" 883 collation = f" COLLATE {collation}" if collation else "" 884 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
905def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 906 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 907 if bad_args: 908 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 909 910 return self.func( 911 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 912 )
915def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 916 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 917 if bad_args: 918 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 919 920 return self.func( 921 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 922 )
925def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 926 names = [] 927 for agg in aggregations: 928 if isinstance(agg, exp.Alias): 929 names.append(agg.alias) 930 else: 931 """ 932 This case corresponds to aggregations without aliases being used as suffixes 933 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 934 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 935 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 936 """ 937 agg_all_unquoted = agg.transform( 938 lambda node: ( 939 exp.Identifier(this=node.name, quoted=False) 940 if isinstance(node, exp.Identifier) 941 else node 942 ) 943 ) 944 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 945 946 return names
986def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 987 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 988 if expression.args.get("count"): 989 self.unsupported(f"Only two arguments are supported in function {name}.") 990 991 return self.func(name, expression.this, expression.expression) 992 993 return _arg_max_or_min_sql
996def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 997 this = expression.this.copy() 998 999 return_type = expression.return_type 1000 if return_type.is_type(exp.DataType.Type.DATE): 1001 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1002 # can truncate timestamp strings, because some dialects can't cast them to DATE 1003 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1004 1005 expression.this.replace(exp.cast(this, return_type)) 1006 return expression
1009def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1010 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1011 if cast and isinstance(expression, exp.TsOrDsAdd): 1012 expression = ts_or_ds_add_cast(expression) 1013 1014 return self.func( 1015 name, 1016 unit_to_var(expression), 1017 expression.expression, 1018 expression.this, 1019 ) 1020 1021 return _delta_sql
1024def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1025 unit = expression.args.get("unit") 1026 1027 if isinstance(unit, exp.Placeholder): 1028 return unit 1029 if unit: 1030 return exp.Literal.string(unit.name) 1031 return exp.Literal.string(default) if default else None
1042def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1043 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1044 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1045 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1046 1047 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1050def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1051 """Remove table refs from columns in when statements.""" 1052 alias = expression.this.args.get("alias") 1053 1054 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1055 return self.dialect.normalize_identifier(identifier).name if identifier else None 1056 1057 targets = {normalize(expression.this.this)} 1058 1059 if alias: 1060 targets.add(normalize(alias.this)) 1061 1062 for when in expression.expressions: 1063 when.transform( 1064 lambda node: ( 1065 exp.column(node.this) 1066 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1067 else node 1068 ), 1069 copy=False, 1070 ) 1071 1072 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1075def build_json_extract_path( 1076 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1077) -> t.Callable[[t.List], F]: 1078 def _builder(args: t.List) -> F: 1079 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1080 for arg in args[1:]: 1081 if not isinstance(arg, exp.Literal): 1082 # We use the fallback parser because we can't really transpile non-literals safely 1083 return expr_type.from_arg_list(args) 1084 1085 text = arg.name 1086 if is_int(text): 1087 index = int(text) 1088 segments.append( 1089 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1090 ) 1091 else: 1092 segments.append(exp.JSONPathKey(this=text)) 1093 1094 # This is done to avoid failing in the expression validator due to the arg count 1095 del args[2:] 1096 return expr_type( 1097 this=seq_get(args, 0), 1098 expression=exp.JSONPath(expressions=segments), 1099 only_json_types=arrow_req_json_type, 1100 ) 1101 1102 return _builder
1105def json_extract_segments( 1106 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1107) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1108 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1109 path = expression.expression 1110 if not isinstance(path, exp.JSONPath): 1111 return rename_func(name)(self, expression) 1112 1113 segments = [] 1114 for segment in path.expressions: 1115 path = self.sql(segment) 1116 if path: 1117 if isinstance(segment, exp.JSONPathPart) and ( 1118 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1119 ): 1120 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1121 1122 segments.append(path) 1123 1124 if op: 1125 return f" {op} ".join([self.sql(expression.this), *segments]) 1126 return self.func(name, expression.this, *segments) 1127 1128 return _json_extract_segments
1138def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1139 cond = expression.expression 1140 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1141 alias = cond.expressions[0] 1142 cond = cond.this 1143 elif isinstance(cond, exp.Predicate): 1144 alias = "_u" 1145 else: 1146 self.unsupported("Unsupported filter condition") 1147 return "" 1148 1149 unnest = exp.Unnest(expressions=[expression.this]) 1150 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1151 return self.sql(exp.Array(expressions=[filtered]))
1163def build_default_decimal_type( 1164 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1165) -> t.Callable[[exp.DataType], exp.DataType]: 1166 def _builder(dtype: exp.DataType) -> exp.DataType: 1167 if dtype.expressions or precision is None: 1168 return dtype 1169 1170 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1171 return exp.DataType.build(f"DECIMAL({params})") 1172 1173 return _builder