sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses, to_bool 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[ 19 exp.DateAdd, 20 exp.DateDiff, 21 exp.DateSub, 22 exp.TsOrDsAdd, 23 exp.TsOrDsDiff, 24] 25DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 26JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 27 28 29if t.TYPE_CHECKING: 30 from sqlglot._typing import B, E, F 31 32 from sqlglot.optimizer.annotate_types import TypeAnnotator 33 34 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 35 36logger = logging.getLogger("sqlglot") 37 38UNESCAPED_SEQUENCES = { 39 "\\a": "\a", 40 "\\b": "\b", 41 "\\f": "\f", 42 "\\n": "\n", 43 "\\r": "\r", 44 "\\t": "\t", 45 "\\v": "\v", 46 "\\\\": "\\", 47} 48 49 50def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 51 return lambda self, e: self._annotate_with_type(e, data_type) 52 53 54class Dialects(str, Enum): 55 """Dialects supported by SQLGLot.""" 56 57 DIALECT = "" 58 59 ATHENA = "athena" 60 BIGQUERY = "bigquery" 61 CLICKHOUSE = "clickhouse" 62 DATABRICKS = "databricks" 63 DORIS = "doris" 64 DRILL = "drill" 65 DRUID = "druid" 66 DUCKDB = "duckdb" 67 HIVE = "hive" 68 MATERIALIZE = "materialize" 69 MYSQL = "mysql" 70 ORACLE = "oracle" 71 POSTGRES = "postgres" 72 PRESTO = "presto" 73 PRQL = "prql" 74 REDSHIFT = "redshift" 75 RISINGWAVE = "risingwave" 76 SNOWFLAKE = "snowflake" 77 SPARK = "spark" 78 SPARK2 = "spark2" 79 SQLITE = "sqlite" 80 STARROCKS = "starrocks" 81 TABLEAU = "tableau" 82 TERADATA = "teradata" 83 TRINO = "trino" 84 TSQL = "tsql" 85 86 87class NormalizationStrategy(str, AutoName): 88 """Specifies the strategy according to which identifiers should be normalized.""" 89 90 LOWERCASE = auto() 91 """Unquoted identifiers are lowercased.""" 92 93 UPPERCASE = auto() 94 """Unquoted identifiers are uppercased.""" 95 96 CASE_SENSITIVE = auto() 97 """Always case-sensitive, regardless of quotes.""" 98 99 CASE_INSENSITIVE = auto() 100 """Always case-insensitive, regardless of quotes.""" 101 102 103class _Dialect(type): 104 classes: t.Dict[str, t.Type[Dialect]] = {} 105 106 def __eq__(cls, other: t.Any) -> bool: 107 if cls is other: 108 return True 109 if isinstance(other, str): 110 return cls is cls.get(other) 111 if isinstance(other, Dialect): 112 return cls is type(other) 113 114 return False 115 116 def __hash__(cls) -> int: 117 return hash(cls.__name__.lower()) 118 119 @classmethod 120 def __getitem__(cls, key: str) -> t.Type[Dialect]: 121 return cls.classes[key] 122 123 @classmethod 124 def get( 125 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 126 ) -> t.Optional[t.Type[Dialect]]: 127 return cls.classes.get(key, default) 128 129 def __new__(cls, clsname, bases, attrs): 130 klass = super().__new__(cls, clsname, bases, attrs) 131 enum = Dialects.__members__.get(clsname.upper()) 132 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 133 134 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 135 klass.FORMAT_TRIE = ( 136 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 137 ) 138 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 139 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 140 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 141 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 142 143 klass.INVERSE_CREATABLE_KIND_MAPPING = { 144 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 145 } 146 147 base = seq_get(bases, 0) 148 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 149 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 150 base_parser = (getattr(base, "parser_class", Parser),) 151 base_generator = (getattr(base, "generator_class", Generator),) 152 153 klass.tokenizer_class = klass.__dict__.get( 154 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 155 ) 156 klass.jsonpath_tokenizer_class = klass.__dict__.get( 157 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 158 ) 159 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 160 klass.generator_class = klass.__dict__.get( 161 "Generator", type("Generator", base_generator, {}) 162 ) 163 164 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 165 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 166 klass.tokenizer_class._IDENTIFIERS.items() 167 )[0] 168 169 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 170 return next( 171 ( 172 (s, e) 173 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 174 if t == token_type 175 ), 176 (None, None), 177 ) 178 179 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 180 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 181 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 182 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 183 184 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 185 klass.UNESCAPED_SEQUENCES = { 186 **UNESCAPED_SEQUENCES, 187 **klass.UNESCAPED_SEQUENCES, 188 } 189 190 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 191 192 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 193 194 if enum not in ("", "bigquery"): 195 klass.generator_class.SELECT_KINDS = () 196 197 if enum not in ("", "athena", "presto", "trino"): 198 klass.generator_class.TRY_SUPPORTED = False 199 klass.generator_class.SUPPORTS_UESCAPE = False 200 201 if enum not in ("", "databricks", "hive", "spark", "spark2"): 202 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 203 for modifier in ("cluster", "distribute", "sort"): 204 modifier_transforms.pop(modifier, None) 205 206 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 207 208 if enum not in ("", "doris", "mysql"): 209 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 210 TokenType.STRAIGHT_JOIN, 211 } 212 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 213 TokenType.STRAIGHT_JOIN, 214 } 215 216 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 217 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 218 TokenType.ANTI, 219 TokenType.SEMI, 220 } 221 222 return klass 223 224 225class Dialect(metaclass=_Dialect): 226 INDEX_OFFSET = 0 227 """The base index offset for arrays.""" 228 229 WEEK_OFFSET = 0 230 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 231 232 UNNEST_COLUMN_ONLY = False 233 """Whether `UNNEST` table aliases are treated as column aliases.""" 234 235 ALIAS_POST_TABLESAMPLE = False 236 """Whether the table alias comes after tablesample.""" 237 238 TABLESAMPLE_SIZE_IS_PERCENT = False 239 """Whether a size in the table sample clause represents percentage.""" 240 241 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 242 """Specifies the strategy according to which identifiers should be normalized.""" 243 244 IDENTIFIERS_CAN_START_WITH_DIGIT = False 245 """Whether an unquoted identifier can start with a digit.""" 246 247 DPIPE_IS_STRING_CONCAT = True 248 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 249 250 STRICT_STRING_CONCAT = False 251 """Whether `CONCAT`'s arguments must be strings.""" 252 253 SUPPORTS_USER_DEFINED_TYPES = True 254 """Whether user-defined data types are supported.""" 255 256 SUPPORTS_SEMI_ANTI_JOIN = True 257 """Whether `SEMI` or `ANTI` joins are supported.""" 258 259 SUPPORTS_COLUMN_JOIN_MARKS = False 260 """Whether the old-style outer join (+) syntax is supported.""" 261 262 COPY_PARAMS_ARE_CSV = True 263 """Separator of COPY statement parameters.""" 264 265 NORMALIZE_FUNCTIONS: bool | str = "upper" 266 """ 267 Determines how function names are going to be normalized. 268 Possible values: 269 "upper" or True: Convert names to uppercase. 270 "lower": Convert names to lowercase. 271 False: Disables function name normalization. 272 """ 273 274 PRESERVE_ORIGINAL_NAMES: bool = False 275 """ 276 Whether the name of the function should be preserved inside the node's metadata, 277 can be useful for roundtripping deprecated vs new functions that share an AST node 278 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 279 """ 280 281 LOG_BASE_FIRST: t.Optional[bool] = True 282 """ 283 Whether the base comes first in the `LOG` function. 284 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 285 """ 286 287 NULL_ORDERING = "nulls_are_small" 288 """ 289 Default `NULL` ordering method to use if not explicitly set. 290 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 291 """ 292 293 TYPED_DIVISION = False 294 """ 295 Whether the behavior of `a / b` depends on the types of `a` and `b`. 296 False means `a / b` is always float division. 297 True means `a / b` is integer division if both `a` and `b` are integers. 298 """ 299 300 SAFE_DIVISION = False 301 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 302 303 CONCAT_COALESCE = False 304 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 305 306 HEX_LOWERCASE = False 307 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 308 309 DATE_FORMAT = "'%Y-%m-%d'" 310 DATEINT_FORMAT = "'%Y%m%d'" 311 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 312 313 TIME_MAPPING: t.Dict[str, str] = {} 314 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 315 316 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 317 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 318 FORMAT_MAPPING: t.Dict[str, str] = {} 319 """ 320 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 321 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 322 """ 323 324 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 325 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 326 327 PSEUDOCOLUMNS: t.Set[str] = set() 328 """ 329 Columns that are auto-generated by the engine corresponding to this dialect. 330 For example, such columns may be excluded from `SELECT *` queries. 331 """ 332 333 PREFER_CTE_ALIAS_COLUMN = False 334 """ 335 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 336 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 337 any projection aliases in the subquery. 338 339 For example, 340 WITH y(c) AS ( 341 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 342 ) SELECT c FROM y; 343 344 will be rewritten as 345 346 WITH y(c) AS ( 347 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 348 ) SELECT c FROM y; 349 """ 350 351 COPY_PARAMS_ARE_CSV = True 352 """ 353 Whether COPY statement parameters are separated by comma or whitespace 354 """ 355 356 FORCE_EARLY_ALIAS_REF_EXPANSION = False 357 """ 358 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 359 360 For example: 361 WITH data AS ( 362 SELECT 363 1 AS id, 364 2 AS my_id 365 ) 366 SELECT 367 id AS my_id 368 FROM 369 data 370 WHERE 371 my_id = 1 372 GROUP BY 373 my_id, 374 HAVING 375 my_id = 1 376 377 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 378 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 379 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 380 - Clickhouse, which will forward the alias across the query i.e it resolves 381 to "WHERE id = 1 GROUP BY id HAVING id = 1" 382 """ 383 384 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 385 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 386 387 SUPPORTS_ORDER_BY_ALL = False 388 """ 389 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 390 """ 391 392 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 393 """ 394 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 395 as the former is of type INT[] vs the latter which is SUPER 396 """ 397 398 SUPPORTS_FIXED_SIZE_ARRAYS = False 399 """ 400 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 401 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 402 be interpreted as a subscript/index operator. 403 """ 404 405 STRICT_JSON_PATH_SYNTAX = True 406 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 407 408 ON_CONDITION_EMPTY_BEFORE_ERROR = True 409 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 410 411 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 412 """Whether ArrayAgg needs to filter NULL values.""" 413 414 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 415 """ 416 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 417 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 418 is cast to x's type to match it instead. 419 """ 420 421 SUPPORTS_VALUES_DEFAULT = True 422 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 423 424 NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False 425 """Whether number literals can include underscores for better readability""" 426 427 REGEXP_EXTRACT_DEFAULT_GROUP = 0 428 """The default value for the capturing group.""" 429 430 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 431 exp.Except: True, 432 exp.Intersect: True, 433 exp.Union: True, 434 } 435 """ 436 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 437 must be explicitly specified. 438 """ 439 440 CREATABLE_KIND_MAPPING: dict[str, str] = {} 441 """ 442 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 443 equivalent of CREATE SCHEMA is CREATE DATABASE. 444 """ 445 446 # --- Autofilled --- 447 448 tokenizer_class = Tokenizer 449 jsonpath_tokenizer_class = JSONPathTokenizer 450 parser_class = Parser 451 generator_class = Generator 452 453 # A trie of the time_mapping keys 454 TIME_TRIE: t.Dict = {} 455 FORMAT_TRIE: t.Dict = {} 456 457 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 458 INVERSE_TIME_TRIE: t.Dict = {} 459 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 460 INVERSE_FORMAT_TRIE: t.Dict = {} 461 462 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 463 464 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 465 466 # Delimiters for string literals and identifiers 467 QUOTE_START = "'" 468 QUOTE_END = "'" 469 IDENTIFIER_START = '"' 470 IDENTIFIER_END = '"' 471 472 # Delimiters for bit, hex, byte and unicode literals 473 BIT_START: t.Optional[str] = None 474 BIT_END: t.Optional[str] = None 475 HEX_START: t.Optional[str] = None 476 HEX_END: t.Optional[str] = None 477 BYTE_START: t.Optional[str] = None 478 BYTE_END: t.Optional[str] = None 479 UNICODE_START: t.Optional[str] = None 480 UNICODE_END: t.Optional[str] = None 481 482 DATE_PART_MAPPING = { 483 "Y": "YEAR", 484 "YY": "YEAR", 485 "YYY": "YEAR", 486 "YYYY": "YEAR", 487 "YR": "YEAR", 488 "YEARS": "YEAR", 489 "YRS": "YEAR", 490 "MM": "MONTH", 491 "MON": "MONTH", 492 "MONS": "MONTH", 493 "MONTHS": "MONTH", 494 "D": "DAY", 495 "DD": "DAY", 496 "DAYS": "DAY", 497 "DAYOFMONTH": "DAY", 498 "DAY OF WEEK": "DAYOFWEEK", 499 "WEEKDAY": "DAYOFWEEK", 500 "DOW": "DAYOFWEEK", 501 "DW": "DAYOFWEEK", 502 "WEEKDAY_ISO": "DAYOFWEEKISO", 503 "DOW_ISO": "DAYOFWEEKISO", 504 "DW_ISO": "DAYOFWEEKISO", 505 "DAY OF YEAR": "DAYOFYEAR", 506 "DOY": "DAYOFYEAR", 507 "DY": "DAYOFYEAR", 508 "W": "WEEK", 509 "WK": "WEEK", 510 "WEEKOFYEAR": "WEEK", 511 "WOY": "WEEK", 512 "WY": "WEEK", 513 "WEEK_ISO": "WEEKISO", 514 "WEEKOFYEARISO": "WEEKISO", 515 "WEEKOFYEAR_ISO": "WEEKISO", 516 "Q": "QUARTER", 517 "QTR": "QUARTER", 518 "QTRS": "QUARTER", 519 "QUARTERS": "QUARTER", 520 "H": "HOUR", 521 "HH": "HOUR", 522 "HR": "HOUR", 523 "HOURS": "HOUR", 524 "HRS": "HOUR", 525 "M": "MINUTE", 526 "MI": "MINUTE", 527 "MIN": "MINUTE", 528 "MINUTES": "MINUTE", 529 "MINS": "MINUTE", 530 "S": "SECOND", 531 "SEC": "SECOND", 532 "SECONDS": "SECOND", 533 "SECS": "SECOND", 534 "MS": "MILLISECOND", 535 "MSEC": "MILLISECOND", 536 "MSECS": "MILLISECOND", 537 "MSECOND": "MILLISECOND", 538 "MSECONDS": "MILLISECOND", 539 "MILLISEC": "MILLISECOND", 540 "MILLISECS": "MILLISECOND", 541 "MILLISECON": "MILLISECOND", 542 "MILLISECONDS": "MILLISECOND", 543 "US": "MICROSECOND", 544 "USEC": "MICROSECOND", 545 "USECS": "MICROSECOND", 546 "MICROSEC": "MICROSECOND", 547 "MICROSECS": "MICROSECOND", 548 "USECOND": "MICROSECOND", 549 "USECONDS": "MICROSECOND", 550 "MICROSECONDS": "MICROSECOND", 551 "NS": "NANOSECOND", 552 "NSEC": "NANOSECOND", 553 "NANOSEC": "NANOSECOND", 554 "NSECOND": "NANOSECOND", 555 "NSECONDS": "NANOSECOND", 556 "NANOSECS": "NANOSECOND", 557 "EPOCH_SECOND": "EPOCH", 558 "EPOCH_SECONDS": "EPOCH", 559 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 560 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 561 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 562 "TZH": "TIMEZONE_HOUR", 563 "TZM": "TIMEZONE_MINUTE", 564 "DEC": "DECADE", 565 "DECS": "DECADE", 566 "DECADES": "DECADE", 567 "MIL": "MILLENIUM", 568 "MILS": "MILLENIUM", 569 "MILLENIA": "MILLENIUM", 570 "C": "CENTURY", 571 "CENT": "CENTURY", 572 "CENTS": "CENTURY", 573 "CENTURIES": "CENTURY", 574 } 575 576 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 577 exp.DataType.Type.BIGINT: { 578 exp.ApproxDistinct, 579 exp.ArraySize, 580 exp.Length, 581 }, 582 exp.DataType.Type.BOOLEAN: { 583 exp.Between, 584 exp.Boolean, 585 exp.In, 586 exp.RegexpLike, 587 }, 588 exp.DataType.Type.DATE: { 589 exp.CurrentDate, 590 exp.Date, 591 exp.DateFromParts, 592 exp.DateStrToDate, 593 exp.DiToDate, 594 exp.StrToDate, 595 exp.TimeStrToDate, 596 exp.TsOrDsToDate, 597 }, 598 exp.DataType.Type.DATETIME: { 599 exp.CurrentDatetime, 600 exp.Datetime, 601 exp.DatetimeAdd, 602 exp.DatetimeSub, 603 }, 604 exp.DataType.Type.DOUBLE: { 605 exp.ApproxQuantile, 606 exp.Avg, 607 exp.Exp, 608 exp.Ln, 609 exp.Log, 610 exp.Pow, 611 exp.Quantile, 612 exp.Round, 613 exp.SafeDivide, 614 exp.Sqrt, 615 exp.Stddev, 616 exp.StddevPop, 617 exp.StddevSamp, 618 exp.ToDouble, 619 exp.Variance, 620 exp.VariancePop, 621 }, 622 exp.DataType.Type.INT: { 623 exp.Ceil, 624 exp.DatetimeDiff, 625 exp.DateDiff, 626 exp.TimestampDiff, 627 exp.TimeDiff, 628 exp.DateToDi, 629 exp.Levenshtein, 630 exp.Sign, 631 exp.StrPosition, 632 exp.TsOrDiToDi, 633 }, 634 exp.DataType.Type.JSON: { 635 exp.ParseJSON, 636 }, 637 exp.DataType.Type.TIME: { 638 exp.Time, 639 }, 640 exp.DataType.Type.TIMESTAMP: { 641 exp.CurrentTime, 642 exp.CurrentTimestamp, 643 exp.StrToTime, 644 exp.TimeAdd, 645 exp.TimeStrToTime, 646 exp.TimeSub, 647 exp.TimestampAdd, 648 exp.TimestampSub, 649 exp.UnixToTime, 650 }, 651 exp.DataType.Type.TINYINT: { 652 exp.Day, 653 exp.Month, 654 exp.Week, 655 exp.Year, 656 exp.Quarter, 657 }, 658 exp.DataType.Type.VARCHAR: { 659 exp.ArrayConcat, 660 exp.Concat, 661 exp.ConcatWs, 662 exp.DateToDateStr, 663 exp.GroupConcat, 664 exp.Initcap, 665 exp.Lower, 666 exp.Substring, 667 exp.String, 668 exp.TimeToStr, 669 exp.TimeToTimeStr, 670 exp.Trim, 671 exp.TsOrDsToDateStr, 672 exp.UnixToStr, 673 exp.UnixToTimeStr, 674 exp.Upper, 675 }, 676 } 677 678 ANNOTATORS: AnnotatorsType = { 679 **{ 680 expr_type: lambda self, e: self._annotate_unary(e) 681 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 682 }, 683 **{ 684 expr_type: lambda self, e: self._annotate_binary(e) 685 for expr_type in subclasses(exp.__name__, exp.Binary) 686 }, 687 **{ 688 expr_type: _annotate_with_type_lambda(data_type) 689 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 690 for expr_type in expressions 691 }, 692 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 693 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 694 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 695 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 696 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 697 exp.Bracket: lambda self, e: self._annotate_bracket(e), 698 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 699 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 700 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Count: lambda self, e: self._annotate_with_type( 702 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 703 ), 704 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 705 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 706 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 707 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 708 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 709 exp.Div: lambda self, e: self._annotate_div(e), 710 exp.Dot: lambda self, e: self._annotate_dot(e), 711 exp.Explode: lambda self, e: self._annotate_explode(e), 712 exp.Extract: lambda self, e: self._annotate_extract(e), 713 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 714 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 715 e, exp.DataType.build("ARRAY<DATE>") 716 ), 717 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 718 e, exp.DataType.build("ARRAY<TIMESTAMP>") 719 ), 720 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 721 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 722 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 723 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 724 exp.Literal: lambda self, e: self._annotate_literal(e), 725 exp.Map: lambda self, e: self._annotate_map(e), 726 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 727 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 728 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 729 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 730 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 731 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 732 exp.Struct: lambda self, e: self._annotate_struct(e), 733 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 734 exp.Timestamp: lambda self, e: self._annotate_with_type( 735 e, 736 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 737 ), 738 exp.ToMap: lambda self, e: self._annotate_to_map(e), 739 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 740 exp.Unnest: lambda self, e: self._annotate_unnest(e), 741 exp.VarMap: lambda self, e: self._annotate_map(e), 742 } 743 744 @classmethod 745 def get_or_raise(cls, dialect: DialectType) -> Dialect: 746 """ 747 Look up a dialect in the global dialect registry and return it if it exists. 748 749 Args: 750 dialect: The target dialect. If this is a string, it can be optionally followed by 751 additional key-value pairs that are separated by commas and are used to specify 752 dialect settings, such as whether the dialect's identifiers are case-sensitive. 753 754 Example: 755 >>> dialect = dialect_class = get_or_raise("duckdb") 756 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 757 758 Returns: 759 The corresponding Dialect instance. 760 """ 761 762 if not dialect: 763 return cls() 764 if isinstance(dialect, _Dialect): 765 return dialect() 766 if isinstance(dialect, Dialect): 767 return dialect 768 if isinstance(dialect, str): 769 try: 770 dialect_name, *kv_strings = dialect.split(",") 771 kv_pairs = (kv.split("=") for kv in kv_strings) 772 kwargs = {} 773 for pair in kv_pairs: 774 key = pair[0].strip() 775 value: t.Union[bool | str | None] = None 776 777 if len(pair) == 1: 778 # Default initialize standalone settings to True 779 value = True 780 elif len(pair) == 2: 781 value = pair[1].strip() 782 783 kwargs[key] = to_bool(value) 784 785 except ValueError: 786 raise ValueError( 787 f"Invalid dialect format: '{dialect}'. " 788 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 789 ) 790 791 result = cls.get(dialect_name.strip()) 792 if not result: 793 from difflib import get_close_matches 794 795 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 796 if similar: 797 similar = f" Did you mean {similar}?" 798 799 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 800 801 return result(**kwargs) 802 803 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 804 805 @classmethod 806 def format_time( 807 cls, expression: t.Optional[str | exp.Expression] 808 ) -> t.Optional[exp.Expression]: 809 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 810 if isinstance(expression, str): 811 return exp.Literal.string( 812 # the time formats are quoted 813 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 814 ) 815 816 if expression and expression.is_string: 817 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 818 819 return expression 820 821 def __init__(self, **kwargs) -> None: 822 normalization_strategy = kwargs.pop("normalization_strategy", None) 823 824 if normalization_strategy is None: 825 self.normalization_strategy = self.NORMALIZATION_STRATEGY 826 else: 827 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 828 829 self.settings = kwargs 830 831 def __eq__(self, other: t.Any) -> bool: 832 # Does not currently take dialect state into account 833 return type(self) == other 834 835 def __hash__(self) -> int: 836 # Does not currently take dialect state into account 837 return hash(type(self)) 838 839 def normalize_identifier(self, expression: E) -> E: 840 """ 841 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 842 843 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 844 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 845 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 846 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 847 848 There are also dialects like Spark, which are case-insensitive even when quotes are 849 present, and dialects like MySQL, whose resolution rules match those employed by the 850 underlying operating system, for example they may always be case-sensitive in Linux. 851 852 Finally, the normalization behavior of some engines can even be controlled through flags, 853 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 854 855 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 856 that it can analyze queries in the optimizer and successfully capture their semantics. 857 """ 858 if ( 859 isinstance(expression, exp.Identifier) 860 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 861 and ( 862 not expression.quoted 863 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 864 ) 865 ): 866 expression.set( 867 "this", 868 ( 869 expression.this.upper() 870 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 871 else expression.this.lower() 872 ), 873 ) 874 875 return expression 876 877 def case_sensitive(self, text: str) -> bool: 878 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 879 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 880 return False 881 882 unsafe = ( 883 str.islower 884 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 885 else str.isupper 886 ) 887 return any(unsafe(char) for char in text) 888 889 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 890 """Checks if text can be identified given an identify option. 891 892 Args: 893 text: The text to check. 894 identify: 895 `"always"` or `True`: Always returns `True`. 896 `"safe"`: Only returns `True` if the identifier is case-insensitive. 897 898 Returns: 899 Whether the given text can be identified. 900 """ 901 if identify is True or identify == "always": 902 return True 903 904 if identify == "safe": 905 return not self.case_sensitive(text) 906 907 return False 908 909 def quote_identifier(self, expression: E, identify: bool = True) -> E: 910 """ 911 Adds quotes to a given identifier. 912 913 Args: 914 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 915 identify: If set to `False`, the quotes will only be added if the identifier is deemed 916 "unsafe", with respect to its characters and this dialect's normalization strategy. 917 """ 918 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 919 name = expression.this 920 expression.set( 921 "quoted", 922 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 923 ) 924 925 return expression 926 927 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 928 if isinstance(path, exp.Literal): 929 path_text = path.name 930 if path.is_number: 931 path_text = f"[{path_text}]" 932 try: 933 return parse_json_path(path_text, self) 934 except ParseError as e: 935 if self.STRICT_JSON_PATH_SYNTAX: 936 logger.warning(f"Invalid JSON path syntax. {str(e)}") 937 938 return path 939 940 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 941 return self.parser(**opts).parse(self.tokenize(sql), sql) 942 943 def parse_into( 944 self, expression_type: exp.IntoType, sql: str, **opts 945 ) -> t.List[t.Optional[exp.Expression]]: 946 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 947 948 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 949 return self.generator(**opts).generate(expression, copy=copy) 950 951 def transpile(self, sql: str, **opts) -> t.List[str]: 952 return [ 953 self.generate(expression, copy=False, **opts) if expression else "" 954 for expression in self.parse(sql) 955 ] 956 957 def tokenize(self, sql: str) -> t.List[Token]: 958 return self.tokenizer.tokenize(sql) 959 960 @property 961 def tokenizer(self) -> Tokenizer: 962 return self.tokenizer_class(dialect=self) 963 964 @property 965 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 966 return self.jsonpath_tokenizer_class(dialect=self) 967 968 def parser(self, **opts) -> Parser: 969 return self.parser_class(dialect=self, **opts) 970 971 def generator(self, **opts) -> Generator: 972 return self.generator_class(dialect=self, **opts) 973 974 975DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 976 977 978def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 979 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 980 981 982@unsupported_args("accuracy") 983def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 984 return self.func("APPROX_COUNT_DISTINCT", expression.this) 985 986 987def if_sql( 988 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 989) -> t.Callable[[Generator, exp.If], str]: 990 def _if_sql(self: Generator, expression: exp.If) -> str: 991 return self.func( 992 name, 993 expression.this, 994 expression.args.get("true"), 995 expression.args.get("false") or false_value, 996 ) 997 998 return _if_sql 999 1000 1001def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1002 this = expression.this 1003 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1004 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1005 1006 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 1007 1008 1009def inline_array_sql(self: Generator, expression: exp.Array) -> str: 1010 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 1011 1012 1013def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 1014 elem = seq_get(expression.expressions, 0) 1015 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 1016 return self.func("ARRAY", elem) 1017 return inline_array_sql(self, expression) 1018 1019 1020def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1021 return self.like_sql( 1022 exp.Like( 1023 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1024 ) 1025 ) 1026 1027 1028def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1029 zone = self.sql(expression, "this") 1030 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1031 1032 1033def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1034 if expression.args.get("recursive"): 1035 self.unsupported("Recursive CTEs are unsupported") 1036 expression.args["recursive"] = False 1037 return self.with_sql(expression) 1038 1039 1040def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str: 1041 n = self.sql(expression, "this") 1042 d = self.sql(expression, "expression") 1043 return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)" 1044 1045 1046def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1047 self.unsupported("TABLESAMPLE unsupported") 1048 return self.sql(expression.this) 1049 1050 1051def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1052 self.unsupported("PIVOT unsupported") 1053 return "" 1054 1055 1056def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1057 return self.cast_sql(expression) 1058 1059 1060def no_comment_column_constraint_sql( 1061 self: Generator, expression: exp.CommentColumnConstraint 1062) -> str: 1063 self.unsupported("CommentColumnConstraint unsupported") 1064 return "" 1065 1066 1067def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1068 self.unsupported("MAP_FROM_ENTRIES unsupported") 1069 return "" 1070 1071 1072def property_sql(self: Generator, expression: exp.Property) -> str: 1073 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1074 1075 1076def str_position_sql( 1077 self: Generator, 1078 expression: exp.StrPosition, 1079 generate_instance: bool = False, 1080 str_position_func_name: str = "STRPOS", 1081) -> str: 1082 this = self.sql(expression, "this") 1083 substr = self.sql(expression, "substr") 1084 position = self.sql(expression, "position") 1085 instance = expression.args.get("instance") if generate_instance else None 1086 position_offset = "" 1087 1088 if position: 1089 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1090 this = self.func("SUBSTR", this, position) 1091 position_offset = f" + {position} - 1" 1092 1093 strpos_sql = self.func(str_position_func_name, this, substr, instance) 1094 1095 if position_offset: 1096 zero = exp.Literal.number(0) 1097 # If match is not found (returns 0) the position offset should not be applied 1098 case = exp.If( 1099 this=exp.EQ(this=strpos_sql, expression=zero), 1100 true=zero, 1101 false=strpos_sql + position_offset, 1102 ) 1103 strpos_sql = self.sql(case) 1104 1105 return strpos_sql 1106 1107 1108def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1109 return ( 1110 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1111 ) 1112 1113 1114def var_map_sql( 1115 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1116) -> str: 1117 keys = expression.args["keys"] 1118 values = expression.args["values"] 1119 1120 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1121 self.unsupported("Cannot convert array columns into map.") 1122 return self.func(map_func_name, keys, values) 1123 1124 args = [] 1125 for key, value in zip(keys.expressions, values.expressions): 1126 args.append(self.sql(key)) 1127 args.append(self.sql(value)) 1128 1129 return self.func(map_func_name, *args) 1130 1131 1132def build_formatted_time( 1133 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1134) -> t.Callable[[t.List], E]: 1135 """Helper used for time expressions. 1136 1137 Args: 1138 exp_class: the expression class to instantiate. 1139 dialect: target sql dialect. 1140 default: the default format, True being time. 1141 1142 Returns: 1143 A callable that can be used to return the appropriately formatted time expression. 1144 """ 1145 1146 def _builder(args: t.List): 1147 return exp_class( 1148 this=seq_get(args, 0), 1149 format=Dialect[dialect].format_time( 1150 seq_get(args, 1) 1151 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1152 ), 1153 ) 1154 1155 return _builder 1156 1157 1158def time_format( 1159 dialect: DialectType = None, 1160) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1161 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1162 """ 1163 Returns the time format for a given expression, unless it's equivalent 1164 to the default time format of the dialect of interest. 1165 """ 1166 time_format = self.format_time(expression) 1167 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1168 1169 return _time_format 1170 1171 1172def build_date_delta( 1173 exp_class: t.Type[E], 1174 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1175 default_unit: t.Optional[str] = "DAY", 1176) -> t.Callable[[t.List], E]: 1177 def _builder(args: t.List) -> E: 1178 unit_based = len(args) == 3 1179 this = args[2] if unit_based else seq_get(args, 0) 1180 unit = None 1181 if unit_based or default_unit: 1182 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1183 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1184 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1185 1186 return _builder 1187 1188 1189def build_date_delta_with_interval( 1190 expression_class: t.Type[E], 1191) -> t.Callable[[t.List], t.Optional[E]]: 1192 def _builder(args: t.List) -> t.Optional[E]: 1193 if len(args) < 2: 1194 return None 1195 1196 interval = args[1] 1197 1198 if not isinstance(interval, exp.Interval): 1199 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1200 1201 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1202 1203 return _builder 1204 1205 1206def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1207 unit = seq_get(args, 0) 1208 this = seq_get(args, 1) 1209 1210 if isinstance(this, exp.Cast) and this.is_type("date"): 1211 return exp.DateTrunc(unit=unit, this=this) 1212 return exp.TimestampTrunc(this=this, unit=unit) 1213 1214 1215def date_add_interval_sql( 1216 data_type: str, kind: str 1217) -> t.Callable[[Generator, exp.Expression], str]: 1218 def func(self: Generator, expression: exp.Expression) -> str: 1219 this = self.sql(expression, "this") 1220 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1221 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1222 1223 return func 1224 1225 1226def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1227 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1228 args = [unit_to_str(expression), expression.this] 1229 if zone: 1230 args.append(expression.args.get("zone")) 1231 return self.func("DATE_TRUNC", *args) 1232 1233 return _timestamptrunc_sql 1234 1235 1236def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1237 zone = expression.args.get("zone") 1238 if not zone: 1239 from sqlglot.optimizer.annotate_types import annotate_types 1240 1241 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1242 return self.sql(exp.cast(expression.this, target_type)) 1243 if zone.name.lower() in TIMEZONES: 1244 return self.sql( 1245 exp.AtTimeZone( 1246 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1247 zone=zone, 1248 ) 1249 ) 1250 return self.func("TIMESTAMP", expression.this, zone) 1251 1252 1253def no_time_sql(self: Generator, expression: exp.Time) -> str: 1254 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1255 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1256 expr = exp.cast( 1257 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1258 ) 1259 return self.sql(expr) 1260 1261 1262def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1263 this = expression.this 1264 expr = expression.expression 1265 1266 if expr.name.lower() in TIMEZONES: 1267 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1268 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1269 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1270 return self.sql(this) 1271 1272 this = exp.cast(this, exp.DataType.Type.DATE) 1273 expr = exp.cast(expr, exp.DataType.Type.TIME) 1274 1275 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1276 1277 1278def locate_to_strposition(args: t.List) -> exp.Expression: 1279 return exp.StrPosition( 1280 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1281 ) 1282 1283 1284def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1285 return self.func( 1286 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1287 ) 1288 1289 1290def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1291 return self.sql( 1292 exp.Substring( 1293 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1294 ) 1295 ) 1296 1297 1298def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1299 return self.sql( 1300 exp.Substring( 1301 this=expression.this, 1302 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1303 ) 1304 ) 1305 1306 1307def timestrtotime_sql( 1308 self: Generator, 1309 expression: exp.TimeStrToTime, 1310 include_precision: bool = False, 1311) -> str: 1312 datatype = exp.DataType.build( 1313 exp.DataType.Type.TIMESTAMPTZ 1314 if expression.args.get("zone") 1315 else exp.DataType.Type.TIMESTAMP 1316 ) 1317 1318 if isinstance(expression.this, exp.Literal) and include_precision: 1319 precision = subsecond_precision(expression.this.name) 1320 if precision > 0: 1321 datatype = exp.DataType.build( 1322 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1323 ) 1324 1325 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1326 1327 1328def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1329 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1330 1331 1332# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1333def encode_decode_sql( 1334 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1335) -> str: 1336 charset = expression.args.get("charset") 1337 if charset and charset.name.lower() != "utf-8": 1338 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1339 1340 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1341 1342 1343def min_or_least(self: Generator, expression: exp.Min) -> str: 1344 name = "LEAST" if expression.expressions else "MIN" 1345 return rename_func(name)(self, expression) 1346 1347 1348def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1349 name = "GREATEST" if expression.expressions else "MAX" 1350 return rename_func(name)(self, expression) 1351 1352 1353def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1354 cond = expression.this 1355 1356 if isinstance(expression.this, exp.Distinct): 1357 cond = expression.this.expressions[0] 1358 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1359 1360 return self.func("sum", exp.func("if", cond, 1, 0)) 1361 1362 1363def trim_sql(self: Generator, expression: exp.Trim) -> str: 1364 target = self.sql(expression, "this") 1365 trim_type = self.sql(expression, "position") 1366 remove_chars = self.sql(expression, "expression") 1367 collation = self.sql(expression, "collation") 1368 1369 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1370 if not remove_chars: 1371 return self.trim_sql(expression) 1372 1373 trim_type = f"{trim_type} " if trim_type else "" 1374 remove_chars = f"{remove_chars} " if remove_chars else "" 1375 from_part = "FROM " if trim_type or remove_chars else "" 1376 collation = f" COLLATE {collation}" if collation else "" 1377 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1378 1379 1380def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1381 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1382 1383 1384def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1385 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1386 1387 1388def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1389 delim, *rest_args = expression.expressions 1390 return self.sql( 1391 reduce( 1392 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1393 rest_args, 1394 ) 1395 ) 1396 1397 1398@unsupported_args("position", "occurrence", "parameters") 1399def regexp_extract_sql( 1400 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1401) -> str: 1402 group = expression.args.get("group") 1403 1404 # Do not render group if it's the default value for this dialect 1405 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1406 group = None 1407 1408 return self.func(expression.sql_name(), expression.this, expression.expression, group) 1409 1410 1411@unsupported_args("position", "occurrence", "modifiers") 1412def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1413 return self.func( 1414 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1415 ) 1416 1417 1418def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1419 names = [] 1420 for agg in aggregations: 1421 if isinstance(agg, exp.Alias): 1422 names.append(agg.alias) 1423 else: 1424 """ 1425 This case corresponds to aggregations without aliases being used as suffixes 1426 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1427 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1428 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1429 """ 1430 agg_all_unquoted = agg.transform( 1431 lambda node: ( 1432 exp.Identifier(this=node.name, quoted=False) 1433 if isinstance(node, exp.Identifier) 1434 else node 1435 ) 1436 ) 1437 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1438 1439 return names 1440 1441 1442def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1443 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1444 1445 1446# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1447def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1448 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1449 1450 1451def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1452 return self.func("MAX", expression.this) 1453 1454 1455def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1456 a = self.sql(expression.left) 1457 b = self.sql(expression.right) 1458 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1459 1460 1461def is_parse_json(expression: exp.Expression) -> bool: 1462 return isinstance(expression, exp.ParseJSON) or ( 1463 isinstance(expression, exp.Cast) and expression.is_type("json") 1464 ) 1465 1466 1467def isnull_to_is_null(args: t.List) -> exp.Expression: 1468 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1469 1470 1471def generatedasidentitycolumnconstraint_sql( 1472 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1473) -> str: 1474 start = self.sql(expression, "start") or "1" 1475 increment = self.sql(expression, "increment") or "1" 1476 return f"IDENTITY({start}, {increment})" 1477 1478 1479def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1480 @unsupported_args("count") 1481 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1482 return self.func(name, expression.this, expression.expression) 1483 1484 return _arg_max_or_min_sql 1485 1486 1487def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1488 this = expression.this.copy() 1489 1490 return_type = expression.return_type 1491 if return_type.is_type(exp.DataType.Type.DATE): 1492 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1493 # can truncate timestamp strings, because some dialects can't cast them to DATE 1494 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1495 1496 expression.this.replace(exp.cast(this, return_type)) 1497 return expression 1498 1499 1500def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1501 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1502 if cast and isinstance(expression, exp.TsOrDsAdd): 1503 expression = ts_or_ds_add_cast(expression) 1504 1505 return self.func( 1506 name, 1507 unit_to_var(expression), 1508 expression.expression, 1509 expression.this, 1510 ) 1511 1512 return _delta_sql 1513 1514 1515def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1516 unit = expression.args.get("unit") 1517 1518 if isinstance(unit, exp.Placeholder): 1519 return unit 1520 if unit: 1521 return exp.Literal.string(unit.name) 1522 return exp.Literal.string(default) if default else None 1523 1524 1525def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1526 unit = expression.args.get("unit") 1527 1528 if isinstance(unit, (exp.Var, exp.Placeholder)): 1529 return unit 1530 return exp.Var(this=default) if default else None 1531 1532 1533@t.overload 1534def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1535 pass 1536 1537 1538@t.overload 1539def map_date_part( 1540 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1541) -> t.Optional[exp.Expression]: 1542 pass 1543 1544 1545def map_date_part(part, dialect: DialectType = Dialect): 1546 mapped = ( 1547 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1548 ) 1549 return exp.var(mapped) if mapped else part 1550 1551 1552def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1553 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1554 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1555 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1556 1557 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1558 1559 1560def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1561 """Remove table refs from columns in when statements.""" 1562 alias = expression.this.args.get("alias") 1563 1564 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1565 return self.dialect.normalize_identifier(identifier).name if identifier else None 1566 1567 targets = {normalize(expression.this.this)} 1568 1569 if alias: 1570 targets.add(normalize(alias.this)) 1571 1572 for when in expression.args["whens"].expressions: 1573 # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED 1574 # they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part 1575 # (not the column list) of the INSERT 1576 then: exp.Insert | exp.Update | None = when.args.get("then") 1577 if then: 1578 if isinstance(then, exp.Update): 1579 for equals in then.find_all(exp.EQ): 1580 equal_lhs = equals.this 1581 if ( 1582 isinstance(equal_lhs, exp.Column) 1583 and normalize(equal_lhs.args.get("table")) in targets 1584 ): 1585 equal_lhs.replace(exp.column(equal_lhs.this)) 1586 if isinstance(then, exp.Insert): 1587 column_list = then.this 1588 if isinstance(column_list, exp.Tuple): 1589 for column in column_list.expressions: 1590 if normalize(column.args.get("table")) in targets: 1591 column.replace(exp.column(column.this)) 1592 1593 return self.merge_sql(expression) 1594 1595 1596def build_json_extract_path( 1597 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1598) -> t.Callable[[t.List], F]: 1599 def _builder(args: t.List) -> F: 1600 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1601 for arg in args[1:]: 1602 if not isinstance(arg, exp.Literal): 1603 # We use the fallback parser because we can't really transpile non-literals safely 1604 return expr_type.from_arg_list(args) 1605 1606 text = arg.name 1607 if is_int(text): 1608 index = int(text) 1609 segments.append( 1610 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1611 ) 1612 else: 1613 segments.append(exp.JSONPathKey(this=text)) 1614 1615 # This is done to avoid failing in the expression validator due to the arg count 1616 del args[2:] 1617 return expr_type( 1618 this=seq_get(args, 0), 1619 expression=exp.JSONPath(expressions=segments), 1620 only_json_types=arrow_req_json_type, 1621 ) 1622 1623 return _builder 1624 1625 1626def json_extract_segments( 1627 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1628) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1629 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1630 path = expression.expression 1631 if not isinstance(path, exp.JSONPath): 1632 return rename_func(name)(self, expression) 1633 1634 escape = path.args.get("escape") 1635 1636 segments = [] 1637 for segment in path.expressions: 1638 path = self.sql(segment) 1639 if path: 1640 if isinstance(segment, exp.JSONPathPart) and ( 1641 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1642 ): 1643 if escape: 1644 path = self.escape_str(path) 1645 1646 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1647 1648 segments.append(path) 1649 1650 if op: 1651 return f" {op} ".join([self.sql(expression.this), *segments]) 1652 return self.func(name, expression.this, *segments) 1653 1654 return _json_extract_segments 1655 1656 1657def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1658 if isinstance(expression.this, exp.JSONPathWildcard): 1659 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1660 1661 return expression.name 1662 1663 1664def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1665 cond = expression.expression 1666 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1667 alias = cond.expressions[0] 1668 cond = cond.this 1669 elif isinstance(cond, exp.Predicate): 1670 alias = "_u" 1671 else: 1672 self.unsupported("Unsupported filter condition") 1673 return "" 1674 1675 unnest = exp.Unnest(expressions=[expression.this]) 1676 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1677 return self.sql(exp.Array(expressions=[filtered])) 1678 1679 1680def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1681 return self.func( 1682 "TO_NUMBER", 1683 expression.this, 1684 expression.args.get("format"), 1685 expression.args.get("nlsparam"), 1686 ) 1687 1688 1689def build_default_decimal_type( 1690 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1691) -> t.Callable[[exp.DataType], exp.DataType]: 1692 def _builder(dtype: exp.DataType) -> exp.DataType: 1693 if dtype.expressions or precision is None: 1694 return dtype 1695 1696 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1697 return exp.DataType.build(f"DECIMAL({params})") 1698 1699 return _builder 1700 1701 1702def build_timestamp_from_parts(args: t.List) -> exp.Func: 1703 if len(args) == 2: 1704 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1705 # so we parse this into Anonymous for now instead of introducing complexity 1706 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1707 1708 return exp.TimestampFromParts.from_arg_list(args) 1709 1710 1711def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1712 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1713 1714 1715def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1716 start = expression.args.get("start") 1717 end = expression.args.get("end") 1718 step = expression.args.get("step") 1719 1720 if isinstance(start, exp.Cast): 1721 target_type = start.to 1722 elif isinstance(end, exp.Cast): 1723 target_type = end.to 1724 else: 1725 target_type = None 1726 1727 if start and end and target_type and target_type.is_type("date", "timestamp"): 1728 if isinstance(start, exp.Cast) and target_type is start.to: 1729 end = exp.cast(end, target_type) 1730 else: 1731 start = exp.cast(start, target_type) 1732 1733 return self.func("SEQUENCE", start, end, step) 1734 1735 1736def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1737 def _builder(args: t.List, dialect: Dialect) -> E: 1738 return expr_type( 1739 this=seq_get(args, 0), 1740 expression=seq_get(args, 1), 1741 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1742 parameters=seq_get(args, 3), 1743 ) 1744 1745 return _builder 1746 1747 1748def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1749 if isinstance(expression.this, exp.Explode): 1750 return self.sql( 1751 exp.Join( 1752 this=exp.Unnest( 1753 expressions=[expression.this.this], 1754 alias=expression.args.get("alias"), 1755 offset=isinstance(expression.this, exp.Posexplode), 1756 ), 1757 kind="cross", 1758 ) 1759 ) 1760 return self.lateral_sql(expression) 1761 1762 1763def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: 1764 return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) 1765 1766 1767def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1768 args = [] 1769 for unit, value in expression.args.items(): 1770 if isinstance(value, exp.Kwarg): 1771 value = value.expression 1772 1773 args.append(f"{value} {unit}") 1774 1775 return f"INTERVAL '{self.format_args(*args, sep=sep)}'" 1776 1777 1778def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str: 1779 length_func = "LENGTH" if expression.args.get("binary") else "CHAR_LENGTH" 1780 return self.func(length_func, expression.this)
55class Dialects(str, Enum): 56 """Dialects supported by SQLGLot.""" 57 58 DIALECT = "" 59 60 ATHENA = "athena" 61 BIGQUERY = "bigquery" 62 CLICKHOUSE = "clickhouse" 63 DATABRICKS = "databricks" 64 DORIS = "doris" 65 DRILL = "drill" 66 DRUID = "druid" 67 DUCKDB = "duckdb" 68 HIVE = "hive" 69 MATERIALIZE = "materialize" 70 MYSQL = "mysql" 71 ORACLE = "oracle" 72 POSTGRES = "postgres" 73 PRESTO = "presto" 74 PRQL = "prql" 75 REDSHIFT = "redshift" 76 RISINGWAVE = "risingwave" 77 SNOWFLAKE = "snowflake" 78 SPARK = "spark" 79 SPARK2 = "spark2" 80 SQLITE = "sqlite" 81 STARROCKS = "starrocks" 82 TABLEAU = "tableau" 83 TERADATA = "teradata" 84 TRINO = "trino" 85 TSQL = "tsql"
Dialects supported by SQLGLot.
88class NormalizationStrategy(str, AutoName): 89 """Specifies the strategy according to which identifiers should be normalized.""" 90 91 LOWERCASE = auto() 92 """Unquoted identifiers are lowercased.""" 93 94 UPPERCASE = auto() 95 """Unquoted identifiers are uppercased.""" 96 97 CASE_SENSITIVE = auto() 98 """Always case-sensitive, regardless of quotes.""" 99 100 CASE_INSENSITIVE = auto() 101 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
226class Dialect(metaclass=_Dialect): 227 INDEX_OFFSET = 0 228 """The base index offset for arrays.""" 229 230 WEEK_OFFSET = 0 231 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 232 233 UNNEST_COLUMN_ONLY = False 234 """Whether `UNNEST` table aliases are treated as column aliases.""" 235 236 ALIAS_POST_TABLESAMPLE = False 237 """Whether the table alias comes after tablesample.""" 238 239 TABLESAMPLE_SIZE_IS_PERCENT = False 240 """Whether a size in the table sample clause represents percentage.""" 241 242 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 243 """Specifies the strategy according to which identifiers should be normalized.""" 244 245 IDENTIFIERS_CAN_START_WITH_DIGIT = False 246 """Whether an unquoted identifier can start with a digit.""" 247 248 DPIPE_IS_STRING_CONCAT = True 249 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 250 251 STRICT_STRING_CONCAT = False 252 """Whether `CONCAT`'s arguments must be strings.""" 253 254 SUPPORTS_USER_DEFINED_TYPES = True 255 """Whether user-defined data types are supported.""" 256 257 SUPPORTS_SEMI_ANTI_JOIN = True 258 """Whether `SEMI` or `ANTI` joins are supported.""" 259 260 SUPPORTS_COLUMN_JOIN_MARKS = False 261 """Whether the old-style outer join (+) syntax is supported.""" 262 263 COPY_PARAMS_ARE_CSV = True 264 """Separator of COPY statement parameters.""" 265 266 NORMALIZE_FUNCTIONS: bool | str = "upper" 267 """ 268 Determines how function names are going to be normalized. 269 Possible values: 270 "upper" or True: Convert names to uppercase. 271 "lower": Convert names to lowercase. 272 False: Disables function name normalization. 273 """ 274 275 PRESERVE_ORIGINAL_NAMES: bool = False 276 """ 277 Whether the name of the function should be preserved inside the node's metadata, 278 can be useful for roundtripping deprecated vs new functions that share an AST node 279 e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery 280 """ 281 282 LOG_BASE_FIRST: t.Optional[bool] = True 283 """ 284 Whether the base comes first in the `LOG` function. 285 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 286 """ 287 288 NULL_ORDERING = "nulls_are_small" 289 """ 290 Default `NULL` ordering method to use if not explicitly set. 291 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 292 """ 293 294 TYPED_DIVISION = False 295 """ 296 Whether the behavior of `a / b` depends on the types of `a` and `b`. 297 False means `a / b` is always float division. 298 True means `a / b` is integer division if both `a` and `b` are integers. 299 """ 300 301 SAFE_DIVISION = False 302 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 303 304 CONCAT_COALESCE = False 305 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 306 307 HEX_LOWERCASE = False 308 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 309 310 DATE_FORMAT = "'%Y-%m-%d'" 311 DATEINT_FORMAT = "'%Y%m%d'" 312 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 313 314 TIME_MAPPING: t.Dict[str, str] = {} 315 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 316 317 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 318 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 319 FORMAT_MAPPING: t.Dict[str, str] = {} 320 """ 321 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 322 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 323 """ 324 325 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 326 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 327 328 PSEUDOCOLUMNS: t.Set[str] = set() 329 """ 330 Columns that are auto-generated by the engine corresponding to this dialect. 331 For example, such columns may be excluded from `SELECT *` queries. 332 """ 333 334 PREFER_CTE_ALIAS_COLUMN = False 335 """ 336 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 337 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 338 any projection aliases in the subquery. 339 340 For example, 341 WITH y(c) AS ( 342 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 343 ) SELECT c FROM y; 344 345 will be rewritten as 346 347 WITH y(c) AS ( 348 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 349 ) SELECT c FROM y; 350 """ 351 352 COPY_PARAMS_ARE_CSV = True 353 """ 354 Whether COPY statement parameters are separated by comma or whitespace 355 """ 356 357 FORCE_EARLY_ALIAS_REF_EXPANSION = False 358 """ 359 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 360 361 For example: 362 WITH data AS ( 363 SELECT 364 1 AS id, 365 2 AS my_id 366 ) 367 SELECT 368 id AS my_id 369 FROM 370 data 371 WHERE 372 my_id = 1 373 GROUP BY 374 my_id, 375 HAVING 376 my_id = 1 377 378 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 379 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 380 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 381 - Clickhouse, which will forward the alias across the query i.e it resolves 382 to "WHERE id = 1 GROUP BY id HAVING id = 1" 383 """ 384 385 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 386 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 387 388 SUPPORTS_ORDER_BY_ALL = False 389 """ 390 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 391 """ 392 393 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 394 """ 395 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 396 as the former is of type INT[] vs the latter which is SUPER 397 """ 398 399 SUPPORTS_FIXED_SIZE_ARRAYS = False 400 """ 401 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 402 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 403 be interpreted as a subscript/index operator. 404 """ 405 406 STRICT_JSON_PATH_SYNTAX = True 407 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 408 409 ON_CONDITION_EMPTY_BEFORE_ERROR = True 410 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 411 412 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 413 """Whether ArrayAgg needs to filter NULL values.""" 414 415 PROMOTE_TO_INFERRED_DATETIME_TYPE = False 416 """ 417 This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted 418 to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal 419 is cast to x's type to match it instead. 420 """ 421 422 SUPPORTS_VALUES_DEFAULT = True 423 """Whether the DEFAULT keyword is supported in the VALUES clause.""" 424 425 NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False 426 """Whether number literals can include underscores for better readability""" 427 428 REGEXP_EXTRACT_DEFAULT_GROUP = 0 429 """The default value for the capturing group.""" 430 431 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 432 exp.Except: True, 433 exp.Intersect: True, 434 exp.Union: True, 435 } 436 """ 437 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 438 must be explicitly specified. 439 """ 440 441 CREATABLE_KIND_MAPPING: dict[str, str] = {} 442 """ 443 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 444 equivalent of CREATE SCHEMA is CREATE DATABASE. 445 """ 446 447 # --- Autofilled --- 448 449 tokenizer_class = Tokenizer 450 jsonpath_tokenizer_class = JSONPathTokenizer 451 parser_class = Parser 452 generator_class = Generator 453 454 # A trie of the time_mapping keys 455 TIME_TRIE: t.Dict = {} 456 FORMAT_TRIE: t.Dict = {} 457 458 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 459 INVERSE_TIME_TRIE: t.Dict = {} 460 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 461 INVERSE_FORMAT_TRIE: t.Dict = {} 462 463 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 464 465 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 466 467 # Delimiters for string literals and identifiers 468 QUOTE_START = "'" 469 QUOTE_END = "'" 470 IDENTIFIER_START = '"' 471 IDENTIFIER_END = '"' 472 473 # Delimiters for bit, hex, byte and unicode literals 474 BIT_START: t.Optional[str] = None 475 BIT_END: t.Optional[str] = None 476 HEX_START: t.Optional[str] = None 477 HEX_END: t.Optional[str] = None 478 BYTE_START: t.Optional[str] = None 479 BYTE_END: t.Optional[str] = None 480 UNICODE_START: t.Optional[str] = None 481 UNICODE_END: t.Optional[str] = None 482 483 DATE_PART_MAPPING = { 484 "Y": "YEAR", 485 "YY": "YEAR", 486 "YYY": "YEAR", 487 "YYYY": "YEAR", 488 "YR": "YEAR", 489 "YEARS": "YEAR", 490 "YRS": "YEAR", 491 "MM": "MONTH", 492 "MON": "MONTH", 493 "MONS": "MONTH", 494 "MONTHS": "MONTH", 495 "D": "DAY", 496 "DD": "DAY", 497 "DAYS": "DAY", 498 "DAYOFMONTH": "DAY", 499 "DAY OF WEEK": "DAYOFWEEK", 500 "WEEKDAY": "DAYOFWEEK", 501 "DOW": "DAYOFWEEK", 502 "DW": "DAYOFWEEK", 503 "WEEKDAY_ISO": "DAYOFWEEKISO", 504 "DOW_ISO": "DAYOFWEEKISO", 505 "DW_ISO": "DAYOFWEEKISO", 506 "DAY OF YEAR": "DAYOFYEAR", 507 "DOY": "DAYOFYEAR", 508 "DY": "DAYOFYEAR", 509 "W": "WEEK", 510 "WK": "WEEK", 511 "WEEKOFYEAR": "WEEK", 512 "WOY": "WEEK", 513 "WY": "WEEK", 514 "WEEK_ISO": "WEEKISO", 515 "WEEKOFYEARISO": "WEEKISO", 516 "WEEKOFYEAR_ISO": "WEEKISO", 517 "Q": "QUARTER", 518 "QTR": "QUARTER", 519 "QTRS": "QUARTER", 520 "QUARTERS": "QUARTER", 521 "H": "HOUR", 522 "HH": "HOUR", 523 "HR": "HOUR", 524 "HOURS": "HOUR", 525 "HRS": "HOUR", 526 "M": "MINUTE", 527 "MI": "MINUTE", 528 "MIN": "MINUTE", 529 "MINUTES": "MINUTE", 530 "MINS": "MINUTE", 531 "S": "SECOND", 532 "SEC": "SECOND", 533 "SECONDS": "SECOND", 534 "SECS": "SECOND", 535 "MS": "MILLISECOND", 536 "MSEC": "MILLISECOND", 537 "MSECS": "MILLISECOND", 538 "MSECOND": "MILLISECOND", 539 "MSECONDS": "MILLISECOND", 540 "MILLISEC": "MILLISECOND", 541 "MILLISECS": "MILLISECOND", 542 "MILLISECON": "MILLISECOND", 543 "MILLISECONDS": "MILLISECOND", 544 "US": "MICROSECOND", 545 "USEC": "MICROSECOND", 546 "USECS": "MICROSECOND", 547 "MICROSEC": "MICROSECOND", 548 "MICROSECS": "MICROSECOND", 549 "USECOND": "MICROSECOND", 550 "USECONDS": "MICROSECOND", 551 "MICROSECONDS": "MICROSECOND", 552 "NS": "NANOSECOND", 553 "NSEC": "NANOSECOND", 554 "NANOSEC": "NANOSECOND", 555 "NSECOND": "NANOSECOND", 556 "NSECONDS": "NANOSECOND", 557 "NANOSECS": "NANOSECOND", 558 "EPOCH_SECOND": "EPOCH", 559 "EPOCH_SECONDS": "EPOCH", 560 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 561 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 562 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 563 "TZH": "TIMEZONE_HOUR", 564 "TZM": "TIMEZONE_MINUTE", 565 "DEC": "DECADE", 566 "DECS": "DECADE", 567 "DECADES": "DECADE", 568 "MIL": "MILLENIUM", 569 "MILS": "MILLENIUM", 570 "MILLENIA": "MILLENIUM", 571 "C": "CENTURY", 572 "CENT": "CENTURY", 573 "CENTS": "CENTURY", 574 "CENTURIES": "CENTURY", 575 } 576 577 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 578 exp.DataType.Type.BIGINT: { 579 exp.ApproxDistinct, 580 exp.ArraySize, 581 exp.Length, 582 }, 583 exp.DataType.Type.BOOLEAN: { 584 exp.Between, 585 exp.Boolean, 586 exp.In, 587 exp.RegexpLike, 588 }, 589 exp.DataType.Type.DATE: { 590 exp.CurrentDate, 591 exp.Date, 592 exp.DateFromParts, 593 exp.DateStrToDate, 594 exp.DiToDate, 595 exp.StrToDate, 596 exp.TimeStrToDate, 597 exp.TsOrDsToDate, 598 }, 599 exp.DataType.Type.DATETIME: { 600 exp.CurrentDatetime, 601 exp.Datetime, 602 exp.DatetimeAdd, 603 exp.DatetimeSub, 604 }, 605 exp.DataType.Type.DOUBLE: { 606 exp.ApproxQuantile, 607 exp.Avg, 608 exp.Exp, 609 exp.Ln, 610 exp.Log, 611 exp.Pow, 612 exp.Quantile, 613 exp.Round, 614 exp.SafeDivide, 615 exp.Sqrt, 616 exp.Stddev, 617 exp.StddevPop, 618 exp.StddevSamp, 619 exp.ToDouble, 620 exp.Variance, 621 exp.VariancePop, 622 }, 623 exp.DataType.Type.INT: { 624 exp.Ceil, 625 exp.DatetimeDiff, 626 exp.DateDiff, 627 exp.TimestampDiff, 628 exp.TimeDiff, 629 exp.DateToDi, 630 exp.Levenshtein, 631 exp.Sign, 632 exp.StrPosition, 633 exp.TsOrDiToDi, 634 }, 635 exp.DataType.Type.JSON: { 636 exp.ParseJSON, 637 }, 638 exp.DataType.Type.TIME: { 639 exp.Time, 640 }, 641 exp.DataType.Type.TIMESTAMP: { 642 exp.CurrentTime, 643 exp.CurrentTimestamp, 644 exp.StrToTime, 645 exp.TimeAdd, 646 exp.TimeStrToTime, 647 exp.TimeSub, 648 exp.TimestampAdd, 649 exp.TimestampSub, 650 exp.UnixToTime, 651 }, 652 exp.DataType.Type.TINYINT: { 653 exp.Day, 654 exp.Month, 655 exp.Week, 656 exp.Year, 657 exp.Quarter, 658 }, 659 exp.DataType.Type.VARCHAR: { 660 exp.ArrayConcat, 661 exp.Concat, 662 exp.ConcatWs, 663 exp.DateToDateStr, 664 exp.GroupConcat, 665 exp.Initcap, 666 exp.Lower, 667 exp.Substring, 668 exp.String, 669 exp.TimeToStr, 670 exp.TimeToTimeStr, 671 exp.Trim, 672 exp.TsOrDsToDateStr, 673 exp.UnixToStr, 674 exp.UnixToTimeStr, 675 exp.Upper, 676 }, 677 } 678 679 ANNOTATORS: AnnotatorsType = { 680 **{ 681 expr_type: lambda self, e: self._annotate_unary(e) 682 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 683 }, 684 **{ 685 expr_type: lambda self, e: self._annotate_binary(e) 686 for expr_type in subclasses(exp.__name__, exp.Binary) 687 }, 688 **{ 689 expr_type: _annotate_with_type_lambda(data_type) 690 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 691 for expr_type in expressions 692 }, 693 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 694 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 695 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 696 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 697 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Bracket: lambda self, e: self._annotate_bracket(e), 699 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 700 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 701 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 702 exp.Count: lambda self, e: self._annotate_with_type( 703 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 704 ), 705 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 706 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 707 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 708 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 709 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 710 exp.Div: lambda self, e: self._annotate_div(e), 711 exp.Dot: lambda self, e: self._annotate_dot(e), 712 exp.Explode: lambda self, e: self._annotate_explode(e), 713 exp.Extract: lambda self, e: self._annotate_extract(e), 714 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 715 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 716 e, exp.DataType.build("ARRAY<DATE>") 717 ), 718 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 719 e, exp.DataType.build("ARRAY<TIMESTAMP>") 720 ), 721 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 722 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 723 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 724 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 725 exp.Literal: lambda self, e: self._annotate_literal(e), 726 exp.Map: lambda self, e: self._annotate_map(e), 727 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 728 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 729 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 730 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 731 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 732 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 733 exp.Struct: lambda self, e: self._annotate_struct(e), 734 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 735 exp.Timestamp: lambda self, e: self._annotate_with_type( 736 e, 737 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 738 ), 739 exp.ToMap: lambda self, e: self._annotate_to_map(e), 740 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 741 exp.Unnest: lambda self, e: self._annotate_unnest(e), 742 exp.VarMap: lambda self, e: self._annotate_map(e), 743 } 744 745 @classmethod 746 def get_or_raise(cls, dialect: DialectType) -> Dialect: 747 """ 748 Look up a dialect in the global dialect registry and return it if it exists. 749 750 Args: 751 dialect: The target dialect. If this is a string, it can be optionally followed by 752 additional key-value pairs that are separated by commas and are used to specify 753 dialect settings, such as whether the dialect's identifiers are case-sensitive. 754 755 Example: 756 >>> dialect = dialect_class = get_or_raise("duckdb") 757 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 758 759 Returns: 760 The corresponding Dialect instance. 761 """ 762 763 if not dialect: 764 return cls() 765 if isinstance(dialect, _Dialect): 766 return dialect() 767 if isinstance(dialect, Dialect): 768 return dialect 769 if isinstance(dialect, str): 770 try: 771 dialect_name, *kv_strings = dialect.split(",") 772 kv_pairs = (kv.split("=") for kv in kv_strings) 773 kwargs = {} 774 for pair in kv_pairs: 775 key = pair[0].strip() 776 value: t.Union[bool | str | None] = None 777 778 if len(pair) == 1: 779 # Default initialize standalone settings to True 780 value = True 781 elif len(pair) == 2: 782 value = pair[1].strip() 783 784 kwargs[key] = to_bool(value) 785 786 except ValueError: 787 raise ValueError( 788 f"Invalid dialect format: '{dialect}'. " 789 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 790 ) 791 792 result = cls.get(dialect_name.strip()) 793 if not result: 794 from difflib import get_close_matches 795 796 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 797 if similar: 798 similar = f" Did you mean {similar}?" 799 800 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 801 802 return result(**kwargs) 803 804 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 805 806 @classmethod 807 def format_time( 808 cls, expression: t.Optional[str | exp.Expression] 809 ) -> t.Optional[exp.Expression]: 810 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 811 if isinstance(expression, str): 812 return exp.Literal.string( 813 # the time formats are quoted 814 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 815 ) 816 817 if expression and expression.is_string: 818 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 819 820 return expression 821 822 def __init__(self, **kwargs) -> None: 823 normalization_strategy = kwargs.pop("normalization_strategy", None) 824 825 if normalization_strategy is None: 826 self.normalization_strategy = self.NORMALIZATION_STRATEGY 827 else: 828 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 829 830 self.settings = kwargs 831 832 def __eq__(self, other: t.Any) -> bool: 833 # Does not currently take dialect state into account 834 return type(self) == other 835 836 def __hash__(self) -> int: 837 # Does not currently take dialect state into account 838 return hash(type(self)) 839 840 def normalize_identifier(self, expression: E) -> E: 841 """ 842 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 843 844 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 845 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 846 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 847 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 848 849 There are also dialects like Spark, which are case-insensitive even when quotes are 850 present, and dialects like MySQL, whose resolution rules match those employed by the 851 underlying operating system, for example they may always be case-sensitive in Linux. 852 853 Finally, the normalization behavior of some engines can even be controlled through flags, 854 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 855 856 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 857 that it can analyze queries in the optimizer and successfully capture their semantics. 858 """ 859 if ( 860 isinstance(expression, exp.Identifier) 861 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 862 and ( 863 not expression.quoted 864 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 865 ) 866 ): 867 expression.set( 868 "this", 869 ( 870 expression.this.upper() 871 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 872 else expression.this.lower() 873 ), 874 ) 875 876 return expression 877 878 def case_sensitive(self, text: str) -> bool: 879 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 880 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 881 return False 882 883 unsafe = ( 884 str.islower 885 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 886 else str.isupper 887 ) 888 return any(unsafe(char) for char in text) 889 890 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 891 """Checks if text can be identified given an identify option. 892 893 Args: 894 text: The text to check. 895 identify: 896 `"always"` or `True`: Always returns `True`. 897 `"safe"`: Only returns `True` if the identifier is case-insensitive. 898 899 Returns: 900 Whether the given text can be identified. 901 """ 902 if identify is True or identify == "always": 903 return True 904 905 if identify == "safe": 906 return not self.case_sensitive(text) 907 908 return False 909 910 def quote_identifier(self, expression: E, identify: bool = True) -> E: 911 """ 912 Adds quotes to a given identifier. 913 914 Args: 915 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 916 identify: If set to `False`, the quotes will only be added if the identifier is deemed 917 "unsafe", with respect to its characters and this dialect's normalization strategy. 918 """ 919 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 920 name = expression.this 921 expression.set( 922 "quoted", 923 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 924 ) 925 926 return expression 927 928 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 929 if isinstance(path, exp.Literal): 930 path_text = path.name 931 if path.is_number: 932 path_text = f"[{path_text}]" 933 try: 934 return parse_json_path(path_text, self) 935 except ParseError as e: 936 if self.STRICT_JSON_PATH_SYNTAX: 937 logger.warning(f"Invalid JSON path syntax. {str(e)}") 938 939 return path 940 941 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 942 return self.parser(**opts).parse(self.tokenize(sql), sql) 943 944 def parse_into( 945 self, expression_type: exp.IntoType, sql: str, **opts 946 ) -> t.List[t.Optional[exp.Expression]]: 947 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 948 949 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 950 return self.generator(**opts).generate(expression, copy=copy) 951 952 def transpile(self, sql: str, **opts) -> t.List[str]: 953 return [ 954 self.generate(expression, copy=False, **opts) if expression else "" 955 for expression in self.parse(sql) 956 ] 957 958 def tokenize(self, sql: str) -> t.List[Token]: 959 return self.tokenizer.tokenize(sql) 960 961 @property 962 def tokenizer(self) -> Tokenizer: 963 return self.tokenizer_class(dialect=self) 964 965 @property 966 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 967 return self.jsonpath_tokenizer_class(dialect=self) 968 969 def parser(self, **opts) -> Parser: 970 return self.parser_class(dialect=self, **opts) 971 972 def generator(self, **opts) -> Generator: 973 return self.generator_class(dialect=self, **opts)
822 def __init__(self, **kwargs) -> None: 823 normalization_strategy = kwargs.pop("normalization_strategy", None) 824 825 if normalization_strategy is None: 826 self.normalization_strategy = self.NORMALIZATION_STRATEGY 827 else: 828 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 829 830 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the name of the function should be preserved inside the node's metadata, can be useful for roundtripping deprecated vs new functions that share an AST node e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal is cast to x's type to match it instead.
Whether number literals can include underscores for better readability
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
745 @classmethod 746 def get_or_raise(cls, dialect: DialectType) -> Dialect: 747 """ 748 Look up a dialect in the global dialect registry and return it if it exists. 749 750 Args: 751 dialect: The target dialect. If this is a string, it can be optionally followed by 752 additional key-value pairs that are separated by commas and are used to specify 753 dialect settings, such as whether the dialect's identifiers are case-sensitive. 754 755 Example: 756 >>> dialect = dialect_class = get_or_raise("duckdb") 757 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 758 759 Returns: 760 The corresponding Dialect instance. 761 """ 762 763 if not dialect: 764 return cls() 765 if isinstance(dialect, _Dialect): 766 return dialect() 767 if isinstance(dialect, Dialect): 768 return dialect 769 if isinstance(dialect, str): 770 try: 771 dialect_name, *kv_strings = dialect.split(",") 772 kv_pairs = (kv.split("=") for kv in kv_strings) 773 kwargs = {} 774 for pair in kv_pairs: 775 key = pair[0].strip() 776 value: t.Union[bool | str | None] = None 777 778 if len(pair) == 1: 779 # Default initialize standalone settings to True 780 value = True 781 elif len(pair) == 2: 782 value = pair[1].strip() 783 784 kwargs[key] = to_bool(value) 785 786 except ValueError: 787 raise ValueError( 788 f"Invalid dialect format: '{dialect}'. " 789 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 790 ) 791 792 result = cls.get(dialect_name.strip()) 793 if not result: 794 from difflib import get_close_matches 795 796 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 797 if similar: 798 similar = f" Did you mean {similar}?" 799 800 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 801 802 return result(**kwargs) 803 804 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
806 @classmethod 807 def format_time( 808 cls, expression: t.Optional[str | exp.Expression] 809 ) -> t.Optional[exp.Expression]: 810 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 811 if isinstance(expression, str): 812 return exp.Literal.string( 813 # the time formats are quoted 814 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 815 ) 816 817 if expression and expression.is_string: 818 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 819 820 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
840 def normalize_identifier(self, expression: E) -> E: 841 """ 842 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 843 844 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 845 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 846 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 847 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 848 849 There are also dialects like Spark, which are case-insensitive even when quotes are 850 present, and dialects like MySQL, whose resolution rules match those employed by the 851 underlying operating system, for example they may always be case-sensitive in Linux. 852 853 Finally, the normalization behavior of some engines can even be controlled through flags, 854 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 855 856 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 857 that it can analyze queries in the optimizer and successfully capture their semantics. 858 """ 859 if ( 860 isinstance(expression, exp.Identifier) 861 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 862 and ( 863 not expression.quoted 864 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 865 ) 866 ): 867 expression.set( 868 "this", 869 ( 870 expression.this.upper() 871 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 872 else expression.this.lower() 873 ), 874 ) 875 876 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
878 def case_sensitive(self, text: str) -> bool: 879 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 880 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 881 return False 882 883 unsafe = ( 884 str.islower 885 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 886 else str.isupper 887 ) 888 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
890 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 891 """Checks if text can be identified given an identify option. 892 893 Args: 894 text: The text to check. 895 identify: 896 `"always"` or `True`: Always returns `True`. 897 `"safe"`: Only returns `True` if the identifier is case-insensitive. 898 899 Returns: 900 Whether the given text can be identified. 901 """ 902 if identify is True or identify == "always": 903 return True 904 905 if identify == "safe": 906 return not self.case_sensitive(text) 907 908 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
910 def quote_identifier(self, expression: E, identify: bool = True) -> E: 911 """ 912 Adds quotes to a given identifier. 913 914 Args: 915 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 916 identify: If set to `False`, the quotes will only be added if the identifier is deemed 917 "unsafe", with respect to its characters and this dialect's normalization strategy. 918 """ 919 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 920 name = expression.this 921 expression.set( 922 "quoted", 923 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 924 ) 925 926 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
928 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 929 if isinstance(path, exp.Literal): 930 path_text = path.name 931 if path.is_number: 932 path_text = f"[{path_text}]" 933 try: 934 return parse_json_path(path_text, self) 935 except ParseError as e: 936 if self.STRICT_JSON_PATH_SYNTAX: 937 logger.warning(f"Invalid JSON path syntax. {str(e)}") 938 939 return path
988def if_sql( 989 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 990) -> t.Callable[[Generator, exp.If], str]: 991 def _if_sql(self: Generator, expression: exp.If) -> str: 992 return self.func( 993 name, 994 expression.this, 995 expression.args.get("true"), 996 expression.args.get("false") or false_value, 997 ) 998 999 return _if_sql
1002def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1003 this = expression.this 1004 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 1005 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 1006 1007 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1077def str_position_sql( 1078 self: Generator, 1079 expression: exp.StrPosition, 1080 generate_instance: bool = False, 1081 str_position_func_name: str = "STRPOS", 1082) -> str: 1083 this = self.sql(expression, "this") 1084 substr = self.sql(expression, "substr") 1085 position = self.sql(expression, "position") 1086 instance = expression.args.get("instance") if generate_instance else None 1087 position_offset = "" 1088 1089 if position: 1090 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1091 this = self.func("SUBSTR", this, position) 1092 position_offset = f" + {position} - 1" 1093 1094 strpos_sql = self.func(str_position_func_name, this, substr, instance) 1095 1096 if position_offset: 1097 zero = exp.Literal.number(0) 1098 # If match is not found (returns 0) the position offset should not be applied 1099 case = exp.If( 1100 this=exp.EQ(this=strpos_sql, expression=zero), 1101 true=zero, 1102 false=strpos_sql + position_offset, 1103 ) 1104 strpos_sql = self.sql(case) 1105 1106 return strpos_sql
1115def var_map_sql( 1116 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1117) -> str: 1118 keys = expression.args["keys"] 1119 values = expression.args["values"] 1120 1121 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1122 self.unsupported("Cannot convert array columns into map.") 1123 return self.func(map_func_name, keys, values) 1124 1125 args = [] 1126 for key, value in zip(keys.expressions, values.expressions): 1127 args.append(self.sql(key)) 1128 args.append(self.sql(value)) 1129 1130 return self.func(map_func_name, *args)
1133def build_formatted_time( 1134 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1135) -> t.Callable[[t.List], E]: 1136 """Helper used for time expressions. 1137 1138 Args: 1139 exp_class: the expression class to instantiate. 1140 dialect: target sql dialect. 1141 default: the default format, True being time. 1142 1143 Returns: 1144 A callable that can be used to return the appropriately formatted time expression. 1145 """ 1146 1147 def _builder(args: t.List): 1148 return exp_class( 1149 this=seq_get(args, 0), 1150 format=Dialect[dialect].format_time( 1151 seq_get(args, 1) 1152 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1153 ), 1154 ) 1155 1156 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1159def time_format( 1160 dialect: DialectType = None, 1161) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1162 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1163 """ 1164 Returns the time format for a given expression, unless it's equivalent 1165 to the default time format of the dialect of interest. 1166 """ 1167 time_format = self.format_time(expression) 1168 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1169 1170 return _time_format
1173def build_date_delta( 1174 exp_class: t.Type[E], 1175 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1176 default_unit: t.Optional[str] = "DAY", 1177) -> t.Callable[[t.List], E]: 1178 def _builder(args: t.List) -> E: 1179 unit_based = len(args) == 3 1180 this = args[2] if unit_based else seq_get(args, 0) 1181 unit = None 1182 if unit_based or default_unit: 1183 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1184 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1185 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1186 1187 return _builder
1190def build_date_delta_with_interval( 1191 expression_class: t.Type[E], 1192) -> t.Callable[[t.List], t.Optional[E]]: 1193 def _builder(args: t.List) -> t.Optional[E]: 1194 if len(args) < 2: 1195 return None 1196 1197 interval = args[1] 1198 1199 if not isinstance(interval, exp.Interval): 1200 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1201 1202 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1203 1204 return _builder
1207def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1208 unit = seq_get(args, 0) 1209 this = seq_get(args, 1) 1210 1211 if isinstance(this, exp.Cast) and this.is_type("date"): 1212 return exp.DateTrunc(unit=unit, this=this) 1213 return exp.TimestampTrunc(this=this, unit=unit)
1216def date_add_interval_sql( 1217 data_type: str, kind: str 1218) -> t.Callable[[Generator, exp.Expression], str]: 1219 def func(self: Generator, expression: exp.Expression) -> str: 1220 this = self.sql(expression, "this") 1221 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1222 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1223 1224 return func
1227def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1228 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1229 args = [unit_to_str(expression), expression.this] 1230 if zone: 1231 args.append(expression.args.get("zone")) 1232 return self.func("DATE_TRUNC", *args) 1233 1234 return _timestamptrunc_sql
1237def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1238 zone = expression.args.get("zone") 1239 if not zone: 1240 from sqlglot.optimizer.annotate_types import annotate_types 1241 1242 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1243 return self.sql(exp.cast(expression.this, target_type)) 1244 if zone.name.lower() in TIMEZONES: 1245 return self.sql( 1246 exp.AtTimeZone( 1247 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1248 zone=zone, 1249 ) 1250 ) 1251 return self.func("TIMESTAMP", expression.this, zone)
1254def no_time_sql(self: Generator, expression: exp.Time) -> str: 1255 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1256 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1257 expr = exp.cast( 1258 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1259 ) 1260 return self.sql(expr)
1263def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1264 this = expression.this 1265 expr = expression.expression 1266 1267 if expr.name.lower() in TIMEZONES: 1268 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1269 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1270 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1271 return self.sql(this) 1272 1273 this = exp.cast(this, exp.DataType.Type.DATE) 1274 expr = exp.cast(expr, exp.DataType.Type.TIME) 1275 1276 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1308def timestrtotime_sql( 1309 self: Generator, 1310 expression: exp.TimeStrToTime, 1311 include_precision: bool = False, 1312) -> str: 1313 datatype = exp.DataType.build( 1314 exp.DataType.Type.TIMESTAMPTZ 1315 if expression.args.get("zone") 1316 else exp.DataType.Type.TIMESTAMP 1317 ) 1318 1319 if isinstance(expression.this, exp.Literal) and include_precision: 1320 precision = subsecond_precision(expression.this.name) 1321 if precision > 0: 1322 datatype = exp.DataType.build( 1323 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1324 ) 1325 1326 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1334def encode_decode_sql( 1335 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1336) -> str: 1337 charset = expression.args.get("charset") 1338 if charset and charset.name.lower() != "utf-8": 1339 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1340 1341 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1354def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1355 cond = expression.this 1356 1357 if isinstance(expression.this, exp.Distinct): 1358 cond = expression.this.expressions[0] 1359 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1360 1361 return self.func("sum", exp.func("if", cond, 1, 0))
1364def trim_sql(self: Generator, expression: exp.Trim) -> str: 1365 target = self.sql(expression, "this") 1366 trim_type = self.sql(expression, "position") 1367 remove_chars = self.sql(expression, "expression") 1368 collation = self.sql(expression, "collation") 1369 1370 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1371 if not remove_chars: 1372 return self.trim_sql(expression) 1373 1374 trim_type = f"{trim_type} " if trim_type else "" 1375 remove_chars = f"{remove_chars} " if remove_chars else "" 1376 from_part = "FROM " if trim_type or remove_chars else "" 1377 collation = f" COLLATE {collation}" if collation else "" 1378 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1399@unsupported_args("position", "occurrence", "parameters") 1400def regexp_extract_sql( 1401 self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll 1402) -> str: 1403 group = expression.args.get("group") 1404 1405 # Do not render group if it's the default value for this dialect 1406 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1407 group = None 1408 1409 return self.func(expression.sql_name(), expression.this, expression.expression, group)
1419def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1420 names = [] 1421 for agg in aggregations: 1422 if isinstance(agg, exp.Alias): 1423 names.append(agg.alias) 1424 else: 1425 """ 1426 This case corresponds to aggregations without aliases being used as suffixes 1427 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1428 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1429 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1430 """ 1431 agg_all_unquoted = agg.transform( 1432 lambda node: ( 1433 exp.Identifier(this=node.name, quoted=False) 1434 if isinstance(node, exp.Identifier) 1435 else node 1436 ) 1437 ) 1438 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1439 1440 return names
1480def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1481 @unsupported_args("count") 1482 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1483 return self.func(name, expression.this, expression.expression) 1484 1485 return _arg_max_or_min_sql
1488def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1489 this = expression.this.copy() 1490 1491 return_type = expression.return_type 1492 if return_type.is_type(exp.DataType.Type.DATE): 1493 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1494 # can truncate timestamp strings, because some dialects can't cast them to DATE 1495 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1496 1497 expression.this.replace(exp.cast(this, return_type)) 1498 return expression
1501def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1502 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1503 if cast and isinstance(expression, exp.TsOrDsAdd): 1504 expression = ts_or_ds_add_cast(expression) 1505 1506 return self.func( 1507 name, 1508 unit_to_var(expression), 1509 expression.expression, 1510 expression.this, 1511 ) 1512 1513 return _delta_sql
1516def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1517 unit = expression.args.get("unit") 1518 1519 if isinstance(unit, exp.Placeholder): 1520 return unit 1521 if unit: 1522 return exp.Literal.string(unit.name) 1523 return exp.Literal.string(default) if default else None
1553def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1554 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1555 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1556 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1557 1558 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1561def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1562 """Remove table refs from columns in when statements.""" 1563 alias = expression.this.args.get("alias") 1564 1565 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1566 return self.dialect.normalize_identifier(identifier).name if identifier else None 1567 1568 targets = {normalize(expression.this.this)} 1569 1570 if alias: 1571 targets.add(normalize(alias.this)) 1572 1573 for when in expression.args["whens"].expressions: 1574 # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED 1575 # they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part 1576 # (not the column list) of the INSERT 1577 then: exp.Insert | exp.Update | None = when.args.get("then") 1578 if then: 1579 if isinstance(then, exp.Update): 1580 for equals in then.find_all(exp.EQ): 1581 equal_lhs = equals.this 1582 if ( 1583 isinstance(equal_lhs, exp.Column) 1584 and normalize(equal_lhs.args.get("table")) in targets 1585 ): 1586 equal_lhs.replace(exp.column(equal_lhs.this)) 1587 if isinstance(then, exp.Insert): 1588 column_list = then.this 1589 if isinstance(column_list, exp.Tuple): 1590 for column in column_list.expressions: 1591 if normalize(column.args.get("table")) in targets: 1592 column.replace(exp.column(column.this)) 1593 1594 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1597def build_json_extract_path( 1598 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1599) -> t.Callable[[t.List], F]: 1600 def _builder(args: t.List) -> F: 1601 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1602 for arg in args[1:]: 1603 if not isinstance(arg, exp.Literal): 1604 # We use the fallback parser because we can't really transpile non-literals safely 1605 return expr_type.from_arg_list(args) 1606 1607 text = arg.name 1608 if is_int(text): 1609 index = int(text) 1610 segments.append( 1611 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1612 ) 1613 else: 1614 segments.append(exp.JSONPathKey(this=text)) 1615 1616 # This is done to avoid failing in the expression validator due to the arg count 1617 del args[2:] 1618 return expr_type( 1619 this=seq_get(args, 0), 1620 expression=exp.JSONPath(expressions=segments), 1621 only_json_types=arrow_req_json_type, 1622 ) 1623 1624 return _builder
1627def json_extract_segments( 1628 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1629) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1630 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1631 path = expression.expression 1632 if not isinstance(path, exp.JSONPath): 1633 return rename_func(name)(self, expression) 1634 1635 escape = path.args.get("escape") 1636 1637 segments = [] 1638 for segment in path.expressions: 1639 path = self.sql(segment) 1640 if path: 1641 if isinstance(segment, exp.JSONPathPart) and ( 1642 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1643 ): 1644 if escape: 1645 path = self.escape_str(path) 1646 1647 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1648 1649 segments.append(path) 1650 1651 if op: 1652 return f" {op} ".join([self.sql(expression.this), *segments]) 1653 return self.func(name, expression.this, *segments) 1654 1655 return _json_extract_segments
1665def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1666 cond = expression.expression 1667 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1668 alias = cond.expressions[0] 1669 cond = cond.this 1670 elif isinstance(cond, exp.Predicate): 1671 alias = "_u" 1672 else: 1673 self.unsupported("Unsupported filter condition") 1674 return "" 1675 1676 unnest = exp.Unnest(expressions=[expression.this]) 1677 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1678 return self.sql(exp.Array(expressions=[filtered]))
1690def build_default_decimal_type( 1691 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1692) -> t.Callable[[exp.DataType], exp.DataType]: 1693 def _builder(dtype: exp.DataType) -> exp.DataType: 1694 if dtype.expressions or precision is None: 1695 return dtype 1696 1697 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1698 return exp.DataType.build(f"DECIMAL({params})") 1699 1700 return _builder
1703def build_timestamp_from_parts(args: t.List) -> exp.Func: 1704 if len(args) == 2: 1705 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1706 # so we parse this into Anonymous for now instead of introducing complexity 1707 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1708 1709 return exp.TimestampFromParts.from_arg_list(args)
1716def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1717 start = expression.args.get("start") 1718 end = expression.args.get("end") 1719 step = expression.args.get("step") 1720 1721 if isinstance(start, exp.Cast): 1722 target_type = start.to 1723 elif isinstance(end, exp.Cast): 1724 target_type = end.to 1725 else: 1726 target_type = None 1727 1728 if start and end and target_type and target_type.is_type("date", "timestamp"): 1729 if isinstance(start, exp.Cast) and target_type is start.to: 1730 end = exp.cast(end, target_type) 1731 else: 1732 start = exp.cast(start, target_type) 1733 1734 return self.func("SEQUENCE", start, end, step)
1737def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: 1738 def _builder(args: t.List, dialect: Dialect) -> E: 1739 return expr_type( 1740 this=seq_get(args, 0), 1741 expression=seq_get(args, 1), 1742 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1743 parameters=seq_get(args, 3), 1744 ) 1745 1746 return _builder
1749def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1750 if isinstance(expression.this, exp.Explode): 1751 return self.sql( 1752 exp.Join( 1753 this=exp.Unnest( 1754 expressions=[expression.this.this], 1755 alias=expression.args.get("alias"), 1756 offset=isinstance(expression.this, exp.Posexplode), 1757 ), 1758 kind="cross", 1759 ) 1760 ) 1761 return self.lateral_sql(expression)
1768def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: 1769 args = [] 1770 for unit, value in expression.args.items(): 1771 if isinstance(value, exp.Kwarg): 1772 value = value.expression 1773 1774 args.append(f"{value} {unit}") 1775 1776 return f"INTERVAL '{self.format_args(*args, sep=sep)}'"