sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20 21if t.TYPE_CHECKING: 22 from sqlglot._typing import B, E, F 23 24 JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 BIGQUERY = "bigquery" 35 CLICKHOUSE = "clickhouse" 36 DATABRICKS = "databricks" 37 DORIS = "doris" 38 DRILL = "drill" 39 DUCKDB = "duckdb" 40 HIVE = "hive" 41 MYSQL = "mysql" 42 ORACLE = "oracle" 43 POSTGRES = "postgres" 44 PRESTO = "presto" 45 REDSHIFT = "redshift" 46 SNOWFLAKE = "snowflake" 47 SPARK = "spark" 48 SPARK2 = "spark2" 49 SQLITE = "sqlite" 50 STARROCKS = "starrocks" 51 TABLEAU = "tableau" 52 TERADATA = "teradata" 53 TRINO = "trino" 54 TSQL = "tsql" 55 56 57class NormalizationStrategy(str, AutoName): 58 """Specifies the strategy according to which identifiers should be normalized.""" 59 60 LOWERCASE = auto() 61 """Unquoted identifiers are lowercased.""" 62 63 UPPERCASE = auto() 64 """Unquoted identifiers are uppercased.""" 65 66 CASE_SENSITIVE = auto() 67 """Always case-sensitive, regardless of quotes.""" 68 69 CASE_INSENSITIVE = auto() 70 """Always case-insensitive, regardless of quotes.""" 71 72 73class _Dialect(type): 74 classes: t.Dict[str, t.Type[Dialect]] = {} 75 76 def __eq__(cls, other: t.Any) -> bool: 77 if cls is other: 78 return True 79 if isinstance(other, str): 80 return cls is cls.get(other) 81 if isinstance(other, Dialect): 82 return cls is type(other) 83 84 return False 85 86 def __hash__(cls) -> int: 87 return hash(cls.__name__.lower()) 88 89 @classmethod 90 def __getitem__(cls, key: str) -> t.Type[Dialect]: 91 return cls.classes[key] 92 93 @classmethod 94 def get( 95 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 96 ) -> t.Optional[t.Type[Dialect]]: 97 return cls.classes.get(key, default) 98 99 def __new__(cls, clsname, bases, attrs): 100 klass = super().__new__(cls, clsname, bases, attrs) 101 enum = Dialects.__members__.get(clsname.upper()) 102 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 103 104 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 105 klass.FORMAT_TRIE = ( 106 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 107 ) 108 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 109 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 110 111 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 112 113 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 114 klass.parser_class = getattr(klass, "Parser", Parser) 115 klass.generator_class = getattr(klass, "Generator", Generator) 116 117 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 118 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 119 klass.tokenizer_class._IDENTIFIERS.items() 120 )[0] 121 122 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 123 return next( 124 ( 125 (s, e) 126 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 127 if t == token_type 128 ), 129 (None, None), 130 ) 131 132 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 133 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 134 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 135 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 136 137 if enum not in ("", "bigquery"): 138 klass.generator_class.SELECT_KINDS = () 139 140 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 141 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 142 TokenType.ANTI, 143 TokenType.SEMI, 144 } 145 146 return klass 147 148 149class Dialect(metaclass=_Dialect): 150 INDEX_OFFSET = 0 151 """Determines the base index offset for arrays.""" 152 153 WEEK_OFFSET = 0 154 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 155 156 UNNEST_COLUMN_ONLY = False 157 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 158 159 ALIAS_POST_TABLESAMPLE = False 160 """Determines whether or not the table alias comes after tablesample.""" 161 162 TABLESAMPLE_SIZE_IS_PERCENT = False 163 """Determines whether or not a size in the table sample clause represents percentage.""" 164 165 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 166 """Specifies the strategy according to which identifiers should be normalized.""" 167 168 IDENTIFIERS_CAN_START_WITH_DIGIT = False 169 """Determines whether or not an unquoted identifier can start with a digit.""" 170 171 DPIPE_IS_STRING_CONCAT = True 172 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 173 174 STRICT_STRING_CONCAT = False 175 """Determines whether or not `CONCAT`'s arguments must be strings.""" 176 177 SUPPORTS_USER_DEFINED_TYPES = True 178 """Determines whether or not user-defined data types are supported.""" 179 180 SUPPORTS_SEMI_ANTI_JOIN = True 181 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 182 183 NORMALIZE_FUNCTIONS: bool | str = "upper" 184 """Determines how function names are going to be normalized.""" 185 186 LOG_BASE_FIRST = True 187 """Determines whether the base comes first in the `LOG` function.""" 188 189 NULL_ORDERING = "nulls_are_small" 190 """ 191 Indicates the default `NULL` ordering method to use if not explicitly set. 192 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 193 """ 194 195 TYPED_DIVISION = False 196 """ 197 Whether the behavior of `a / b` depends on the types of `a` and `b`. 198 False means `a / b` is always float division. 199 True means `a / b` is integer division if both `a` and `b` are integers. 200 """ 201 202 SAFE_DIVISION = False 203 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 204 205 CONCAT_COALESCE = False 206 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 207 208 DATE_FORMAT = "'%Y-%m-%d'" 209 DATEINT_FORMAT = "'%Y%m%d'" 210 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 211 212 TIME_MAPPING: t.Dict[str, str] = {} 213 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 214 215 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 216 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 217 FORMAT_MAPPING: t.Dict[str, str] = {} 218 """ 219 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 220 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 221 """ 222 223 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 224 """Mapping of an unescaped escape sequence to the corresponding character.""" 225 226 PSEUDOCOLUMNS: t.Set[str] = set() 227 """ 228 Columns that are auto-generated by the engine corresponding to this dialect. 229 For example, such columns may be excluded from `SELECT *` queries. 230 """ 231 232 PREFER_CTE_ALIAS_COLUMN = False 233 """ 234 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 235 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 236 any projection aliases in the subquery. 237 238 For example, 239 WITH y(c) AS ( 240 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 241 ) SELECT c FROM y; 242 243 will be rewritten as 244 245 WITH y(c) AS ( 246 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 247 ) SELECT c FROM y; 248 """ 249 250 # --- Autofilled --- 251 252 tokenizer_class = Tokenizer 253 parser_class = Parser 254 generator_class = Generator 255 256 # A trie of the time_mapping keys 257 TIME_TRIE: t.Dict = {} 258 FORMAT_TRIE: t.Dict = {} 259 260 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 261 INVERSE_TIME_TRIE: t.Dict = {} 262 263 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 264 265 # Delimiters for string literals and identifiers 266 QUOTE_START = "'" 267 QUOTE_END = "'" 268 IDENTIFIER_START = '"' 269 IDENTIFIER_END = '"' 270 271 # Delimiters for bit, hex, byte and unicode literals 272 BIT_START: t.Optional[str] = None 273 BIT_END: t.Optional[str] = None 274 HEX_START: t.Optional[str] = None 275 HEX_END: t.Optional[str] = None 276 BYTE_START: t.Optional[str] = None 277 BYTE_END: t.Optional[str] = None 278 UNICODE_START: t.Optional[str] = None 279 UNICODE_END: t.Optional[str] = None 280 281 @classmethod 282 def get_or_raise(cls, dialect: DialectType) -> Dialect: 283 """ 284 Look up a dialect in the global dialect registry and return it if it exists. 285 286 Args: 287 dialect: The target dialect. If this is a string, it can be optionally followed by 288 additional key-value pairs that are separated by commas and are used to specify 289 dialect settings, such as whether the dialect's identifiers are case-sensitive. 290 291 Example: 292 >>> dialect = dialect_class = get_or_raise("duckdb") 293 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 294 295 Returns: 296 The corresponding Dialect instance. 297 """ 298 299 if not dialect: 300 return cls() 301 if isinstance(dialect, _Dialect): 302 return dialect() 303 if isinstance(dialect, Dialect): 304 return dialect 305 if isinstance(dialect, str): 306 try: 307 dialect_name, *kv_pairs = dialect.split(",") 308 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 309 except ValueError: 310 raise ValueError( 311 f"Invalid dialect format: '{dialect}'. " 312 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 313 ) 314 315 result = cls.get(dialect_name.strip()) 316 if not result: 317 from difflib import get_close_matches 318 319 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 320 if similar: 321 similar = f" Did you mean {similar}?" 322 323 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 324 325 return result(**kwargs) 326 327 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 328 329 @classmethod 330 def format_time( 331 cls, expression: t.Optional[str | exp.Expression] 332 ) -> t.Optional[exp.Expression]: 333 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 334 if isinstance(expression, str): 335 return exp.Literal.string( 336 # the time formats are quoted 337 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 338 ) 339 340 if expression and expression.is_string: 341 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 342 343 return expression 344 345 def __init__(self, **kwargs) -> None: 346 normalization_strategy = kwargs.get("normalization_strategy") 347 348 if normalization_strategy is None: 349 self.normalization_strategy = self.NORMALIZATION_STRATEGY 350 else: 351 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 352 353 def __eq__(self, other: t.Any) -> bool: 354 # Does not currently take dialect state into account 355 return type(self) == other 356 357 def __hash__(self) -> int: 358 # Does not currently take dialect state into account 359 return hash(type(self)) 360 361 def normalize_identifier(self, expression: E) -> E: 362 """ 363 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 364 365 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 366 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 367 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 368 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 369 370 There are also dialects like Spark, which are case-insensitive even when quotes are 371 present, and dialects like MySQL, whose resolution rules match those employed by the 372 underlying operating system, for example they may always be case-sensitive in Linux. 373 374 Finally, the normalization behavior of some engines can even be controlled through flags, 375 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 376 377 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 378 that it can analyze queries in the optimizer and successfully capture their semantics. 379 """ 380 if ( 381 isinstance(expression, exp.Identifier) 382 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 383 and ( 384 not expression.quoted 385 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 386 ) 387 ): 388 expression.set( 389 "this", 390 ( 391 expression.this.upper() 392 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 393 else expression.this.lower() 394 ), 395 ) 396 397 return expression 398 399 def case_sensitive(self, text: str) -> bool: 400 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 401 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 402 return False 403 404 unsafe = ( 405 str.islower 406 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 407 else str.isupper 408 ) 409 return any(unsafe(char) for char in text) 410 411 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 412 """Checks if text can be identified given an identify option. 413 414 Args: 415 text: The text to check. 416 identify: 417 `"always"` or `True`: Always returns `True`. 418 `"safe"`: Only returns `True` if the identifier is case-insensitive. 419 420 Returns: 421 Whether or not the given text can be identified. 422 """ 423 if identify is True or identify == "always": 424 return True 425 426 if identify == "safe": 427 return not self.case_sensitive(text) 428 429 return False 430 431 def quote_identifier(self, expression: E, identify: bool = True) -> E: 432 """ 433 Adds quotes to a given identifier. 434 435 Args: 436 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 437 identify: If set to `False`, the quotes will only be added if the identifier is deemed 438 "unsafe", with respect to its characters and this dialect's normalization strategy. 439 """ 440 if isinstance(expression, exp.Identifier): 441 name = expression.this 442 expression.set( 443 "quoted", 444 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 445 ) 446 447 return expression 448 449 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 450 if isinstance(path, exp.Literal): 451 path_text = path.name 452 if path.is_number: 453 path_text = f"[{path_text}]" 454 455 try: 456 return parse_json_path(path_text) 457 except ParseError as e: 458 logger.warning(f"Invalid JSON path syntax. {str(e)}") 459 460 return path 461 462 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 463 return self.parser(**opts).parse(self.tokenize(sql), sql) 464 465 def parse_into( 466 self, expression_type: exp.IntoType, sql: str, **opts 467 ) -> t.List[t.Optional[exp.Expression]]: 468 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 469 470 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 471 return self.generator(**opts).generate(expression, copy=copy) 472 473 def transpile(self, sql: str, **opts) -> t.List[str]: 474 return [ 475 self.generate(expression, copy=False, **opts) if expression else "" 476 for expression in self.parse(sql) 477 ] 478 479 def tokenize(self, sql: str) -> t.List[Token]: 480 return self.tokenizer.tokenize(sql) 481 482 @property 483 def tokenizer(self) -> Tokenizer: 484 if not hasattr(self, "_tokenizer"): 485 self._tokenizer = self.tokenizer_class(dialect=self) 486 return self._tokenizer 487 488 def parser(self, **opts) -> Parser: 489 return self.parser_class(dialect=self, **opts) 490 491 def generator(self, **opts) -> Generator: 492 return self.generator_class(dialect=self, **opts) 493 494 495DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 496 497 498def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 499 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 500 501 502def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 503 if expression.args.get("accuracy"): 504 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 505 return self.func("APPROX_COUNT_DISTINCT", expression.this) 506 507 508def if_sql( 509 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 510) -> t.Callable[[Generator, exp.If], str]: 511 def _if_sql(self: Generator, expression: exp.If) -> str: 512 return self.func( 513 name, 514 expression.this, 515 expression.args.get("true"), 516 expression.args.get("false") or false_value, 517 ) 518 519 return _if_sql 520 521 522def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 523 this = expression.this 524 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 525 this.replace(exp.cast(this, "json")) 526 527 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 528 529 530def inline_array_sql(self: Generator, expression: exp.Array) -> str: 531 return f"[{self.expressions(expression, flat=True)}]" 532 533 534def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 535 return self.like_sql( 536 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 537 ) 538 539 540def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 541 zone = self.sql(expression, "this") 542 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 543 544 545def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 546 if expression.args.get("recursive"): 547 self.unsupported("Recursive CTEs are unsupported") 548 expression.args["recursive"] = False 549 return self.with_sql(expression) 550 551 552def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 553 n = self.sql(expression, "this") 554 d = self.sql(expression, "expression") 555 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 556 557 558def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 559 self.unsupported("TABLESAMPLE unsupported") 560 return self.sql(expression.this) 561 562 563def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 564 self.unsupported("PIVOT unsupported") 565 return "" 566 567 568def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 569 return self.cast_sql(expression) 570 571 572def no_comment_column_constraint_sql( 573 self: Generator, expression: exp.CommentColumnConstraint 574) -> str: 575 self.unsupported("CommentColumnConstraint unsupported") 576 return "" 577 578 579def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 580 self.unsupported("MAP_FROM_ENTRIES unsupported") 581 return "" 582 583 584def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 585 this = self.sql(expression, "this") 586 substr = self.sql(expression, "substr") 587 position = self.sql(expression, "position") 588 if position: 589 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 590 return f"STRPOS({this}, {substr})" 591 592 593def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 594 return ( 595 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 596 ) 597 598 599def var_map_sql( 600 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 601) -> str: 602 keys = expression.args["keys"] 603 values = expression.args["values"] 604 605 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 606 self.unsupported("Cannot convert array columns into map.") 607 return self.func(map_func_name, keys, values) 608 609 args = [] 610 for key, value in zip(keys.expressions, values.expressions): 611 args.append(self.sql(key)) 612 args.append(self.sql(value)) 613 614 return self.func(map_func_name, *args) 615 616 617def format_time_lambda( 618 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 619) -> t.Callable[[t.List], E]: 620 """Helper used for time expressions. 621 622 Args: 623 exp_class: the expression class to instantiate. 624 dialect: target sql dialect. 625 default: the default format, True being time. 626 627 Returns: 628 A callable that can be used to return the appropriately formatted time expression. 629 """ 630 631 def _format_time(args: t.List): 632 return exp_class( 633 this=seq_get(args, 0), 634 format=Dialect[dialect].format_time( 635 seq_get(args, 1) 636 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 637 ), 638 ) 639 640 return _format_time 641 642 643def time_format( 644 dialect: DialectType = None, 645) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 646 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 647 """ 648 Returns the time format for a given expression, unless it's equivalent 649 to the default time format of the dialect of interest. 650 """ 651 time_format = self.format_time(expression) 652 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 653 654 return _time_format 655 656 657def parse_date_delta( 658 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 659) -> t.Callable[[t.List], E]: 660 def inner_func(args: t.List) -> E: 661 unit_based = len(args) == 3 662 this = args[2] if unit_based else seq_get(args, 0) 663 unit = args[0] if unit_based else exp.Literal.string("DAY") 664 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 665 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 666 667 return inner_func 668 669 670def parse_date_delta_with_interval( 671 expression_class: t.Type[E], 672) -> t.Callable[[t.List], t.Optional[E]]: 673 def func(args: t.List) -> t.Optional[E]: 674 if len(args) < 2: 675 return None 676 677 interval = args[1] 678 679 if not isinstance(interval, exp.Interval): 680 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 681 682 expression = interval.this 683 if expression and expression.is_string: 684 expression = exp.Literal.number(expression.this) 685 686 return expression_class( 687 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 688 ) 689 690 return func 691 692 693def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 694 unit = seq_get(args, 0) 695 this = seq_get(args, 1) 696 697 if isinstance(this, exp.Cast) and this.is_type("date"): 698 return exp.DateTrunc(unit=unit, this=this) 699 return exp.TimestampTrunc(this=this, unit=unit) 700 701 702def date_add_interval_sql( 703 data_type: str, kind: str 704) -> t.Callable[[Generator, exp.Expression], str]: 705 def func(self: Generator, expression: exp.Expression) -> str: 706 this = self.sql(expression, "this") 707 unit = expression.args.get("unit") 708 unit = exp.var(unit.name.upper() if unit else "DAY") 709 interval = exp.Interval(this=expression.expression, unit=unit) 710 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 711 712 return func 713 714 715def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 716 return self.func( 717 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 718 ) 719 720 721def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 722 if not expression.expression: 723 from sqlglot.optimizer.annotate_types import annotate_types 724 725 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 726 return self.sql(exp.cast(expression.this, to=target_type)) 727 if expression.text("expression").lower() in TIMEZONES: 728 return self.sql( 729 exp.AtTimeZone( 730 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 731 zone=expression.expression, 732 ) 733 ) 734 return self.func("TIMESTAMP", expression.this, expression.expression) 735 736 737def locate_to_strposition(args: t.List) -> exp.Expression: 738 return exp.StrPosition( 739 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 740 ) 741 742 743def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 744 return self.func( 745 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 746 ) 747 748 749def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 750 return self.sql( 751 exp.Substring( 752 this=expression.this, start=exp.Literal.number(1), length=expression.expression 753 ) 754 ) 755 756 757def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 758 return self.sql( 759 exp.Substring( 760 this=expression.this, 761 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 762 ) 763 ) 764 765 766def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 767 return self.sql(exp.cast(expression.this, "timestamp")) 768 769 770def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 771 return self.sql(exp.cast(expression.this, "date")) 772 773 774# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 775def encode_decode_sql( 776 self: Generator, expression: exp.Expression, name: str, replace: bool = True 777) -> str: 778 charset = expression.args.get("charset") 779 if charset and charset.name.lower() != "utf-8": 780 self.unsupported(f"Expected utf-8 character set, got {charset}.") 781 782 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 783 784 785def min_or_least(self: Generator, expression: exp.Min) -> str: 786 name = "LEAST" if expression.expressions else "MIN" 787 return rename_func(name)(self, expression) 788 789 790def max_or_greatest(self: Generator, expression: exp.Max) -> str: 791 name = "GREATEST" if expression.expressions else "MAX" 792 return rename_func(name)(self, expression) 793 794 795def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 796 cond = expression.this 797 798 if isinstance(expression.this, exp.Distinct): 799 cond = expression.this.expressions[0] 800 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 801 802 return self.func("sum", exp.func("if", cond, 1, 0)) 803 804 805def trim_sql(self: Generator, expression: exp.Trim) -> str: 806 target = self.sql(expression, "this") 807 trim_type = self.sql(expression, "position") 808 remove_chars = self.sql(expression, "expression") 809 collation = self.sql(expression, "collation") 810 811 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 812 if not remove_chars and not collation: 813 return self.trim_sql(expression) 814 815 trim_type = f"{trim_type} " if trim_type else "" 816 remove_chars = f"{remove_chars} " if remove_chars else "" 817 from_part = "FROM " if trim_type or remove_chars else "" 818 collation = f" COLLATE {collation}" if collation else "" 819 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 820 821 822def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 823 return self.func("STRPTIME", expression.this, self.format_time(expression)) 824 825 826def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 827 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 828 829 830def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 831 delim, *rest_args = expression.expressions 832 return self.sql( 833 reduce( 834 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 835 rest_args, 836 ) 837 ) 838 839 840def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 841 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 842 if bad_args: 843 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 844 845 return self.func( 846 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 847 ) 848 849 850def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 851 bad_args = list( 852 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 853 ) 854 if bad_args: 855 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 856 857 return self.func( 858 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 859 ) 860 861 862def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 863 names = [] 864 for agg in aggregations: 865 if isinstance(agg, exp.Alias): 866 names.append(agg.alias) 867 else: 868 """ 869 This case corresponds to aggregations without aliases being used as suffixes 870 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 871 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 872 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 873 """ 874 agg_all_unquoted = agg.transform( 875 lambda node: ( 876 exp.Identifier(this=node.name, quoted=False) 877 if isinstance(node, exp.Identifier) 878 else node 879 ) 880 ) 881 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 882 883 return names 884 885 886def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 887 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 888 889 890# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 891def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 892 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 893 894 895def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 896 return self.func("MAX", expression.this) 897 898 899def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 900 a = self.sql(expression.left) 901 b = self.sql(expression.right) 902 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 903 904 905def is_parse_json(expression: exp.Expression) -> bool: 906 return isinstance(expression, exp.ParseJSON) or ( 907 isinstance(expression, exp.Cast) and expression.is_type("json") 908 ) 909 910 911def isnull_to_is_null(args: t.List) -> exp.Expression: 912 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 913 914 915def generatedasidentitycolumnconstraint_sql( 916 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 917) -> str: 918 start = self.sql(expression, "start") or "1" 919 increment = self.sql(expression, "increment") or "1" 920 return f"IDENTITY({start}, {increment})" 921 922 923def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 924 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 925 if expression.args.get("count"): 926 self.unsupported(f"Only two arguments are supported in function {name}.") 927 928 return self.func(name, expression.this, expression.expression) 929 930 return _arg_max_or_min_sql 931 932 933def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 934 this = expression.this.copy() 935 936 return_type = expression.return_type 937 if return_type.is_type(exp.DataType.Type.DATE): 938 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 939 # can truncate timestamp strings, because some dialects can't cast them to DATE 940 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 941 942 expression.this.replace(exp.cast(this, return_type)) 943 return expression 944 945 946def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 947 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 948 if cast and isinstance(expression, exp.TsOrDsAdd): 949 expression = ts_or_ds_add_cast(expression) 950 951 return self.func( 952 name, 953 exp.var(expression.text("unit").upper() or "DAY"), 954 expression.expression, 955 expression.this, 956 ) 957 958 return _delta_sql 959 960 961def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 962 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 963 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 964 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 965 966 return self.sql(exp.cast(minus_one_day, "date")) 967 968 969def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 970 """Remove table refs from columns in when statements.""" 971 alias = expression.this.args.get("alias") 972 973 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 974 return self.dialect.normalize_identifier(identifier).name if identifier else None 975 976 targets = {normalize(expression.this.this)} 977 978 if alias: 979 targets.add(normalize(alias.this)) 980 981 for when in expression.expressions: 982 when.transform( 983 lambda node: ( 984 exp.column(node.this) 985 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 986 else node 987 ), 988 copy=False, 989 ) 990 991 return self.merge_sql(expression) 992 993 994def parse_json_extract_path( 995 expr_type: t.Type[F], zero_based_indexing: bool = True 996) -> t.Callable[[t.List], F]: 997 def _parse_json_extract_path(args: t.List) -> F: 998 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 999 for arg in args[1:]: 1000 if not isinstance(arg, exp.Literal): 1001 # We use the fallback parser because we can't really transpile non-literals safely 1002 return expr_type.from_arg_list(args) 1003 1004 text = arg.name 1005 if is_int(text): 1006 index = int(text) 1007 segments.append( 1008 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1009 ) 1010 else: 1011 segments.append(exp.JSONPathKey(this=text)) 1012 1013 # This is done to avoid failing in the expression validator due to the arg count 1014 del args[2:] 1015 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1016 1017 return _parse_json_extract_path 1018 1019 1020def json_extract_segments( 1021 name: str, quoted_index: bool = True 1022) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1023 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1024 path = expression.expression 1025 if not isinstance(path, exp.JSONPath): 1026 return rename_func(name)(self, expression) 1027 1028 segments = [] 1029 for segment in path.expressions: 1030 path = self.sql(segment) 1031 if path: 1032 if isinstance(segment, exp.JSONPathPart) and ( 1033 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1034 ): 1035 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1036 1037 segments.append(path) 1038 1039 return self.func(name, expression.this, *segments) 1040 1041 return _json_extract_segments 1042 1043 1044def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1045 if isinstance(expression.this, exp.JSONPathWildcard): 1046 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1047 1048 return expression.name
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 REDSHIFT = "redshift" 47 SNOWFLAKE = "snowflake" 48 SPARK = "spark" 49 SPARK2 = "spark2" 50 SQLITE = "sqlite" 51 STARROCKS = "starrocks" 52 TABLEAU = "tableau" 53 TERADATA = "teradata" 54 TRINO = "trino" 55 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
58class NormalizationStrategy(str, AutoName): 59 """Specifies the strategy according to which identifiers should be normalized.""" 60 61 LOWERCASE = auto() 62 """Unquoted identifiers are lowercased.""" 63 64 UPPERCASE = auto() 65 """Unquoted identifiers are uppercased.""" 66 67 CASE_SENSITIVE = auto() 68 """Always case-sensitive, regardless of quotes.""" 69 70 CASE_INSENSITIVE = auto() 71 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
150class Dialect(metaclass=_Dialect): 151 INDEX_OFFSET = 0 152 """Determines the base index offset for arrays.""" 153 154 WEEK_OFFSET = 0 155 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 156 157 UNNEST_COLUMN_ONLY = False 158 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 159 160 ALIAS_POST_TABLESAMPLE = False 161 """Determines whether or not the table alias comes after tablesample.""" 162 163 TABLESAMPLE_SIZE_IS_PERCENT = False 164 """Determines whether or not a size in the table sample clause represents percentage.""" 165 166 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 167 """Specifies the strategy according to which identifiers should be normalized.""" 168 169 IDENTIFIERS_CAN_START_WITH_DIGIT = False 170 """Determines whether or not an unquoted identifier can start with a digit.""" 171 172 DPIPE_IS_STRING_CONCAT = True 173 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 174 175 STRICT_STRING_CONCAT = False 176 """Determines whether or not `CONCAT`'s arguments must be strings.""" 177 178 SUPPORTS_USER_DEFINED_TYPES = True 179 """Determines whether or not user-defined data types are supported.""" 180 181 SUPPORTS_SEMI_ANTI_JOIN = True 182 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 183 184 NORMALIZE_FUNCTIONS: bool | str = "upper" 185 """Determines how function names are going to be normalized.""" 186 187 LOG_BASE_FIRST = True 188 """Determines whether the base comes first in the `LOG` function.""" 189 190 NULL_ORDERING = "nulls_are_small" 191 """ 192 Indicates the default `NULL` ordering method to use if not explicitly set. 193 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 194 """ 195 196 TYPED_DIVISION = False 197 """ 198 Whether the behavior of `a / b` depends on the types of `a` and `b`. 199 False means `a / b` is always float division. 200 True means `a / b` is integer division if both `a` and `b` are integers. 201 """ 202 203 SAFE_DIVISION = False 204 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 205 206 CONCAT_COALESCE = False 207 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 208 209 DATE_FORMAT = "'%Y-%m-%d'" 210 DATEINT_FORMAT = "'%Y%m%d'" 211 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 212 213 TIME_MAPPING: t.Dict[str, str] = {} 214 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 215 216 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 217 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 218 FORMAT_MAPPING: t.Dict[str, str] = {} 219 """ 220 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 221 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 222 """ 223 224 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 225 """Mapping of an unescaped escape sequence to the corresponding character.""" 226 227 PSEUDOCOLUMNS: t.Set[str] = set() 228 """ 229 Columns that are auto-generated by the engine corresponding to this dialect. 230 For example, such columns may be excluded from `SELECT *` queries. 231 """ 232 233 PREFER_CTE_ALIAS_COLUMN = False 234 """ 235 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 236 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 237 any projection aliases in the subquery. 238 239 For example, 240 WITH y(c) AS ( 241 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 242 ) SELECT c FROM y; 243 244 will be rewritten as 245 246 WITH y(c) AS ( 247 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 248 ) SELECT c FROM y; 249 """ 250 251 # --- Autofilled --- 252 253 tokenizer_class = Tokenizer 254 parser_class = Parser 255 generator_class = Generator 256 257 # A trie of the time_mapping keys 258 TIME_TRIE: t.Dict = {} 259 FORMAT_TRIE: t.Dict = {} 260 261 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 262 INVERSE_TIME_TRIE: t.Dict = {} 263 264 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 265 266 # Delimiters for string literals and identifiers 267 QUOTE_START = "'" 268 QUOTE_END = "'" 269 IDENTIFIER_START = '"' 270 IDENTIFIER_END = '"' 271 272 # Delimiters for bit, hex, byte and unicode literals 273 BIT_START: t.Optional[str] = None 274 BIT_END: t.Optional[str] = None 275 HEX_START: t.Optional[str] = None 276 HEX_END: t.Optional[str] = None 277 BYTE_START: t.Optional[str] = None 278 BYTE_END: t.Optional[str] = None 279 UNICODE_START: t.Optional[str] = None 280 UNICODE_END: t.Optional[str] = None 281 282 @classmethod 283 def get_or_raise(cls, dialect: DialectType) -> Dialect: 284 """ 285 Look up a dialect in the global dialect registry and return it if it exists. 286 287 Args: 288 dialect: The target dialect. If this is a string, it can be optionally followed by 289 additional key-value pairs that are separated by commas and are used to specify 290 dialect settings, such as whether the dialect's identifiers are case-sensitive. 291 292 Example: 293 >>> dialect = dialect_class = get_or_raise("duckdb") 294 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 295 296 Returns: 297 The corresponding Dialect instance. 298 """ 299 300 if not dialect: 301 return cls() 302 if isinstance(dialect, _Dialect): 303 return dialect() 304 if isinstance(dialect, Dialect): 305 return dialect 306 if isinstance(dialect, str): 307 try: 308 dialect_name, *kv_pairs = dialect.split(",") 309 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 310 except ValueError: 311 raise ValueError( 312 f"Invalid dialect format: '{dialect}'. " 313 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 314 ) 315 316 result = cls.get(dialect_name.strip()) 317 if not result: 318 from difflib import get_close_matches 319 320 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 321 if similar: 322 similar = f" Did you mean {similar}?" 323 324 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 325 326 return result(**kwargs) 327 328 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 329 330 @classmethod 331 def format_time( 332 cls, expression: t.Optional[str | exp.Expression] 333 ) -> t.Optional[exp.Expression]: 334 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 335 if isinstance(expression, str): 336 return exp.Literal.string( 337 # the time formats are quoted 338 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 339 ) 340 341 if expression and expression.is_string: 342 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 343 344 return expression 345 346 def __init__(self, **kwargs) -> None: 347 normalization_strategy = kwargs.get("normalization_strategy") 348 349 if normalization_strategy is None: 350 self.normalization_strategy = self.NORMALIZATION_STRATEGY 351 else: 352 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 353 354 def __eq__(self, other: t.Any) -> bool: 355 # Does not currently take dialect state into account 356 return type(self) == other 357 358 def __hash__(self) -> int: 359 # Does not currently take dialect state into account 360 return hash(type(self)) 361 362 def normalize_identifier(self, expression: E) -> E: 363 """ 364 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 365 366 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 367 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 368 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 369 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 370 371 There are also dialects like Spark, which are case-insensitive even when quotes are 372 present, and dialects like MySQL, whose resolution rules match those employed by the 373 underlying operating system, for example they may always be case-sensitive in Linux. 374 375 Finally, the normalization behavior of some engines can even be controlled through flags, 376 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 377 378 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 379 that it can analyze queries in the optimizer and successfully capture their semantics. 380 """ 381 if ( 382 isinstance(expression, exp.Identifier) 383 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 384 and ( 385 not expression.quoted 386 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 387 ) 388 ): 389 expression.set( 390 "this", 391 ( 392 expression.this.upper() 393 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 394 else expression.this.lower() 395 ), 396 ) 397 398 return expression 399 400 def case_sensitive(self, text: str) -> bool: 401 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 402 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 403 return False 404 405 unsafe = ( 406 str.islower 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else str.isupper 409 ) 410 return any(unsafe(char) for char in text) 411 412 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 413 """Checks if text can be identified given an identify option. 414 415 Args: 416 text: The text to check. 417 identify: 418 `"always"` or `True`: Always returns `True`. 419 `"safe"`: Only returns `True` if the identifier is case-insensitive. 420 421 Returns: 422 Whether or not the given text can be identified. 423 """ 424 if identify is True or identify == "always": 425 return True 426 427 if identify == "safe": 428 return not self.case_sensitive(text) 429 430 return False 431 432 def quote_identifier(self, expression: E, identify: bool = True) -> E: 433 """ 434 Adds quotes to a given identifier. 435 436 Args: 437 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 438 identify: If set to `False`, the quotes will only be added if the identifier is deemed 439 "unsafe", with respect to its characters and this dialect's normalization strategy. 440 """ 441 if isinstance(expression, exp.Identifier): 442 name = expression.this 443 expression.set( 444 "quoted", 445 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 446 ) 447 448 return expression 449 450 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 451 if isinstance(path, exp.Literal): 452 path_text = path.name 453 if path.is_number: 454 path_text = f"[{path_text}]" 455 456 try: 457 return parse_json_path(path_text) 458 except ParseError as e: 459 logger.warning(f"Invalid JSON path syntax. {str(e)}") 460 461 return path 462 463 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 464 return self.parser(**opts).parse(self.tokenize(sql), sql) 465 466 def parse_into( 467 self, expression_type: exp.IntoType, sql: str, **opts 468 ) -> t.List[t.Optional[exp.Expression]]: 469 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 470 471 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 472 return self.generator(**opts).generate(expression, copy=copy) 473 474 def transpile(self, sql: str, **opts) -> t.List[str]: 475 return [ 476 self.generate(expression, copy=False, **opts) if expression else "" 477 for expression in self.parse(sql) 478 ] 479 480 def tokenize(self, sql: str) -> t.List[Token]: 481 return self.tokenizer.tokenize(sql) 482 483 @property 484 def tokenizer(self) -> Tokenizer: 485 if not hasattr(self, "_tokenizer"): 486 self._tokenizer = self.tokenizer_class(dialect=self) 487 return self._tokenizer 488 489 def parser(self, **opts) -> Parser: 490 return self.parser_class(dialect=self, **opts) 491 492 def generator(self, **opts) -> Generator: 493 return self.generator_class(dialect=self, **opts)
346 def __init__(self, **kwargs) -> None: 347 normalization_strategy = kwargs.get("normalization_strategy") 348 349 if normalization_strategy is None: 350 self.normalization_strategy = self.NORMALIZATION_STRATEGY 351 else: 352 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Determines whether or not a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
282 @classmethod 283 def get_or_raise(cls, dialect: DialectType) -> Dialect: 284 """ 285 Look up a dialect in the global dialect registry and return it if it exists. 286 287 Args: 288 dialect: The target dialect. If this is a string, it can be optionally followed by 289 additional key-value pairs that are separated by commas and are used to specify 290 dialect settings, such as whether the dialect's identifiers are case-sensitive. 291 292 Example: 293 >>> dialect = dialect_class = get_or_raise("duckdb") 294 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 295 296 Returns: 297 The corresponding Dialect instance. 298 """ 299 300 if not dialect: 301 return cls() 302 if isinstance(dialect, _Dialect): 303 return dialect() 304 if isinstance(dialect, Dialect): 305 return dialect 306 if isinstance(dialect, str): 307 try: 308 dialect_name, *kv_pairs = dialect.split(",") 309 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 310 except ValueError: 311 raise ValueError( 312 f"Invalid dialect format: '{dialect}'. " 313 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 314 ) 315 316 result = cls.get(dialect_name.strip()) 317 if not result: 318 from difflib import get_close_matches 319 320 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 321 if similar: 322 similar = f" Did you mean {similar}?" 323 324 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 325 326 return result(**kwargs) 327 328 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
330 @classmethod 331 def format_time( 332 cls, expression: t.Optional[str | exp.Expression] 333 ) -> t.Optional[exp.Expression]: 334 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 335 if isinstance(expression, str): 336 return exp.Literal.string( 337 # the time formats are quoted 338 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 339 ) 340 341 if expression and expression.is_string: 342 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 343 344 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
362 def normalize_identifier(self, expression: E) -> E: 363 """ 364 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 365 366 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 367 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 368 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 369 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 370 371 There are also dialects like Spark, which are case-insensitive even when quotes are 372 present, and dialects like MySQL, whose resolution rules match those employed by the 373 underlying operating system, for example they may always be case-sensitive in Linux. 374 375 Finally, the normalization behavior of some engines can even be controlled through flags, 376 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 377 378 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 379 that it can analyze queries in the optimizer and successfully capture their semantics. 380 """ 381 if ( 382 isinstance(expression, exp.Identifier) 383 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 384 and ( 385 not expression.quoted 386 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 387 ) 388 ): 389 expression.set( 390 "this", 391 ( 392 expression.this.upper() 393 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 394 else expression.this.lower() 395 ), 396 ) 397 398 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
400 def case_sensitive(self, text: str) -> bool: 401 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 402 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 403 return False 404 405 unsafe = ( 406 str.islower 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else str.isupper 409 ) 410 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
412 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 413 """Checks if text can be identified given an identify option. 414 415 Args: 416 text: The text to check. 417 identify: 418 `"always"` or `True`: Always returns `True`. 419 `"safe"`: Only returns `True` if the identifier is case-insensitive. 420 421 Returns: 422 Whether or not the given text can be identified. 423 """ 424 if identify is True or identify == "always": 425 return True 426 427 if identify == "safe": 428 return not self.case_sensitive(text) 429 430 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
432 def quote_identifier(self, expression: E, identify: bool = True) -> E: 433 """ 434 Adds quotes to a given identifier. 435 436 Args: 437 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 438 identify: If set to `False`, the quotes will only be added if the identifier is deemed 439 "unsafe", with respect to its characters and this dialect's normalization strategy. 440 """ 441 if isinstance(expression, exp.Identifier): 442 name = expression.this 443 expression.set( 444 "quoted", 445 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 446 ) 447 448 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
450 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 451 if isinstance(path, exp.Literal): 452 path_text = path.name 453 if path.is_number: 454 path_text = f"[{path_text}]" 455 456 try: 457 return parse_json_path(path_text) 458 except ParseError as e: 459 logger.warning(f"Invalid JSON path syntax. {str(e)}") 460 461 return path
509def if_sql( 510 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 511) -> t.Callable[[Generator, exp.If], str]: 512 def _if_sql(self: Generator, expression: exp.If) -> str: 513 return self.func( 514 name, 515 expression.this, 516 expression.args.get("true"), 517 expression.args.get("false") or false_value, 518 ) 519 520 return _if_sql
523def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 524 this = expression.this 525 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 526 this.replace(exp.cast(this, "json")) 527 528 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
585def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 586 this = self.sql(expression, "this") 587 substr = self.sql(expression, "substr") 588 position = self.sql(expression, "position") 589 if position: 590 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 591 return f"STRPOS({this}, {substr})"
600def var_map_sql( 601 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 602) -> str: 603 keys = expression.args["keys"] 604 values = expression.args["values"] 605 606 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 607 self.unsupported("Cannot convert array columns into map.") 608 return self.func(map_func_name, keys, values) 609 610 args = [] 611 for key, value in zip(keys.expressions, values.expressions): 612 args.append(self.sql(key)) 613 args.append(self.sql(value)) 614 615 return self.func(map_func_name, *args)
618def format_time_lambda( 619 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 620) -> t.Callable[[t.List], E]: 621 """Helper used for time expressions. 622 623 Args: 624 exp_class: the expression class to instantiate. 625 dialect: target sql dialect. 626 default: the default format, True being time. 627 628 Returns: 629 A callable that can be used to return the appropriately formatted time expression. 630 """ 631 632 def _format_time(args: t.List): 633 return exp_class( 634 this=seq_get(args, 0), 635 format=Dialect[dialect].format_time( 636 seq_get(args, 1) 637 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 638 ), 639 ) 640 641 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
644def time_format( 645 dialect: DialectType = None, 646) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 647 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 648 """ 649 Returns the time format for a given expression, unless it's equivalent 650 to the default time format of the dialect of interest. 651 """ 652 time_format = self.format_time(expression) 653 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 654 655 return _time_format
658def parse_date_delta( 659 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 660) -> t.Callable[[t.List], E]: 661 def inner_func(args: t.List) -> E: 662 unit_based = len(args) == 3 663 this = args[2] if unit_based else seq_get(args, 0) 664 unit = args[0] if unit_based else exp.Literal.string("DAY") 665 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 666 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 667 668 return inner_func
671def parse_date_delta_with_interval( 672 expression_class: t.Type[E], 673) -> t.Callable[[t.List], t.Optional[E]]: 674 def func(args: t.List) -> t.Optional[E]: 675 if len(args) < 2: 676 return None 677 678 interval = args[1] 679 680 if not isinstance(interval, exp.Interval): 681 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 682 683 expression = interval.this 684 if expression and expression.is_string: 685 expression = exp.Literal.number(expression.this) 686 687 return expression_class( 688 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 689 ) 690 691 return func
703def date_add_interval_sql( 704 data_type: str, kind: str 705) -> t.Callable[[Generator, exp.Expression], str]: 706 def func(self: Generator, expression: exp.Expression) -> str: 707 this = self.sql(expression, "this") 708 unit = expression.args.get("unit") 709 unit = exp.var(unit.name.upper() if unit else "DAY") 710 interval = exp.Interval(this=expression.expression, unit=unit) 711 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 712 713 return func
722def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 723 if not expression.expression: 724 from sqlglot.optimizer.annotate_types import annotate_types 725 726 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 727 return self.sql(exp.cast(expression.this, to=target_type)) 728 if expression.text("expression").lower() in TIMEZONES: 729 return self.sql( 730 exp.AtTimeZone( 731 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 732 zone=expression.expression, 733 ) 734 ) 735 return self.func("TIMESTAMP", expression.this, expression.expression)
776def encode_decode_sql( 777 self: Generator, expression: exp.Expression, name: str, replace: bool = True 778) -> str: 779 charset = expression.args.get("charset") 780 if charset and charset.name.lower() != "utf-8": 781 self.unsupported(f"Expected utf-8 character set, got {charset}.") 782 783 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
796def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 797 cond = expression.this 798 799 if isinstance(expression.this, exp.Distinct): 800 cond = expression.this.expressions[0] 801 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 802 803 return self.func("sum", exp.func("if", cond, 1, 0))
806def trim_sql(self: Generator, expression: exp.Trim) -> str: 807 target = self.sql(expression, "this") 808 trim_type = self.sql(expression, "position") 809 remove_chars = self.sql(expression, "expression") 810 collation = self.sql(expression, "collation") 811 812 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 813 if not remove_chars and not collation: 814 return self.trim_sql(expression) 815 816 trim_type = f"{trim_type} " if trim_type else "" 817 remove_chars = f"{remove_chars} " if remove_chars else "" 818 from_part = "FROM " if trim_type or remove_chars else "" 819 collation = f" COLLATE {collation}" if collation else "" 820 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
841def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 842 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 843 if bad_args: 844 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 845 846 return self.func( 847 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 848 )
851def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 852 bad_args = list( 853 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 854 ) 855 if bad_args: 856 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 857 858 return self.func( 859 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 860 )
863def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 864 names = [] 865 for agg in aggregations: 866 if isinstance(agg, exp.Alias): 867 names.append(agg.alias) 868 else: 869 """ 870 This case corresponds to aggregations without aliases being used as suffixes 871 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 872 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 873 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 874 """ 875 agg_all_unquoted = agg.transform( 876 lambda node: ( 877 exp.Identifier(this=node.name, quoted=False) 878 if isinstance(node, exp.Identifier) 879 else node 880 ) 881 ) 882 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 883 884 return names
924def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 925 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 926 if expression.args.get("count"): 927 self.unsupported(f"Only two arguments are supported in function {name}.") 928 929 return self.func(name, expression.this, expression.expression) 930 931 return _arg_max_or_min_sql
934def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 935 this = expression.this.copy() 936 937 return_type = expression.return_type 938 if return_type.is_type(exp.DataType.Type.DATE): 939 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 940 # can truncate timestamp strings, because some dialects can't cast them to DATE 941 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 942 943 expression.this.replace(exp.cast(this, return_type)) 944 return expression
947def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 948 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 949 if cast and isinstance(expression, exp.TsOrDsAdd): 950 expression = ts_or_ds_add_cast(expression) 951 952 return self.func( 953 name, 954 exp.var(expression.text("unit").upper() or "DAY"), 955 expression.expression, 956 expression.this, 957 ) 958 959 return _delta_sql
962def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 963 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 964 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 965 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 966 967 return self.sql(exp.cast(minus_one_day, "date"))
970def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 971 """Remove table refs from columns in when statements.""" 972 alias = expression.this.args.get("alias") 973 974 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 975 return self.dialect.normalize_identifier(identifier).name if identifier else None 976 977 targets = {normalize(expression.this.this)} 978 979 if alias: 980 targets.add(normalize(alias.this)) 981 982 for when in expression.expressions: 983 when.transform( 984 lambda node: ( 985 exp.column(node.this) 986 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 987 else node 988 ), 989 copy=False, 990 ) 991 992 return self.merge_sql(expression)
Remove table refs from columns in when statements.
995def parse_json_extract_path( 996 expr_type: t.Type[F], zero_based_indexing: bool = True 997) -> t.Callable[[t.List], F]: 998 def _parse_json_extract_path(args: t.List) -> F: 999 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1000 for arg in args[1:]: 1001 if not isinstance(arg, exp.Literal): 1002 # We use the fallback parser because we can't really transpile non-literals safely 1003 return expr_type.from_arg_list(args) 1004 1005 text = arg.name 1006 if is_int(text): 1007 index = int(text) 1008 segments.append( 1009 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1010 ) 1011 else: 1012 segments.append(exp.JSONPathKey(this=text)) 1013 1014 # This is done to avoid failing in the expression validator due to the arg count 1015 del args[2:] 1016 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1017 1018 return _parse_json_extract_path
1021def json_extract_segments( 1022 name: str, quoted_index: bool = True 1023) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1024 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1025 path = expression.expression 1026 if not isinstance(path, exp.JSONPath): 1027 return rename_func(name)(self, expression) 1028 1029 segments = [] 1030 for segment in path.expressions: 1031 path = self.sql(segment) 1032 if path: 1033 if isinstance(segment, exp.JSONPathPart) and ( 1034 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1035 ): 1036 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1037 1038 segments.append(path) 1039 1040 return self.func(name, expression.this, *segments) 1041 1042 return _json_extract_segments