sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "clickhouse"): 191 klass.generator_class.SUPPORTS_NULLABLE_TYPES = False 192 193 if enum not in ("", "athena", "presto", "trino"): 194 klass.generator_class.TRY_SUPPORTED = False 195 klass.generator_class.SUPPORTS_UESCAPE = False 196 197 if enum not in ("", "databricks", "hive", "spark", "spark2"): 198 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 199 for modifier in ("cluster", "distribute", "sort"): 200 modifier_transforms.pop(modifier, None) 201 202 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 203 204 if enum not in ("", "doris", "mysql"): 205 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 209 TokenType.STRAIGHT_JOIN, 210 } 211 212 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 213 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 214 TokenType.ANTI, 215 TokenType.SEMI, 216 } 217 218 return klass 219 220 221class Dialect(metaclass=_Dialect): 222 INDEX_OFFSET = 0 223 """The base index offset for arrays.""" 224 225 WEEK_OFFSET = 0 226 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 227 228 UNNEST_COLUMN_ONLY = False 229 """Whether `UNNEST` table aliases are treated as column aliases.""" 230 231 ALIAS_POST_TABLESAMPLE = False 232 """Whether the table alias comes after tablesample.""" 233 234 TABLESAMPLE_SIZE_IS_PERCENT = False 235 """Whether a size in the table sample clause represents percentage.""" 236 237 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 238 """Specifies the strategy according to which identifiers should be normalized.""" 239 240 IDENTIFIERS_CAN_START_WITH_DIGIT = False 241 """Whether an unquoted identifier can start with a digit.""" 242 243 DPIPE_IS_STRING_CONCAT = True 244 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 245 246 STRICT_STRING_CONCAT = False 247 """Whether `CONCAT`'s arguments must be strings.""" 248 249 SUPPORTS_USER_DEFINED_TYPES = True 250 """Whether user-defined data types are supported.""" 251 252 SUPPORTS_SEMI_ANTI_JOIN = True 253 """Whether `SEMI` or `ANTI` joins are supported.""" 254 255 SUPPORTS_COLUMN_JOIN_MARKS = False 256 """Whether the old-style outer join (+) syntax is supported.""" 257 258 COPY_PARAMS_ARE_CSV = True 259 """Separator of COPY statement parameters.""" 260 261 NORMALIZE_FUNCTIONS: bool | str = "upper" 262 """ 263 Determines how function names are going to be normalized. 264 Possible values: 265 "upper" or True: Convert names to uppercase. 266 "lower": Convert names to lowercase. 267 False: Disables function name normalization. 268 """ 269 270 LOG_BASE_FIRST: t.Optional[bool] = True 271 """ 272 Whether the base comes first in the `LOG` function. 273 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 274 """ 275 276 NULL_ORDERING = "nulls_are_small" 277 """ 278 Default `NULL` ordering method to use if not explicitly set. 279 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 280 """ 281 282 TYPED_DIVISION = False 283 """ 284 Whether the behavior of `a / b` depends on the types of `a` and `b`. 285 False means `a / b` is always float division. 286 True means `a / b` is integer division if both `a` and `b` are integers. 287 """ 288 289 SAFE_DIVISION = False 290 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 291 292 CONCAT_COALESCE = False 293 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 294 295 HEX_LOWERCASE = False 296 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 297 298 DATE_FORMAT = "'%Y-%m-%d'" 299 DATEINT_FORMAT = "'%Y%m%d'" 300 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 301 302 TIME_MAPPING: t.Dict[str, str] = {} 303 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 304 305 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 306 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 307 FORMAT_MAPPING: t.Dict[str, str] = {} 308 """ 309 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 310 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 311 """ 312 313 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 314 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 315 316 PSEUDOCOLUMNS: t.Set[str] = set() 317 """ 318 Columns that are auto-generated by the engine corresponding to this dialect. 319 For example, such columns may be excluded from `SELECT *` queries. 320 """ 321 322 PREFER_CTE_ALIAS_COLUMN = False 323 """ 324 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 325 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 326 any projection aliases in the subquery. 327 328 For example, 329 WITH y(c) AS ( 330 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 333 will be rewritten as 334 335 WITH y(c) AS ( 336 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 337 ) SELECT c FROM y; 338 """ 339 340 COPY_PARAMS_ARE_CSV = True 341 """ 342 Whether COPY statement parameters are separated by comma or whitespace 343 """ 344 345 FORCE_EARLY_ALIAS_REF_EXPANSION = False 346 """ 347 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 348 349 For example: 350 WITH data AS ( 351 SELECT 352 1 AS id, 353 2 AS my_id 354 ) 355 SELECT 356 id AS my_id 357 FROM 358 data 359 WHERE 360 my_id = 1 361 GROUP BY 362 my_id, 363 HAVING 364 my_id = 1 365 366 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 367 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 368 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 369 - Clickhouse, which will forward the alias across the query i.e it resolves 370 to "WHERE id = 1 GROUP BY id HAVING id = 1" 371 """ 372 373 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 374 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 375 376 SUPPORTS_ORDER_BY_ALL = False 377 """ 378 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 379 """ 380 381 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 382 """ 383 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 384 as the former is of type INT[] vs the latter which is SUPER 385 """ 386 387 SUPPORTS_FIXED_SIZE_ARRAYS = False 388 """ 389 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 390 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 391 be interpreted as a subscript/index operator. 392 """ 393 394 STRICT_JSON_PATH_SYNTAX = True 395 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 396 397 ON_CONDITION_EMPTY_BEFORE_ERROR = True 398 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 399 400 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 401 """Whether ArrayAgg needs to filter NULL values.""" 402 403 REGEXP_EXTRACT_DEFAULT_GROUP = 0 404 """The default value for the capturing group.""" 405 406 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 407 exp.Except: True, 408 exp.Intersect: True, 409 exp.Union: True, 410 } 411 """ 412 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 413 must be explicitly specified. 414 """ 415 416 CREATABLE_KIND_MAPPING: dict[str, str] = {} 417 """ 418 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 419 equivalent of CREATE SCHEMA is CREATE DATABASE. 420 """ 421 422 # --- Autofilled --- 423 424 tokenizer_class = Tokenizer 425 jsonpath_tokenizer_class = JSONPathTokenizer 426 parser_class = Parser 427 generator_class = Generator 428 429 # A trie of the time_mapping keys 430 TIME_TRIE: t.Dict = {} 431 FORMAT_TRIE: t.Dict = {} 432 433 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 434 INVERSE_TIME_TRIE: t.Dict = {} 435 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 436 INVERSE_FORMAT_TRIE: t.Dict = {} 437 438 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 439 440 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 441 442 # Delimiters for string literals and identifiers 443 QUOTE_START = "'" 444 QUOTE_END = "'" 445 IDENTIFIER_START = '"' 446 IDENTIFIER_END = '"' 447 448 # Delimiters for bit, hex, byte and unicode literals 449 BIT_START: t.Optional[str] = None 450 BIT_END: t.Optional[str] = None 451 HEX_START: t.Optional[str] = None 452 HEX_END: t.Optional[str] = None 453 BYTE_START: t.Optional[str] = None 454 BYTE_END: t.Optional[str] = None 455 UNICODE_START: t.Optional[str] = None 456 UNICODE_END: t.Optional[str] = None 457 458 DATE_PART_MAPPING = { 459 "Y": "YEAR", 460 "YY": "YEAR", 461 "YYY": "YEAR", 462 "YYYY": "YEAR", 463 "YR": "YEAR", 464 "YEARS": "YEAR", 465 "YRS": "YEAR", 466 "MM": "MONTH", 467 "MON": "MONTH", 468 "MONS": "MONTH", 469 "MONTHS": "MONTH", 470 "D": "DAY", 471 "DD": "DAY", 472 "DAYS": "DAY", 473 "DAYOFMONTH": "DAY", 474 "DAY OF WEEK": "DAYOFWEEK", 475 "WEEKDAY": "DAYOFWEEK", 476 "DOW": "DAYOFWEEK", 477 "DW": "DAYOFWEEK", 478 "WEEKDAY_ISO": "DAYOFWEEKISO", 479 "DOW_ISO": "DAYOFWEEKISO", 480 "DW_ISO": "DAYOFWEEKISO", 481 "DAY OF YEAR": "DAYOFYEAR", 482 "DOY": "DAYOFYEAR", 483 "DY": "DAYOFYEAR", 484 "W": "WEEK", 485 "WK": "WEEK", 486 "WEEKOFYEAR": "WEEK", 487 "WOY": "WEEK", 488 "WY": "WEEK", 489 "WEEK_ISO": "WEEKISO", 490 "WEEKOFYEARISO": "WEEKISO", 491 "WEEKOFYEAR_ISO": "WEEKISO", 492 "Q": "QUARTER", 493 "QTR": "QUARTER", 494 "QTRS": "QUARTER", 495 "QUARTERS": "QUARTER", 496 "H": "HOUR", 497 "HH": "HOUR", 498 "HR": "HOUR", 499 "HOURS": "HOUR", 500 "HRS": "HOUR", 501 "M": "MINUTE", 502 "MI": "MINUTE", 503 "MIN": "MINUTE", 504 "MINUTES": "MINUTE", 505 "MINS": "MINUTE", 506 "S": "SECOND", 507 "SEC": "SECOND", 508 "SECONDS": "SECOND", 509 "SECS": "SECOND", 510 "MS": "MILLISECOND", 511 "MSEC": "MILLISECOND", 512 "MSECS": "MILLISECOND", 513 "MSECOND": "MILLISECOND", 514 "MSECONDS": "MILLISECOND", 515 "MILLISEC": "MILLISECOND", 516 "MILLISECS": "MILLISECOND", 517 "MILLISECON": "MILLISECOND", 518 "MILLISECONDS": "MILLISECOND", 519 "US": "MICROSECOND", 520 "USEC": "MICROSECOND", 521 "USECS": "MICROSECOND", 522 "MICROSEC": "MICROSECOND", 523 "MICROSECS": "MICROSECOND", 524 "USECOND": "MICROSECOND", 525 "USECONDS": "MICROSECOND", 526 "MICROSECONDS": "MICROSECOND", 527 "NS": "NANOSECOND", 528 "NSEC": "NANOSECOND", 529 "NANOSEC": "NANOSECOND", 530 "NSECOND": "NANOSECOND", 531 "NSECONDS": "NANOSECOND", 532 "NANOSECS": "NANOSECOND", 533 "EPOCH_SECOND": "EPOCH", 534 "EPOCH_SECONDS": "EPOCH", 535 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 536 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 537 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 538 "TZH": "TIMEZONE_HOUR", 539 "TZM": "TIMEZONE_MINUTE", 540 "DEC": "DECADE", 541 "DECS": "DECADE", 542 "DECADES": "DECADE", 543 "MIL": "MILLENIUM", 544 "MILS": "MILLENIUM", 545 "MILLENIA": "MILLENIUM", 546 "C": "CENTURY", 547 "CENT": "CENTURY", 548 "CENTS": "CENTURY", 549 "CENTURIES": "CENTURY", 550 } 551 552 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 553 exp.DataType.Type.BIGINT: { 554 exp.ApproxDistinct, 555 exp.ArraySize, 556 exp.Length, 557 }, 558 exp.DataType.Type.BOOLEAN: { 559 exp.Between, 560 exp.Boolean, 561 exp.In, 562 exp.RegexpLike, 563 }, 564 exp.DataType.Type.DATE: { 565 exp.CurrentDate, 566 exp.Date, 567 exp.DateFromParts, 568 exp.DateStrToDate, 569 exp.DiToDate, 570 exp.StrToDate, 571 exp.TimeStrToDate, 572 exp.TsOrDsToDate, 573 }, 574 exp.DataType.Type.DATETIME: { 575 exp.CurrentDatetime, 576 exp.Datetime, 577 exp.DatetimeAdd, 578 exp.DatetimeSub, 579 }, 580 exp.DataType.Type.DOUBLE: { 581 exp.ApproxQuantile, 582 exp.Avg, 583 exp.Div, 584 exp.Exp, 585 exp.Ln, 586 exp.Log, 587 exp.Pow, 588 exp.Quantile, 589 exp.Round, 590 exp.SafeDivide, 591 exp.Sqrt, 592 exp.Stddev, 593 exp.StddevPop, 594 exp.StddevSamp, 595 exp.Variance, 596 exp.VariancePop, 597 }, 598 exp.DataType.Type.INT: { 599 exp.Ceil, 600 exp.DatetimeDiff, 601 exp.DateDiff, 602 exp.TimestampDiff, 603 exp.TimeDiff, 604 exp.DateToDi, 605 exp.Levenshtein, 606 exp.Sign, 607 exp.StrPosition, 608 exp.TsOrDiToDi, 609 }, 610 exp.DataType.Type.JSON: { 611 exp.ParseJSON, 612 }, 613 exp.DataType.Type.TIME: { 614 exp.Time, 615 }, 616 exp.DataType.Type.TIMESTAMP: { 617 exp.CurrentTime, 618 exp.CurrentTimestamp, 619 exp.StrToTime, 620 exp.TimeAdd, 621 exp.TimeStrToTime, 622 exp.TimeSub, 623 exp.TimestampAdd, 624 exp.TimestampSub, 625 exp.UnixToTime, 626 }, 627 exp.DataType.Type.TINYINT: { 628 exp.Day, 629 exp.Month, 630 exp.Week, 631 exp.Year, 632 exp.Quarter, 633 }, 634 exp.DataType.Type.VARCHAR: { 635 exp.ArrayConcat, 636 exp.Concat, 637 exp.ConcatWs, 638 exp.DateToDateStr, 639 exp.GroupConcat, 640 exp.Initcap, 641 exp.Lower, 642 exp.Substring, 643 exp.TimeToStr, 644 exp.TimeToTimeStr, 645 exp.Trim, 646 exp.TsOrDsToDateStr, 647 exp.UnixToStr, 648 exp.UnixToTimeStr, 649 exp.Upper, 650 }, 651 } 652 653 ANNOTATORS: AnnotatorsType = { 654 **{ 655 expr_type: lambda self, e: self._annotate_unary(e) 656 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 657 }, 658 **{ 659 expr_type: lambda self, e: self._annotate_binary(e) 660 for expr_type in subclasses(exp.__name__, exp.Binary) 661 }, 662 **{ 663 expr_type: _annotate_with_type_lambda(data_type) 664 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 665 for expr_type in expressions 666 }, 667 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 668 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 669 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 670 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 671 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 672 exp.Bracket: lambda self, e: self._annotate_bracket(e), 673 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 674 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 675 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 676 exp.Count: lambda self, e: self._annotate_with_type( 677 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 678 ), 679 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 680 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 681 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 682 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 683 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 684 exp.Div: lambda self, e: self._annotate_div(e), 685 exp.Dot: lambda self, e: self._annotate_dot(e), 686 exp.Explode: lambda self, e: self._annotate_explode(e), 687 exp.Extract: lambda self, e: self._annotate_extract(e), 688 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 689 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 690 e, exp.DataType.build("ARRAY<DATE>") 691 ), 692 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 693 e, exp.DataType.build("ARRAY<TIMESTAMP>") 694 ), 695 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 696 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 697 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 698 exp.Literal: lambda self, e: self._annotate_literal(e), 699 exp.Map: lambda self, e: self._annotate_map(e), 700 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 702 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 703 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 704 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 705 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 706 exp.Struct: lambda self, e: self._annotate_struct(e), 707 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 708 exp.Timestamp: lambda self, e: self._annotate_with_type( 709 e, 710 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 711 ), 712 exp.ToMap: lambda self, e: self._annotate_to_map(e), 713 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 714 exp.Unnest: lambda self, e: self._annotate_unnest(e), 715 exp.VarMap: lambda self, e: self._annotate_map(e), 716 } 717 718 @classmethod 719 def get_or_raise(cls, dialect: DialectType) -> Dialect: 720 """ 721 Look up a dialect in the global dialect registry and return it if it exists. 722 723 Args: 724 dialect: The target dialect. If this is a string, it can be optionally followed by 725 additional key-value pairs that are separated by commas and are used to specify 726 dialect settings, such as whether the dialect's identifiers are case-sensitive. 727 728 Example: 729 >>> dialect = dialect_class = get_or_raise("duckdb") 730 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 731 732 Returns: 733 The corresponding Dialect instance. 734 """ 735 736 if not dialect: 737 return cls() 738 if isinstance(dialect, _Dialect): 739 return dialect() 740 if isinstance(dialect, Dialect): 741 return dialect 742 if isinstance(dialect, str): 743 try: 744 dialect_name, *kv_strings = dialect.split(",") 745 kv_pairs = (kv.split("=") for kv in kv_strings) 746 kwargs = {} 747 for pair in kv_pairs: 748 key = pair[0].strip() 749 value: t.Union[bool | str | None] = None 750 751 if len(pair) == 1: 752 # Default initialize standalone settings to True 753 value = True 754 elif len(pair) == 2: 755 value = pair[1].strip() 756 757 # Coerce the value to boolean if it matches to the truthy/falsy values below 758 value_lower = value.lower() 759 if value_lower in ("true", "1"): 760 value = True 761 elif value_lower in ("false", "0"): 762 value = False 763 764 kwargs[key] = value 765 766 except ValueError: 767 raise ValueError( 768 f"Invalid dialect format: '{dialect}'. " 769 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 770 ) 771 772 result = cls.get(dialect_name.strip()) 773 if not result: 774 from difflib import get_close_matches 775 776 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 777 if similar: 778 similar = f" Did you mean {similar}?" 779 780 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 781 782 return result(**kwargs) 783 784 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 785 786 @classmethod 787 def format_time( 788 cls, expression: t.Optional[str | exp.Expression] 789 ) -> t.Optional[exp.Expression]: 790 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 791 if isinstance(expression, str): 792 return exp.Literal.string( 793 # the time formats are quoted 794 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 795 ) 796 797 if expression and expression.is_string: 798 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 799 800 return expression 801 802 def __init__(self, **kwargs) -> None: 803 normalization_strategy = kwargs.pop("normalization_strategy", None) 804 805 if normalization_strategy is None: 806 self.normalization_strategy = self.NORMALIZATION_STRATEGY 807 else: 808 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 809 810 self.settings = kwargs 811 812 def __eq__(self, other: t.Any) -> bool: 813 # Does not currently take dialect state into account 814 return type(self) == other 815 816 def __hash__(self) -> int: 817 # Does not currently take dialect state into account 818 return hash(type(self)) 819 820 def normalize_identifier(self, expression: E) -> E: 821 """ 822 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 823 824 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 825 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 826 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 827 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 828 829 There are also dialects like Spark, which are case-insensitive even when quotes are 830 present, and dialects like MySQL, whose resolution rules match those employed by the 831 underlying operating system, for example they may always be case-sensitive in Linux. 832 833 Finally, the normalization behavior of some engines can even be controlled through flags, 834 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 835 836 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 837 that it can analyze queries in the optimizer and successfully capture their semantics. 838 """ 839 if ( 840 isinstance(expression, exp.Identifier) 841 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 842 and ( 843 not expression.quoted 844 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 845 ) 846 ): 847 expression.set( 848 "this", 849 ( 850 expression.this.upper() 851 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 852 else expression.this.lower() 853 ), 854 ) 855 856 return expression 857 858 def case_sensitive(self, text: str) -> bool: 859 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 860 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 861 return False 862 863 unsafe = ( 864 str.islower 865 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 866 else str.isupper 867 ) 868 return any(unsafe(char) for char in text) 869 870 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 871 """Checks if text can be identified given an identify option. 872 873 Args: 874 text: The text to check. 875 identify: 876 `"always"` or `True`: Always returns `True`. 877 `"safe"`: Only returns `True` if the identifier is case-insensitive. 878 879 Returns: 880 Whether the given text can be identified. 881 """ 882 if identify is True or identify == "always": 883 return True 884 885 if identify == "safe": 886 return not self.case_sensitive(text) 887 888 return False 889 890 def quote_identifier(self, expression: E, identify: bool = True) -> E: 891 """ 892 Adds quotes to a given identifier. 893 894 Args: 895 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 896 identify: If set to `False`, the quotes will only be added if the identifier is deemed 897 "unsafe", with respect to its characters and this dialect's normalization strategy. 898 """ 899 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 900 name = expression.this 901 expression.set( 902 "quoted", 903 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 904 ) 905 906 return expression 907 908 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 909 if isinstance(path, exp.Literal): 910 path_text = path.name 911 if path.is_number: 912 path_text = f"[{path_text}]" 913 try: 914 return parse_json_path(path_text, self) 915 except ParseError as e: 916 if self.STRICT_JSON_PATH_SYNTAX: 917 logger.warning(f"Invalid JSON path syntax. {str(e)}") 918 919 return path 920 921 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 922 return self.parser(**opts).parse(self.tokenize(sql), sql) 923 924 def parse_into( 925 self, expression_type: exp.IntoType, sql: str, **opts 926 ) -> t.List[t.Optional[exp.Expression]]: 927 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 928 929 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 930 return self.generator(**opts).generate(expression, copy=copy) 931 932 def transpile(self, sql: str, **opts) -> t.List[str]: 933 return [ 934 self.generate(expression, copy=False, **opts) if expression else "" 935 for expression in self.parse(sql) 936 ] 937 938 def tokenize(self, sql: str) -> t.List[Token]: 939 return self.tokenizer.tokenize(sql) 940 941 @property 942 def tokenizer(self) -> Tokenizer: 943 return self.tokenizer_class(dialect=self) 944 945 @property 946 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 947 return self.jsonpath_tokenizer_class(dialect=self) 948 949 def parser(self, **opts) -> Parser: 950 return self.parser_class(dialect=self, **opts) 951 952 def generator(self, **opts) -> Generator: 953 return self.generator_class(dialect=self, **opts) 954 955 956DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 957 958 959def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 960 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 961 962 963def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 964 if expression.args.get("accuracy"): 965 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 966 return self.func("APPROX_COUNT_DISTINCT", expression.this) 967 968 969def if_sql( 970 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 971) -> t.Callable[[Generator, exp.If], str]: 972 def _if_sql(self: Generator, expression: exp.If) -> str: 973 return self.func( 974 name, 975 expression.this, 976 expression.args.get("true"), 977 expression.args.get("false") or false_value, 978 ) 979 980 return _if_sql 981 982 983def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 984 this = expression.this 985 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 986 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 987 988 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 989 990 991def inline_array_sql(self: Generator, expression: exp.Array) -> str: 992 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 993 994 995def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 996 elem = seq_get(expression.expressions, 0) 997 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 998 return self.func("ARRAY", elem) 999 return inline_array_sql(self, expression) 1000 1001 1002def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1003 return self.like_sql( 1004 exp.Like( 1005 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1006 ) 1007 ) 1008 1009 1010def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1011 zone = self.sql(expression, "this") 1012 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1013 1014 1015def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1016 if expression.args.get("recursive"): 1017 self.unsupported("Recursive CTEs are unsupported") 1018 expression.args["recursive"] = False 1019 return self.with_sql(expression) 1020 1021 1022def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 1023 n = self.sql(expression, "this") 1024 d = self.sql(expression, "expression") 1025 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 1026 1027 1028def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1029 self.unsupported("TABLESAMPLE unsupported") 1030 return self.sql(expression.this) 1031 1032 1033def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1034 self.unsupported("PIVOT unsupported") 1035 return "" 1036 1037 1038def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1039 return self.cast_sql(expression) 1040 1041 1042def no_comment_column_constraint_sql( 1043 self: Generator, expression: exp.CommentColumnConstraint 1044) -> str: 1045 self.unsupported("CommentColumnConstraint unsupported") 1046 return "" 1047 1048 1049def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1050 self.unsupported("MAP_FROM_ENTRIES unsupported") 1051 return "" 1052 1053 1054def property_sql(self: Generator, expression: exp.Property) -> str: 1055 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1056 1057 1058def str_position_sql( 1059 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1060) -> str: 1061 this = self.sql(expression, "this") 1062 substr = self.sql(expression, "substr") 1063 position = self.sql(expression, "position") 1064 instance = expression.args.get("instance") if generate_instance else None 1065 position_offset = "" 1066 1067 if position: 1068 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1069 this = self.func("SUBSTR", this, position) 1070 position_offset = f" + {position} - 1" 1071 1072 return self.func("STRPOS", this, substr, instance) + position_offset 1073 1074 1075def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1076 return ( 1077 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1078 ) 1079 1080 1081def var_map_sql( 1082 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1083) -> str: 1084 keys = expression.args["keys"] 1085 values = expression.args["values"] 1086 1087 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1088 self.unsupported("Cannot convert array columns into map.") 1089 return self.func(map_func_name, keys, values) 1090 1091 args = [] 1092 for key, value in zip(keys.expressions, values.expressions): 1093 args.append(self.sql(key)) 1094 args.append(self.sql(value)) 1095 1096 return self.func(map_func_name, *args) 1097 1098 1099def build_formatted_time( 1100 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1101) -> t.Callable[[t.List], E]: 1102 """Helper used for time expressions. 1103 1104 Args: 1105 exp_class: the expression class to instantiate. 1106 dialect: target sql dialect. 1107 default: the default format, True being time. 1108 1109 Returns: 1110 A callable that can be used to return the appropriately formatted time expression. 1111 """ 1112 1113 def _builder(args: t.List): 1114 return exp_class( 1115 this=seq_get(args, 0), 1116 format=Dialect[dialect].format_time( 1117 seq_get(args, 1) 1118 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1119 ), 1120 ) 1121 1122 return _builder 1123 1124 1125def time_format( 1126 dialect: DialectType = None, 1127) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1128 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1129 """ 1130 Returns the time format for a given expression, unless it's equivalent 1131 to the default time format of the dialect of interest. 1132 """ 1133 time_format = self.format_time(expression) 1134 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1135 1136 return _time_format 1137 1138 1139def build_date_delta( 1140 exp_class: t.Type[E], 1141 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1142 default_unit: t.Optional[str] = "DAY", 1143) -> t.Callable[[t.List], E]: 1144 def _builder(args: t.List) -> E: 1145 unit_based = len(args) == 3 1146 this = args[2] if unit_based else seq_get(args, 0) 1147 unit = None 1148 if unit_based or default_unit: 1149 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1150 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1151 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1152 1153 return _builder 1154 1155 1156def build_date_delta_with_interval( 1157 expression_class: t.Type[E], 1158) -> t.Callable[[t.List], t.Optional[E]]: 1159 def _builder(args: t.List) -> t.Optional[E]: 1160 if len(args) < 2: 1161 return None 1162 1163 interval = args[1] 1164 1165 if not isinstance(interval, exp.Interval): 1166 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1167 1168 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1169 1170 return _builder 1171 1172 1173def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1174 unit = seq_get(args, 0) 1175 this = seq_get(args, 1) 1176 1177 if isinstance(this, exp.Cast) and this.is_type("date"): 1178 return exp.DateTrunc(unit=unit, this=this) 1179 return exp.TimestampTrunc(this=this, unit=unit) 1180 1181 1182def date_add_interval_sql( 1183 data_type: str, kind: str 1184) -> t.Callable[[Generator, exp.Expression], str]: 1185 def func(self: Generator, expression: exp.Expression) -> str: 1186 this = self.sql(expression, "this") 1187 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1188 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1189 1190 return func 1191 1192 1193def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1194 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1195 args = [unit_to_str(expression), expression.this] 1196 if zone: 1197 args.append(expression.args.get("zone")) 1198 return self.func("DATE_TRUNC", *args) 1199 1200 return _timestamptrunc_sql 1201 1202 1203def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1204 zone = expression.args.get("zone") 1205 if not zone: 1206 from sqlglot.optimizer.annotate_types import annotate_types 1207 1208 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1209 return self.sql(exp.cast(expression.this, target_type)) 1210 if zone.name.lower() in TIMEZONES: 1211 return self.sql( 1212 exp.AtTimeZone( 1213 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1214 zone=zone, 1215 ) 1216 ) 1217 return self.func("TIMESTAMP", expression.this, zone) 1218 1219 1220def no_time_sql(self: Generator, expression: exp.Time) -> str: 1221 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1222 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1223 expr = exp.cast( 1224 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1225 ) 1226 return self.sql(expr) 1227 1228 1229def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1230 this = expression.this 1231 expr = expression.expression 1232 1233 if expr.name.lower() in TIMEZONES: 1234 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1235 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1236 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1237 return self.sql(this) 1238 1239 this = exp.cast(this, exp.DataType.Type.DATE) 1240 expr = exp.cast(expr, exp.DataType.Type.TIME) 1241 1242 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1243 1244 1245def locate_to_strposition(args: t.List) -> exp.Expression: 1246 return exp.StrPosition( 1247 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1248 ) 1249 1250 1251def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1252 return self.func( 1253 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1254 ) 1255 1256 1257def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1258 return self.sql( 1259 exp.Substring( 1260 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1261 ) 1262 ) 1263 1264 1265def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1266 return self.sql( 1267 exp.Substring( 1268 this=expression.this, 1269 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1270 ) 1271 ) 1272 1273 1274def timestrtotime_sql( 1275 self: Generator, 1276 expression: exp.TimeStrToTime, 1277 include_precision: bool = False, 1278) -> str: 1279 datatype = exp.DataType.build( 1280 exp.DataType.Type.TIMESTAMPTZ 1281 if expression.args.get("zone") 1282 else exp.DataType.Type.TIMESTAMP 1283 ) 1284 1285 if isinstance(expression.this, exp.Literal) and include_precision: 1286 precision = subsecond_precision(expression.this.name) 1287 if precision > 0: 1288 datatype = exp.DataType.build( 1289 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1290 ) 1291 1292 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1293 1294 1295def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1296 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1297 1298 1299# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1300def encode_decode_sql( 1301 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1302) -> str: 1303 charset = expression.args.get("charset") 1304 if charset and charset.name.lower() != "utf-8": 1305 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1306 1307 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1308 1309 1310def min_or_least(self: Generator, expression: exp.Min) -> str: 1311 name = "LEAST" if expression.expressions else "MIN" 1312 return rename_func(name)(self, expression) 1313 1314 1315def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1316 name = "GREATEST" if expression.expressions else "MAX" 1317 return rename_func(name)(self, expression) 1318 1319 1320def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1321 cond = expression.this 1322 1323 if isinstance(expression.this, exp.Distinct): 1324 cond = expression.this.expressions[0] 1325 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1326 1327 return self.func("sum", exp.func("if", cond, 1, 0)) 1328 1329 1330def trim_sql(self: Generator, expression: exp.Trim) -> str: 1331 target = self.sql(expression, "this") 1332 trim_type = self.sql(expression, "position") 1333 remove_chars = self.sql(expression, "expression") 1334 collation = self.sql(expression, "collation") 1335 1336 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1337 if not remove_chars: 1338 return self.trim_sql(expression) 1339 1340 trim_type = f"{trim_type} " if trim_type else "" 1341 remove_chars = f"{remove_chars} " if remove_chars else "" 1342 from_part = "FROM " if trim_type or remove_chars else "" 1343 collation = f" COLLATE {collation}" if collation else "" 1344 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1345 1346 1347def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1348 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1349 1350 1351def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1352 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1353 1354 1355def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1356 delim, *rest_args = expression.expressions 1357 return self.sql( 1358 reduce( 1359 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1360 rest_args, 1361 ) 1362 ) 1363 1364 1365def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1366 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1367 if bad_args: 1368 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1369 1370 group = expression.args.get("group") 1371 1372 # Do not render group if it's the default value for this dialect 1373 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1374 group = None 1375 1376 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group) 1377 1378 1379def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1380 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1381 if bad_args: 1382 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1383 1384 return self.func( 1385 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1386 ) 1387 1388 1389def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1390 names = [] 1391 for agg in aggregations: 1392 if isinstance(agg, exp.Alias): 1393 names.append(agg.alias) 1394 else: 1395 """ 1396 This case corresponds to aggregations without aliases being used as suffixes 1397 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1398 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1399 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1400 """ 1401 agg_all_unquoted = agg.transform( 1402 lambda node: ( 1403 exp.Identifier(this=node.name, quoted=False) 1404 if isinstance(node, exp.Identifier) 1405 else node 1406 ) 1407 ) 1408 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1409 1410 return names 1411 1412 1413def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1414 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1415 1416 1417# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1418def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1419 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1420 1421 1422def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1423 return self.func("MAX", expression.this) 1424 1425 1426def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1427 a = self.sql(expression.left) 1428 b = self.sql(expression.right) 1429 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1430 1431 1432def is_parse_json(expression: exp.Expression) -> bool: 1433 return isinstance(expression, exp.ParseJSON) or ( 1434 isinstance(expression, exp.Cast) and expression.is_type("json") 1435 ) 1436 1437 1438def isnull_to_is_null(args: t.List) -> exp.Expression: 1439 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1440 1441 1442def generatedasidentitycolumnconstraint_sql( 1443 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1444) -> str: 1445 start = self.sql(expression, "start") or "1" 1446 increment = self.sql(expression, "increment") or "1" 1447 return f"IDENTITY({start}, {increment})" 1448 1449 1450def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1451 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1452 if expression.args.get("count"): 1453 self.unsupported(f"Only two arguments are supported in function {name}.") 1454 1455 return self.func(name, expression.this, expression.expression) 1456 1457 return _arg_max_or_min_sql 1458 1459 1460def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1461 this = expression.this.copy() 1462 1463 return_type = expression.return_type 1464 if return_type.is_type(exp.DataType.Type.DATE): 1465 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1466 # can truncate timestamp strings, because some dialects can't cast them to DATE 1467 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1468 1469 expression.this.replace(exp.cast(this, return_type)) 1470 return expression 1471 1472 1473def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1474 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1475 if cast and isinstance(expression, exp.TsOrDsAdd): 1476 expression = ts_or_ds_add_cast(expression) 1477 1478 return self.func( 1479 name, 1480 unit_to_var(expression), 1481 expression.expression, 1482 expression.this, 1483 ) 1484 1485 return _delta_sql 1486 1487 1488def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1489 unit = expression.args.get("unit") 1490 1491 if isinstance(unit, exp.Placeholder): 1492 return unit 1493 if unit: 1494 return exp.Literal.string(unit.name) 1495 return exp.Literal.string(default) if default else None 1496 1497 1498def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1499 unit = expression.args.get("unit") 1500 1501 if isinstance(unit, (exp.Var, exp.Placeholder)): 1502 return unit 1503 return exp.Var(this=default) if default else None 1504 1505 1506@t.overload 1507def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1508 pass 1509 1510 1511@t.overload 1512def map_date_part( 1513 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1514) -> t.Optional[exp.Expression]: 1515 pass 1516 1517 1518def map_date_part(part, dialect: DialectType = Dialect): 1519 mapped = ( 1520 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1521 ) 1522 return exp.var(mapped) if mapped else part 1523 1524 1525def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1526 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1527 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1528 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1529 1530 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1531 1532 1533def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1534 """Remove table refs from columns in when statements.""" 1535 alias = expression.this.args.get("alias") 1536 1537 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1538 return self.dialect.normalize_identifier(identifier).name if identifier else None 1539 1540 targets = {normalize(expression.this.this)} 1541 1542 if alias: 1543 targets.add(normalize(alias.this)) 1544 1545 for when in expression.expressions: 1546 # only remove the target names from the THEN clause 1547 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1548 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1549 then = when.args.get("then") 1550 if then: 1551 then.transform( 1552 lambda node: ( 1553 exp.column(node.this) 1554 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1555 else node 1556 ), 1557 copy=False, 1558 ) 1559 1560 return self.merge_sql(expression) 1561 1562 1563def build_json_extract_path( 1564 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1565) -> t.Callable[[t.List], F]: 1566 def _builder(args: t.List) -> F: 1567 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1568 for arg in args[1:]: 1569 if not isinstance(arg, exp.Literal): 1570 # We use the fallback parser because we can't really transpile non-literals safely 1571 return expr_type.from_arg_list(args) 1572 1573 text = arg.name 1574 if is_int(text): 1575 index = int(text) 1576 segments.append( 1577 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1578 ) 1579 else: 1580 segments.append(exp.JSONPathKey(this=text)) 1581 1582 # This is done to avoid failing in the expression validator due to the arg count 1583 del args[2:] 1584 return expr_type( 1585 this=seq_get(args, 0), 1586 expression=exp.JSONPath(expressions=segments), 1587 only_json_types=arrow_req_json_type, 1588 ) 1589 1590 return _builder 1591 1592 1593def json_extract_segments( 1594 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1595) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1596 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1597 path = expression.expression 1598 if not isinstance(path, exp.JSONPath): 1599 return rename_func(name)(self, expression) 1600 1601 segments = [] 1602 for segment in path.expressions: 1603 path = self.sql(segment) 1604 if path: 1605 if isinstance(segment, exp.JSONPathPart) and ( 1606 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1607 ): 1608 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1609 1610 segments.append(path) 1611 1612 if op: 1613 return f" {op} ".join([self.sql(expression.this), *segments]) 1614 return self.func(name, expression.this, *segments) 1615 1616 return _json_extract_segments 1617 1618 1619def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1620 if isinstance(expression.this, exp.JSONPathWildcard): 1621 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1622 1623 return expression.name 1624 1625 1626def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1627 cond = expression.expression 1628 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1629 alias = cond.expressions[0] 1630 cond = cond.this 1631 elif isinstance(cond, exp.Predicate): 1632 alias = "_u" 1633 else: 1634 self.unsupported("Unsupported filter condition") 1635 return "" 1636 1637 unnest = exp.Unnest(expressions=[expression.this]) 1638 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1639 return self.sql(exp.Array(expressions=[filtered])) 1640 1641 1642def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1643 return self.func( 1644 "TO_NUMBER", 1645 expression.this, 1646 expression.args.get("format"), 1647 expression.args.get("nlsparam"), 1648 ) 1649 1650 1651def build_default_decimal_type( 1652 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1653) -> t.Callable[[exp.DataType], exp.DataType]: 1654 def _builder(dtype: exp.DataType) -> exp.DataType: 1655 if dtype.expressions or precision is None: 1656 return dtype 1657 1658 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1659 return exp.DataType.build(f"DECIMAL({params})") 1660 1661 return _builder 1662 1663 1664def build_timestamp_from_parts(args: t.List) -> exp.Func: 1665 if len(args) == 2: 1666 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1667 # so we parse this into Anonymous for now instead of introducing complexity 1668 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1669 1670 return exp.TimestampFromParts.from_arg_list(args) 1671 1672 1673def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1674 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1675 1676 1677def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1678 start = expression.args.get("start") 1679 end = expression.args.get("end") 1680 step = expression.args.get("step") 1681 1682 if isinstance(start, exp.Cast): 1683 target_type = start.to 1684 elif isinstance(end, exp.Cast): 1685 target_type = end.to 1686 else: 1687 target_type = None 1688 1689 if start and end and target_type and target_type.is_type("date", "timestamp"): 1690 if isinstance(start, exp.Cast) and target_type is start.to: 1691 end = exp.cast(end, target_type) 1692 else: 1693 start = exp.cast(start, target_type) 1694 1695 return self.func("SEQUENCE", start, end, step) 1696 1697 1698def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract: 1699 return exp.RegexpExtract( 1700 this=seq_get(args, 0), 1701 expression=seq_get(args, 1), 1702 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1703 )
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
222class Dialect(metaclass=_Dialect): 223 INDEX_OFFSET = 0 224 """The base index offset for arrays.""" 225 226 WEEK_OFFSET = 0 227 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 228 229 UNNEST_COLUMN_ONLY = False 230 """Whether `UNNEST` table aliases are treated as column aliases.""" 231 232 ALIAS_POST_TABLESAMPLE = False 233 """Whether the table alias comes after tablesample.""" 234 235 TABLESAMPLE_SIZE_IS_PERCENT = False 236 """Whether a size in the table sample clause represents percentage.""" 237 238 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 239 """Specifies the strategy according to which identifiers should be normalized.""" 240 241 IDENTIFIERS_CAN_START_WITH_DIGIT = False 242 """Whether an unquoted identifier can start with a digit.""" 243 244 DPIPE_IS_STRING_CONCAT = True 245 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 246 247 STRICT_STRING_CONCAT = False 248 """Whether `CONCAT`'s arguments must be strings.""" 249 250 SUPPORTS_USER_DEFINED_TYPES = True 251 """Whether user-defined data types are supported.""" 252 253 SUPPORTS_SEMI_ANTI_JOIN = True 254 """Whether `SEMI` or `ANTI` joins are supported.""" 255 256 SUPPORTS_COLUMN_JOIN_MARKS = False 257 """Whether the old-style outer join (+) syntax is supported.""" 258 259 COPY_PARAMS_ARE_CSV = True 260 """Separator of COPY statement parameters.""" 261 262 NORMALIZE_FUNCTIONS: bool | str = "upper" 263 """ 264 Determines how function names are going to be normalized. 265 Possible values: 266 "upper" or True: Convert names to uppercase. 267 "lower": Convert names to lowercase. 268 False: Disables function name normalization. 269 """ 270 271 LOG_BASE_FIRST: t.Optional[bool] = True 272 """ 273 Whether the base comes first in the `LOG` function. 274 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 275 """ 276 277 NULL_ORDERING = "nulls_are_small" 278 """ 279 Default `NULL` ordering method to use if not explicitly set. 280 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 281 """ 282 283 TYPED_DIVISION = False 284 """ 285 Whether the behavior of `a / b` depends on the types of `a` and `b`. 286 False means `a / b` is always float division. 287 True means `a / b` is integer division if both `a` and `b` are integers. 288 """ 289 290 SAFE_DIVISION = False 291 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 292 293 CONCAT_COALESCE = False 294 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 295 296 HEX_LOWERCASE = False 297 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 298 299 DATE_FORMAT = "'%Y-%m-%d'" 300 DATEINT_FORMAT = "'%Y%m%d'" 301 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 302 303 TIME_MAPPING: t.Dict[str, str] = {} 304 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 305 306 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 307 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 308 FORMAT_MAPPING: t.Dict[str, str] = {} 309 """ 310 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 311 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 312 """ 313 314 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 315 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 316 317 PSEUDOCOLUMNS: t.Set[str] = set() 318 """ 319 Columns that are auto-generated by the engine corresponding to this dialect. 320 For example, such columns may be excluded from `SELECT *` queries. 321 """ 322 323 PREFER_CTE_ALIAS_COLUMN = False 324 """ 325 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 326 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 327 any projection aliases in the subquery. 328 329 For example, 330 WITH y(c) AS ( 331 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 332 ) SELECT c FROM y; 333 334 will be rewritten as 335 336 WITH y(c) AS ( 337 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 338 ) SELECT c FROM y; 339 """ 340 341 COPY_PARAMS_ARE_CSV = True 342 """ 343 Whether COPY statement parameters are separated by comma or whitespace 344 """ 345 346 FORCE_EARLY_ALIAS_REF_EXPANSION = False 347 """ 348 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 349 350 For example: 351 WITH data AS ( 352 SELECT 353 1 AS id, 354 2 AS my_id 355 ) 356 SELECT 357 id AS my_id 358 FROM 359 data 360 WHERE 361 my_id = 1 362 GROUP BY 363 my_id, 364 HAVING 365 my_id = 1 366 367 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 368 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 369 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 370 - Clickhouse, which will forward the alias across the query i.e it resolves 371 to "WHERE id = 1 GROUP BY id HAVING id = 1" 372 """ 373 374 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 375 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 376 377 SUPPORTS_ORDER_BY_ALL = False 378 """ 379 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 380 """ 381 382 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 383 """ 384 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 385 as the former is of type INT[] vs the latter which is SUPER 386 """ 387 388 SUPPORTS_FIXED_SIZE_ARRAYS = False 389 """ 390 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 391 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 392 be interpreted as a subscript/index operator. 393 """ 394 395 STRICT_JSON_PATH_SYNTAX = True 396 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 397 398 ON_CONDITION_EMPTY_BEFORE_ERROR = True 399 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 400 401 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 402 """Whether ArrayAgg needs to filter NULL values.""" 403 404 REGEXP_EXTRACT_DEFAULT_GROUP = 0 405 """The default value for the capturing group.""" 406 407 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 408 exp.Except: True, 409 exp.Intersect: True, 410 exp.Union: True, 411 } 412 """ 413 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 414 must be explicitly specified. 415 """ 416 417 CREATABLE_KIND_MAPPING: dict[str, str] = {} 418 """ 419 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 420 equivalent of CREATE SCHEMA is CREATE DATABASE. 421 """ 422 423 # --- Autofilled --- 424 425 tokenizer_class = Tokenizer 426 jsonpath_tokenizer_class = JSONPathTokenizer 427 parser_class = Parser 428 generator_class = Generator 429 430 # A trie of the time_mapping keys 431 TIME_TRIE: t.Dict = {} 432 FORMAT_TRIE: t.Dict = {} 433 434 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 435 INVERSE_TIME_TRIE: t.Dict = {} 436 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 437 INVERSE_FORMAT_TRIE: t.Dict = {} 438 439 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 440 441 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 442 443 # Delimiters for string literals and identifiers 444 QUOTE_START = "'" 445 QUOTE_END = "'" 446 IDENTIFIER_START = '"' 447 IDENTIFIER_END = '"' 448 449 # Delimiters for bit, hex, byte and unicode literals 450 BIT_START: t.Optional[str] = None 451 BIT_END: t.Optional[str] = None 452 HEX_START: t.Optional[str] = None 453 HEX_END: t.Optional[str] = None 454 BYTE_START: t.Optional[str] = None 455 BYTE_END: t.Optional[str] = None 456 UNICODE_START: t.Optional[str] = None 457 UNICODE_END: t.Optional[str] = None 458 459 DATE_PART_MAPPING = { 460 "Y": "YEAR", 461 "YY": "YEAR", 462 "YYY": "YEAR", 463 "YYYY": "YEAR", 464 "YR": "YEAR", 465 "YEARS": "YEAR", 466 "YRS": "YEAR", 467 "MM": "MONTH", 468 "MON": "MONTH", 469 "MONS": "MONTH", 470 "MONTHS": "MONTH", 471 "D": "DAY", 472 "DD": "DAY", 473 "DAYS": "DAY", 474 "DAYOFMONTH": "DAY", 475 "DAY OF WEEK": "DAYOFWEEK", 476 "WEEKDAY": "DAYOFWEEK", 477 "DOW": "DAYOFWEEK", 478 "DW": "DAYOFWEEK", 479 "WEEKDAY_ISO": "DAYOFWEEKISO", 480 "DOW_ISO": "DAYOFWEEKISO", 481 "DW_ISO": "DAYOFWEEKISO", 482 "DAY OF YEAR": "DAYOFYEAR", 483 "DOY": "DAYOFYEAR", 484 "DY": "DAYOFYEAR", 485 "W": "WEEK", 486 "WK": "WEEK", 487 "WEEKOFYEAR": "WEEK", 488 "WOY": "WEEK", 489 "WY": "WEEK", 490 "WEEK_ISO": "WEEKISO", 491 "WEEKOFYEARISO": "WEEKISO", 492 "WEEKOFYEAR_ISO": "WEEKISO", 493 "Q": "QUARTER", 494 "QTR": "QUARTER", 495 "QTRS": "QUARTER", 496 "QUARTERS": "QUARTER", 497 "H": "HOUR", 498 "HH": "HOUR", 499 "HR": "HOUR", 500 "HOURS": "HOUR", 501 "HRS": "HOUR", 502 "M": "MINUTE", 503 "MI": "MINUTE", 504 "MIN": "MINUTE", 505 "MINUTES": "MINUTE", 506 "MINS": "MINUTE", 507 "S": "SECOND", 508 "SEC": "SECOND", 509 "SECONDS": "SECOND", 510 "SECS": "SECOND", 511 "MS": "MILLISECOND", 512 "MSEC": "MILLISECOND", 513 "MSECS": "MILLISECOND", 514 "MSECOND": "MILLISECOND", 515 "MSECONDS": "MILLISECOND", 516 "MILLISEC": "MILLISECOND", 517 "MILLISECS": "MILLISECOND", 518 "MILLISECON": "MILLISECOND", 519 "MILLISECONDS": "MILLISECOND", 520 "US": "MICROSECOND", 521 "USEC": "MICROSECOND", 522 "USECS": "MICROSECOND", 523 "MICROSEC": "MICROSECOND", 524 "MICROSECS": "MICROSECOND", 525 "USECOND": "MICROSECOND", 526 "USECONDS": "MICROSECOND", 527 "MICROSECONDS": "MICROSECOND", 528 "NS": "NANOSECOND", 529 "NSEC": "NANOSECOND", 530 "NANOSEC": "NANOSECOND", 531 "NSECOND": "NANOSECOND", 532 "NSECONDS": "NANOSECOND", 533 "NANOSECS": "NANOSECOND", 534 "EPOCH_SECOND": "EPOCH", 535 "EPOCH_SECONDS": "EPOCH", 536 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 537 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 538 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 539 "TZH": "TIMEZONE_HOUR", 540 "TZM": "TIMEZONE_MINUTE", 541 "DEC": "DECADE", 542 "DECS": "DECADE", 543 "DECADES": "DECADE", 544 "MIL": "MILLENIUM", 545 "MILS": "MILLENIUM", 546 "MILLENIA": "MILLENIUM", 547 "C": "CENTURY", 548 "CENT": "CENTURY", 549 "CENTS": "CENTURY", 550 "CENTURIES": "CENTURY", 551 } 552 553 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 554 exp.DataType.Type.BIGINT: { 555 exp.ApproxDistinct, 556 exp.ArraySize, 557 exp.Length, 558 }, 559 exp.DataType.Type.BOOLEAN: { 560 exp.Between, 561 exp.Boolean, 562 exp.In, 563 exp.RegexpLike, 564 }, 565 exp.DataType.Type.DATE: { 566 exp.CurrentDate, 567 exp.Date, 568 exp.DateFromParts, 569 exp.DateStrToDate, 570 exp.DiToDate, 571 exp.StrToDate, 572 exp.TimeStrToDate, 573 exp.TsOrDsToDate, 574 }, 575 exp.DataType.Type.DATETIME: { 576 exp.CurrentDatetime, 577 exp.Datetime, 578 exp.DatetimeAdd, 579 exp.DatetimeSub, 580 }, 581 exp.DataType.Type.DOUBLE: { 582 exp.ApproxQuantile, 583 exp.Avg, 584 exp.Div, 585 exp.Exp, 586 exp.Ln, 587 exp.Log, 588 exp.Pow, 589 exp.Quantile, 590 exp.Round, 591 exp.SafeDivide, 592 exp.Sqrt, 593 exp.Stddev, 594 exp.StddevPop, 595 exp.StddevSamp, 596 exp.Variance, 597 exp.VariancePop, 598 }, 599 exp.DataType.Type.INT: { 600 exp.Ceil, 601 exp.DatetimeDiff, 602 exp.DateDiff, 603 exp.TimestampDiff, 604 exp.TimeDiff, 605 exp.DateToDi, 606 exp.Levenshtein, 607 exp.Sign, 608 exp.StrPosition, 609 exp.TsOrDiToDi, 610 }, 611 exp.DataType.Type.JSON: { 612 exp.ParseJSON, 613 }, 614 exp.DataType.Type.TIME: { 615 exp.Time, 616 }, 617 exp.DataType.Type.TIMESTAMP: { 618 exp.CurrentTime, 619 exp.CurrentTimestamp, 620 exp.StrToTime, 621 exp.TimeAdd, 622 exp.TimeStrToTime, 623 exp.TimeSub, 624 exp.TimestampAdd, 625 exp.TimestampSub, 626 exp.UnixToTime, 627 }, 628 exp.DataType.Type.TINYINT: { 629 exp.Day, 630 exp.Month, 631 exp.Week, 632 exp.Year, 633 exp.Quarter, 634 }, 635 exp.DataType.Type.VARCHAR: { 636 exp.ArrayConcat, 637 exp.Concat, 638 exp.ConcatWs, 639 exp.DateToDateStr, 640 exp.GroupConcat, 641 exp.Initcap, 642 exp.Lower, 643 exp.Substring, 644 exp.TimeToStr, 645 exp.TimeToTimeStr, 646 exp.Trim, 647 exp.TsOrDsToDateStr, 648 exp.UnixToStr, 649 exp.UnixToTimeStr, 650 exp.Upper, 651 }, 652 } 653 654 ANNOTATORS: AnnotatorsType = { 655 **{ 656 expr_type: lambda self, e: self._annotate_unary(e) 657 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 658 }, 659 **{ 660 expr_type: lambda self, e: self._annotate_binary(e) 661 for expr_type in subclasses(exp.__name__, exp.Binary) 662 }, 663 **{ 664 expr_type: _annotate_with_type_lambda(data_type) 665 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 666 for expr_type in expressions 667 }, 668 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 669 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 670 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 671 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 672 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 673 exp.Bracket: lambda self, e: self._annotate_bracket(e), 674 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 675 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 676 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 677 exp.Count: lambda self, e: self._annotate_with_type( 678 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 679 ), 680 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 681 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 682 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 683 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 684 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 685 exp.Div: lambda self, e: self._annotate_div(e), 686 exp.Dot: lambda self, e: self._annotate_dot(e), 687 exp.Explode: lambda self, e: self._annotate_explode(e), 688 exp.Extract: lambda self, e: self._annotate_extract(e), 689 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 690 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 691 e, exp.DataType.build("ARRAY<DATE>") 692 ), 693 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 694 e, exp.DataType.build("ARRAY<TIMESTAMP>") 695 ), 696 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 697 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 698 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 699 exp.Literal: lambda self, e: self._annotate_literal(e), 700 exp.Map: lambda self, e: self._annotate_map(e), 701 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 702 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 703 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 704 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 705 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 706 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 707 exp.Struct: lambda self, e: self._annotate_struct(e), 708 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 709 exp.Timestamp: lambda self, e: self._annotate_with_type( 710 e, 711 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 712 ), 713 exp.ToMap: lambda self, e: self._annotate_to_map(e), 714 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 715 exp.Unnest: lambda self, e: self._annotate_unnest(e), 716 exp.VarMap: lambda self, e: self._annotate_map(e), 717 } 718 719 @classmethod 720 def get_or_raise(cls, dialect: DialectType) -> Dialect: 721 """ 722 Look up a dialect in the global dialect registry and return it if it exists. 723 724 Args: 725 dialect: The target dialect. If this is a string, it can be optionally followed by 726 additional key-value pairs that are separated by commas and are used to specify 727 dialect settings, such as whether the dialect's identifiers are case-sensitive. 728 729 Example: 730 >>> dialect = dialect_class = get_or_raise("duckdb") 731 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 732 733 Returns: 734 The corresponding Dialect instance. 735 """ 736 737 if not dialect: 738 return cls() 739 if isinstance(dialect, _Dialect): 740 return dialect() 741 if isinstance(dialect, Dialect): 742 return dialect 743 if isinstance(dialect, str): 744 try: 745 dialect_name, *kv_strings = dialect.split(",") 746 kv_pairs = (kv.split("=") for kv in kv_strings) 747 kwargs = {} 748 for pair in kv_pairs: 749 key = pair[0].strip() 750 value: t.Union[bool | str | None] = None 751 752 if len(pair) == 1: 753 # Default initialize standalone settings to True 754 value = True 755 elif len(pair) == 2: 756 value = pair[1].strip() 757 758 # Coerce the value to boolean if it matches to the truthy/falsy values below 759 value_lower = value.lower() 760 if value_lower in ("true", "1"): 761 value = True 762 elif value_lower in ("false", "0"): 763 value = False 764 765 kwargs[key] = value 766 767 except ValueError: 768 raise ValueError( 769 f"Invalid dialect format: '{dialect}'. " 770 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 771 ) 772 773 result = cls.get(dialect_name.strip()) 774 if not result: 775 from difflib import get_close_matches 776 777 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 778 if similar: 779 similar = f" Did you mean {similar}?" 780 781 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 782 783 return result(**kwargs) 784 785 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 786 787 @classmethod 788 def format_time( 789 cls, expression: t.Optional[str | exp.Expression] 790 ) -> t.Optional[exp.Expression]: 791 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 792 if isinstance(expression, str): 793 return exp.Literal.string( 794 # the time formats are quoted 795 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 796 ) 797 798 if expression and expression.is_string: 799 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 800 801 return expression 802 803 def __init__(self, **kwargs) -> None: 804 normalization_strategy = kwargs.pop("normalization_strategy", None) 805 806 if normalization_strategy is None: 807 self.normalization_strategy = self.NORMALIZATION_STRATEGY 808 else: 809 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 810 811 self.settings = kwargs 812 813 def __eq__(self, other: t.Any) -> bool: 814 # Does not currently take dialect state into account 815 return type(self) == other 816 817 def __hash__(self) -> int: 818 # Does not currently take dialect state into account 819 return hash(type(self)) 820 821 def normalize_identifier(self, expression: E) -> E: 822 """ 823 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 824 825 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 826 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 827 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 828 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 829 830 There are also dialects like Spark, which are case-insensitive even when quotes are 831 present, and dialects like MySQL, whose resolution rules match those employed by the 832 underlying operating system, for example they may always be case-sensitive in Linux. 833 834 Finally, the normalization behavior of some engines can even be controlled through flags, 835 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 836 837 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 838 that it can analyze queries in the optimizer and successfully capture their semantics. 839 """ 840 if ( 841 isinstance(expression, exp.Identifier) 842 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 843 and ( 844 not expression.quoted 845 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 846 ) 847 ): 848 expression.set( 849 "this", 850 ( 851 expression.this.upper() 852 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 853 else expression.this.lower() 854 ), 855 ) 856 857 return expression 858 859 def case_sensitive(self, text: str) -> bool: 860 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 861 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 862 return False 863 864 unsafe = ( 865 str.islower 866 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 867 else str.isupper 868 ) 869 return any(unsafe(char) for char in text) 870 871 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 872 """Checks if text can be identified given an identify option. 873 874 Args: 875 text: The text to check. 876 identify: 877 `"always"` or `True`: Always returns `True`. 878 `"safe"`: Only returns `True` if the identifier is case-insensitive. 879 880 Returns: 881 Whether the given text can be identified. 882 """ 883 if identify is True or identify == "always": 884 return True 885 886 if identify == "safe": 887 return not self.case_sensitive(text) 888 889 return False 890 891 def quote_identifier(self, expression: E, identify: bool = True) -> E: 892 """ 893 Adds quotes to a given identifier. 894 895 Args: 896 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 897 identify: If set to `False`, the quotes will only be added if the identifier is deemed 898 "unsafe", with respect to its characters and this dialect's normalization strategy. 899 """ 900 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 901 name = expression.this 902 expression.set( 903 "quoted", 904 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 905 ) 906 907 return expression 908 909 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 910 if isinstance(path, exp.Literal): 911 path_text = path.name 912 if path.is_number: 913 path_text = f"[{path_text}]" 914 try: 915 return parse_json_path(path_text, self) 916 except ParseError as e: 917 if self.STRICT_JSON_PATH_SYNTAX: 918 logger.warning(f"Invalid JSON path syntax. {str(e)}") 919 920 return path 921 922 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 923 return self.parser(**opts).parse(self.tokenize(sql), sql) 924 925 def parse_into( 926 self, expression_type: exp.IntoType, sql: str, **opts 927 ) -> t.List[t.Optional[exp.Expression]]: 928 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 929 930 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 931 return self.generator(**opts).generate(expression, copy=copy) 932 933 def transpile(self, sql: str, **opts) -> t.List[str]: 934 return [ 935 self.generate(expression, copy=False, **opts) if expression else "" 936 for expression in self.parse(sql) 937 ] 938 939 def tokenize(self, sql: str) -> t.List[Token]: 940 return self.tokenizer.tokenize(sql) 941 942 @property 943 def tokenizer(self) -> Tokenizer: 944 return self.tokenizer_class(dialect=self) 945 946 @property 947 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 948 return self.jsonpath_tokenizer_class(dialect=self) 949 950 def parser(self, **opts) -> Parser: 951 return self.parser_class(dialect=self, **opts) 952 953 def generator(self, **opts) -> Generator: 954 return self.generator_class(dialect=self, **opts)
803 def __init__(self, **kwargs) -> None: 804 normalization_strategy = kwargs.pop("normalization_strategy", None) 805 806 if normalization_strategy is None: 807 self.normalization_strategy = self.NORMALIZATION_STRATEGY 808 else: 809 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 810 811 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
719 @classmethod 720 def get_or_raise(cls, dialect: DialectType) -> Dialect: 721 """ 722 Look up a dialect in the global dialect registry and return it if it exists. 723 724 Args: 725 dialect: The target dialect. If this is a string, it can be optionally followed by 726 additional key-value pairs that are separated by commas and are used to specify 727 dialect settings, such as whether the dialect's identifiers are case-sensitive. 728 729 Example: 730 >>> dialect = dialect_class = get_or_raise("duckdb") 731 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 732 733 Returns: 734 The corresponding Dialect instance. 735 """ 736 737 if not dialect: 738 return cls() 739 if isinstance(dialect, _Dialect): 740 return dialect() 741 if isinstance(dialect, Dialect): 742 return dialect 743 if isinstance(dialect, str): 744 try: 745 dialect_name, *kv_strings = dialect.split(",") 746 kv_pairs = (kv.split("=") for kv in kv_strings) 747 kwargs = {} 748 for pair in kv_pairs: 749 key = pair[0].strip() 750 value: t.Union[bool | str | None] = None 751 752 if len(pair) == 1: 753 # Default initialize standalone settings to True 754 value = True 755 elif len(pair) == 2: 756 value = pair[1].strip() 757 758 # Coerce the value to boolean if it matches to the truthy/falsy values below 759 value_lower = value.lower() 760 if value_lower in ("true", "1"): 761 value = True 762 elif value_lower in ("false", "0"): 763 value = False 764 765 kwargs[key] = value 766 767 except ValueError: 768 raise ValueError( 769 f"Invalid dialect format: '{dialect}'. " 770 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 771 ) 772 773 result = cls.get(dialect_name.strip()) 774 if not result: 775 from difflib import get_close_matches 776 777 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 778 if similar: 779 similar = f" Did you mean {similar}?" 780 781 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 782 783 return result(**kwargs) 784 785 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
787 @classmethod 788 def format_time( 789 cls, expression: t.Optional[str | exp.Expression] 790 ) -> t.Optional[exp.Expression]: 791 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 792 if isinstance(expression, str): 793 return exp.Literal.string( 794 # the time formats are quoted 795 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 796 ) 797 798 if expression and expression.is_string: 799 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 800 801 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
821 def normalize_identifier(self, expression: E) -> E: 822 """ 823 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 824 825 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 826 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 827 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 828 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 829 830 There are also dialects like Spark, which are case-insensitive even when quotes are 831 present, and dialects like MySQL, whose resolution rules match those employed by the 832 underlying operating system, for example they may always be case-sensitive in Linux. 833 834 Finally, the normalization behavior of some engines can even be controlled through flags, 835 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 836 837 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 838 that it can analyze queries in the optimizer and successfully capture their semantics. 839 """ 840 if ( 841 isinstance(expression, exp.Identifier) 842 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 843 and ( 844 not expression.quoted 845 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 846 ) 847 ): 848 expression.set( 849 "this", 850 ( 851 expression.this.upper() 852 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 853 else expression.this.lower() 854 ), 855 ) 856 857 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
859 def case_sensitive(self, text: str) -> bool: 860 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 861 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 862 return False 863 864 unsafe = ( 865 str.islower 866 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 867 else str.isupper 868 ) 869 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
871 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 872 """Checks if text can be identified given an identify option. 873 874 Args: 875 text: The text to check. 876 identify: 877 `"always"` or `True`: Always returns `True`. 878 `"safe"`: Only returns `True` if the identifier is case-insensitive. 879 880 Returns: 881 Whether the given text can be identified. 882 """ 883 if identify is True or identify == "always": 884 return True 885 886 if identify == "safe": 887 return not self.case_sensitive(text) 888 889 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
891 def quote_identifier(self, expression: E, identify: bool = True) -> E: 892 """ 893 Adds quotes to a given identifier. 894 895 Args: 896 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 897 identify: If set to `False`, the quotes will only be added if the identifier is deemed 898 "unsafe", with respect to its characters and this dialect's normalization strategy. 899 """ 900 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 901 name = expression.this 902 expression.set( 903 "quoted", 904 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 905 ) 906 907 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
909 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 910 if isinstance(path, exp.Literal): 911 path_text = path.name 912 if path.is_number: 913 path_text = f"[{path_text}]" 914 try: 915 return parse_json_path(path_text, self) 916 except ParseError as e: 917 if self.STRICT_JSON_PATH_SYNTAX: 918 logger.warning(f"Invalid JSON path syntax. {str(e)}") 919 920 return path
970def if_sql( 971 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 972) -> t.Callable[[Generator, exp.If], str]: 973 def _if_sql(self: Generator, expression: exp.If) -> str: 974 return self.func( 975 name, 976 expression.this, 977 expression.args.get("true"), 978 expression.args.get("false") or false_value, 979 ) 980 981 return _if_sql
984def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 985 this = expression.this 986 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 987 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 988 989 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1059def str_position_sql( 1060 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1061) -> str: 1062 this = self.sql(expression, "this") 1063 substr = self.sql(expression, "substr") 1064 position = self.sql(expression, "position") 1065 instance = expression.args.get("instance") if generate_instance else None 1066 position_offset = "" 1067 1068 if position: 1069 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1070 this = self.func("SUBSTR", this, position) 1071 position_offset = f" + {position} - 1" 1072 1073 return self.func("STRPOS", this, substr, instance) + position_offset
1082def var_map_sql( 1083 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1084) -> str: 1085 keys = expression.args["keys"] 1086 values = expression.args["values"] 1087 1088 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1089 self.unsupported("Cannot convert array columns into map.") 1090 return self.func(map_func_name, keys, values) 1091 1092 args = [] 1093 for key, value in zip(keys.expressions, values.expressions): 1094 args.append(self.sql(key)) 1095 args.append(self.sql(value)) 1096 1097 return self.func(map_func_name, *args)
1100def build_formatted_time( 1101 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1102) -> t.Callable[[t.List], E]: 1103 """Helper used for time expressions. 1104 1105 Args: 1106 exp_class: the expression class to instantiate. 1107 dialect: target sql dialect. 1108 default: the default format, True being time. 1109 1110 Returns: 1111 A callable that can be used to return the appropriately formatted time expression. 1112 """ 1113 1114 def _builder(args: t.List): 1115 return exp_class( 1116 this=seq_get(args, 0), 1117 format=Dialect[dialect].format_time( 1118 seq_get(args, 1) 1119 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1120 ), 1121 ) 1122 1123 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1126def time_format( 1127 dialect: DialectType = None, 1128) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1129 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1130 """ 1131 Returns the time format for a given expression, unless it's equivalent 1132 to the default time format of the dialect of interest. 1133 """ 1134 time_format = self.format_time(expression) 1135 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1136 1137 return _time_format
1140def build_date_delta( 1141 exp_class: t.Type[E], 1142 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1143 default_unit: t.Optional[str] = "DAY", 1144) -> t.Callable[[t.List], E]: 1145 def _builder(args: t.List) -> E: 1146 unit_based = len(args) == 3 1147 this = args[2] if unit_based else seq_get(args, 0) 1148 unit = None 1149 if unit_based or default_unit: 1150 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1151 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1152 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1153 1154 return _builder
1157def build_date_delta_with_interval( 1158 expression_class: t.Type[E], 1159) -> t.Callable[[t.List], t.Optional[E]]: 1160 def _builder(args: t.List) -> t.Optional[E]: 1161 if len(args) < 2: 1162 return None 1163 1164 interval = args[1] 1165 1166 if not isinstance(interval, exp.Interval): 1167 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1168 1169 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1170 1171 return _builder
1174def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1175 unit = seq_get(args, 0) 1176 this = seq_get(args, 1) 1177 1178 if isinstance(this, exp.Cast) and this.is_type("date"): 1179 return exp.DateTrunc(unit=unit, this=this) 1180 return exp.TimestampTrunc(this=this, unit=unit)
1183def date_add_interval_sql( 1184 data_type: str, kind: str 1185) -> t.Callable[[Generator, exp.Expression], str]: 1186 def func(self: Generator, expression: exp.Expression) -> str: 1187 this = self.sql(expression, "this") 1188 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1189 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1190 1191 return func
1194def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1195 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1196 args = [unit_to_str(expression), expression.this] 1197 if zone: 1198 args.append(expression.args.get("zone")) 1199 return self.func("DATE_TRUNC", *args) 1200 1201 return _timestamptrunc_sql
1204def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1205 zone = expression.args.get("zone") 1206 if not zone: 1207 from sqlglot.optimizer.annotate_types import annotate_types 1208 1209 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1210 return self.sql(exp.cast(expression.this, target_type)) 1211 if zone.name.lower() in TIMEZONES: 1212 return self.sql( 1213 exp.AtTimeZone( 1214 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1215 zone=zone, 1216 ) 1217 ) 1218 return self.func("TIMESTAMP", expression.this, zone)
1221def no_time_sql(self: Generator, expression: exp.Time) -> str: 1222 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1223 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1224 expr = exp.cast( 1225 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1226 ) 1227 return self.sql(expr)
1230def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1231 this = expression.this 1232 expr = expression.expression 1233 1234 if expr.name.lower() in TIMEZONES: 1235 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1236 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1237 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1238 return self.sql(this) 1239 1240 this = exp.cast(this, exp.DataType.Type.DATE) 1241 expr = exp.cast(expr, exp.DataType.Type.TIME) 1242 1243 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1275def timestrtotime_sql( 1276 self: Generator, 1277 expression: exp.TimeStrToTime, 1278 include_precision: bool = False, 1279) -> str: 1280 datatype = exp.DataType.build( 1281 exp.DataType.Type.TIMESTAMPTZ 1282 if expression.args.get("zone") 1283 else exp.DataType.Type.TIMESTAMP 1284 ) 1285 1286 if isinstance(expression.this, exp.Literal) and include_precision: 1287 precision = subsecond_precision(expression.this.name) 1288 if precision > 0: 1289 datatype = exp.DataType.build( 1290 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1291 ) 1292 1293 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1301def encode_decode_sql( 1302 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1303) -> str: 1304 charset = expression.args.get("charset") 1305 if charset and charset.name.lower() != "utf-8": 1306 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1307 1308 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1321def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1322 cond = expression.this 1323 1324 if isinstance(expression.this, exp.Distinct): 1325 cond = expression.this.expressions[0] 1326 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1327 1328 return self.func("sum", exp.func("if", cond, 1, 0))
1331def trim_sql(self: Generator, expression: exp.Trim) -> str: 1332 target = self.sql(expression, "this") 1333 trim_type = self.sql(expression, "position") 1334 remove_chars = self.sql(expression, "expression") 1335 collation = self.sql(expression, "collation") 1336 1337 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1338 if not remove_chars: 1339 return self.trim_sql(expression) 1340 1341 trim_type = f"{trim_type} " if trim_type else "" 1342 remove_chars = f"{remove_chars} " if remove_chars else "" 1343 from_part = "FROM " if trim_type or remove_chars else "" 1344 collation = f" COLLATE {collation}" if collation else "" 1345 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1366def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1367 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1368 if bad_args: 1369 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1370 1371 group = expression.args.get("group") 1372 1373 # Do not render group if it's the default value for this dialect 1374 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1375 group = None 1376 1377 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group)
1380def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1381 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1382 if bad_args: 1383 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1384 1385 return self.func( 1386 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1387 )
1390def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1391 names = [] 1392 for agg in aggregations: 1393 if isinstance(agg, exp.Alias): 1394 names.append(agg.alias) 1395 else: 1396 """ 1397 This case corresponds to aggregations without aliases being used as suffixes 1398 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1399 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1400 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1401 """ 1402 agg_all_unquoted = agg.transform( 1403 lambda node: ( 1404 exp.Identifier(this=node.name, quoted=False) 1405 if isinstance(node, exp.Identifier) 1406 else node 1407 ) 1408 ) 1409 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1410 1411 return names
1451def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1452 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1453 if expression.args.get("count"): 1454 self.unsupported(f"Only two arguments are supported in function {name}.") 1455 1456 return self.func(name, expression.this, expression.expression) 1457 1458 return _arg_max_or_min_sql
1461def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1462 this = expression.this.copy() 1463 1464 return_type = expression.return_type 1465 if return_type.is_type(exp.DataType.Type.DATE): 1466 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1467 # can truncate timestamp strings, because some dialects can't cast them to DATE 1468 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1469 1470 expression.this.replace(exp.cast(this, return_type)) 1471 return expression
1474def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1475 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1476 if cast and isinstance(expression, exp.TsOrDsAdd): 1477 expression = ts_or_ds_add_cast(expression) 1478 1479 return self.func( 1480 name, 1481 unit_to_var(expression), 1482 expression.expression, 1483 expression.this, 1484 ) 1485 1486 return _delta_sql
1489def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1490 unit = expression.args.get("unit") 1491 1492 if isinstance(unit, exp.Placeholder): 1493 return unit 1494 if unit: 1495 return exp.Literal.string(unit.name) 1496 return exp.Literal.string(default) if default else None
1526def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1527 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1528 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1529 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1530 1531 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1534def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1535 """Remove table refs from columns in when statements.""" 1536 alias = expression.this.args.get("alias") 1537 1538 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1539 return self.dialect.normalize_identifier(identifier).name if identifier else None 1540 1541 targets = {normalize(expression.this.this)} 1542 1543 if alias: 1544 targets.add(normalize(alias.this)) 1545 1546 for when in expression.expressions: 1547 # only remove the target names from the THEN clause 1548 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1549 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1550 then = when.args.get("then") 1551 if then: 1552 then.transform( 1553 lambda node: ( 1554 exp.column(node.this) 1555 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1556 else node 1557 ), 1558 copy=False, 1559 ) 1560 1561 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1564def build_json_extract_path( 1565 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1566) -> t.Callable[[t.List], F]: 1567 def _builder(args: t.List) -> F: 1568 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1569 for arg in args[1:]: 1570 if not isinstance(arg, exp.Literal): 1571 # We use the fallback parser because we can't really transpile non-literals safely 1572 return expr_type.from_arg_list(args) 1573 1574 text = arg.name 1575 if is_int(text): 1576 index = int(text) 1577 segments.append( 1578 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1579 ) 1580 else: 1581 segments.append(exp.JSONPathKey(this=text)) 1582 1583 # This is done to avoid failing in the expression validator due to the arg count 1584 del args[2:] 1585 return expr_type( 1586 this=seq_get(args, 0), 1587 expression=exp.JSONPath(expressions=segments), 1588 only_json_types=arrow_req_json_type, 1589 ) 1590 1591 return _builder
1594def json_extract_segments( 1595 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1596) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1597 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1598 path = expression.expression 1599 if not isinstance(path, exp.JSONPath): 1600 return rename_func(name)(self, expression) 1601 1602 segments = [] 1603 for segment in path.expressions: 1604 path = self.sql(segment) 1605 if path: 1606 if isinstance(segment, exp.JSONPathPart) and ( 1607 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1608 ): 1609 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1610 1611 segments.append(path) 1612 1613 if op: 1614 return f" {op} ".join([self.sql(expression.this), *segments]) 1615 return self.func(name, expression.this, *segments) 1616 1617 return _json_extract_segments
1627def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1628 cond = expression.expression 1629 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1630 alias = cond.expressions[0] 1631 cond = cond.this 1632 elif isinstance(cond, exp.Predicate): 1633 alias = "_u" 1634 else: 1635 self.unsupported("Unsupported filter condition") 1636 return "" 1637 1638 unnest = exp.Unnest(expressions=[expression.this]) 1639 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1640 return self.sql(exp.Array(expressions=[filtered]))
1652def build_default_decimal_type( 1653 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1654) -> t.Callable[[exp.DataType], exp.DataType]: 1655 def _builder(dtype: exp.DataType) -> exp.DataType: 1656 if dtype.expressions or precision is None: 1657 return dtype 1658 1659 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1660 return exp.DataType.build(f"DECIMAL({params})") 1661 1662 return _builder
1665def build_timestamp_from_parts(args: t.List) -> exp.Func: 1666 if len(args) == 2: 1667 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1668 # so we parse this into Anonymous for now instead of introducing complexity 1669 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1670 1671 return exp.TimestampFromParts.from_arg_list(args)
1678def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1679 start = expression.args.get("start") 1680 end = expression.args.get("end") 1681 step = expression.args.get("step") 1682 1683 if isinstance(start, exp.Cast): 1684 target_type = start.to 1685 elif isinstance(end, exp.Cast): 1686 target_type = end.to 1687 else: 1688 target_type = None 1689 1690 if start and end and target_type and target_type.is_type("date", "timestamp"): 1691 if isinstance(start, exp.Cast) and target_type is start.to: 1692 end = exp.cast(end, target_type) 1693 else: 1694 start = exp.cast(start, target_type) 1695 1696 return self.func("SEQUENCE", start, end, step)