sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer, TokenType 12from sqlglot.trie import new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot._typing import E 16 17 18# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. 19# https://docs.snowflake.com/en/sql-reference/identifiers-syntax 20RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} 21 22 23class Dialects(str, Enum): 24 DIALECT = "" 25 26 BIGQUERY = "bigquery" 27 CLICKHOUSE = "clickhouse" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TRINO = "trino" 42 TSQL = "tsql" 43 DATABRICKS = "databricks" 44 DRILL = "drill" 45 TERADATA = "teradata" 46 47 48class _Dialect(type): 49 classes: t.Dict[str, t.Type[Dialect]] = {} 50 51 def __eq__(cls, other: t.Any) -> bool: 52 if cls is other: 53 return True 54 if isinstance(other, str): 55 return cls is cls.get(other) 56 if isinstance(other, Dialect): 57 return cls is type(other) 58 59 return False 60 61 def __hash__(cls) -> int: 62 return hash(cls.__name__.lower()) 63 64 @classmethod 65 def __getitem__(cls, key: str) -> t.Type[Dialect]: 66 return cls.classes[key] 67 68 @classmethod 69 def get( 70 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 71 ) -> t.Optional[t.Type[Dialect]]: 72 return cls.classes.get(key, default) 73 74 def __new__(cls, clsname, bases, attrs): 75 klass = super().__new__(cls, clsname, bases, attrs) 76 enum = Dialects.__members__.get(clsname.upper()) 77 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 78 79 klass.time_trie = new_trie(klass.time_mapping) 80 klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} 81 klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) 82 83 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 84 klass.parser_class = getattr(klass, "Parser", Parser) 85 klass.generator_class = getattr(klass, "Generator", Generator) 86 87 klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] 88 klass.identifier_start, klass.identifier_end = list( 89 klass.tokenizer_class._IDENTIFIERS.items() 90 )[0] 91 92 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 93 return next( 94 ( 95 (s, e) 96 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 97 if t == token_type 98 ), 99 (None, None), 100 ) 101 102 klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) 103 klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) 104 klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) 105 klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) 106 107 klass.tokenizer_class.identifiers_can_start_with_digit = ( 108 klass.identifiers_can_start_with_digit 109 ) 110 111 return klass 112 113 114class Dialect(metaclass=_Dialect): 115 index_offset = 0 116 unnest_column_only = False 117 alias_post_tablesample = False 118 identifiers_can_start_with_digit = False 119 normalize_functions: t.Optional[str] = "upper" 120 null_ordering = "nulls_are_small" 121 122 date_format = "'%Y-%m-%d'" 123 dateint_format = "'%Y%m%d'" 124 time_format = "'%Y-%m-%d %H:%M:%S'" 125 time_mapping: t.Dict[str, str] = {} 126 127 # autofilled 128 quote_start = None 129 quote_end = None 130 identifier_start = None 131 identifier_end = None 132 133 time_trie = None 134 inverse_time_mapping = None 135 inverse_time_trie = None 136 tokenizer_class = None 137 parser_class = None 138 generator_class = None 139 140 def __eq__(self, other: t.Any) -> bool: 141 return type(self) == other 142 143 def __hash__(self) -> int: 144 return hash(type(self)) 145 146 @classmethod 147 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 148 if not dialect: 149 return cls 150 if isinstance(dialect, _Dialect): 151 return dialect 152 if isinstance(dialect, Dialect): 153 return dialect.__class__ 154 155 result = cls.get(dialect) 156 if not result: 157 raise ValueError(f"Unknown dialect '{dialect}'") 158 159 return result 160 161 @classmethod 162 def format_time( 163 cls, expression: t.Optional[str | exp.Expression] 164 ) -> t.Optional[exp.Expression]: 165 if isinstance(expression, str): 166 return exp.Literal.string( 167 format_time( 168 expression[1:-1], # the time formats are quoted 169 cls.time_mapping, 170 cls.time_trie, 171 ) 172 ) 173 if expression and expression.is_string: 174 return exp.Literal.string( 175 format_time( 176 expression.this, 177 cls.time_mapping, 178 cls.time_trie, 179 ) 180 ) 181 return expression 182 183 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 184 return self.parser(**opts).parse(self.tokenize(sql), sql) 185 186 def parse_into( 187 self, expression_type: exp.IntoType, sql: str, **opts 188 ) -> t.List[t.Optional[exp.Expression]]: 189 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 190 191 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 192 return self.generator(**opts).generate(expression) 193 194 def transpile(self, sql: str, **opts) -> t.List[str]: 195 return [self.generate(expression, **opts) for expression in self.parse(sql)] 196 197 def tokenize(self, sql: str) -> t.List[Token]: 198 return self.tokenizer.tokenize(sql) 199 200 @property 201 def tokenizer(self) -> Tokenizer: 202 if not hasattr(self, "_tokenizer"): 203 self._tokenizer = self.tokenizer_class() # type: ignore 204 return self._tokenizer 205 206 def parser(self, **opts) -> Parser: 207 return self.parser_class( # type: ignore 208 **{ 209 "index_offset": self.index_offset, 210 "unnest_column_only": self.unnest_column_only, 211 "alias_post_tablesample": self.alias_post_tablesample, 212 "null_ordering": self.null_ordering, 213 **opts, 214 }, 215 ) 216 217 def generator(self, **opts) -> Generator: 218 return self.generator_class( # type: ignore 219 **{ 220 "quote_start": self.quote_start, 221 "quote_end": self.quote_end, 222 "bit_start": self.bit_start, 223 "bit_end": self.bit_end, 224 "hex_start": self.hex_start, 225 "hex_end": self.hex_end, 226 "byte_start": self.byte_start, 227 "byte_end": self.byte_end, 228 "raw_start": self.raw_start, 229 "raw_end": self.raw_end, 230 "identifier_start": self.identifier_start, 231 "identifier_end": self.identifier_end, 232 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 233 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 234 "index_offset": self.index_offset, 235 "time_mapping": self.inverse_time_mapping, 236 "time_trie": self.inverse_time_trie, 237 "unnest_column_only": self.unnest_column_only, 238 "alias_post_tablesample": self.alias_post_tablesample, 239 "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, 240 "normalize_functions": self.normalize_functions, 241 "null_ordering": self.null_ordering, 242 **opts, 243 } 244 ) 245 246 247DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 248 249 250def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 251 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 252 253 254def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 255 if expression.args.get("accuracy"): 256 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 257 return self.func("APPROX_COUNT_DISTINCT", expression.this) 258 259 260def if_sql(self: Generator, expression: exp.If) -> str: 261 return self.func( 262 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 263 ) 264 265 266def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 267 return self.binary(expression, "->") 268 269 270def arrow_json_extract_scalar_sql( 271 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 272) -> str: 273 return self.binary(expression, "->>") 274 275 276def inline_array_sql(self: Generator, expression: exp.Array) -> str: 277 return f"[{self.expressions(expression)}]" 278 279 280def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 281 return self.like_sql( 282 exp.Like( 283 this=exp.Lower(this=expression.this), 284 expression=expression.args["expression"], 285 ) 286 ) 287 288 289def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 290 zone = self.sql(expression, "this") 291 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 292 293 294def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 295 if expression.args.get("recursive"): 296 self.unsupported("Recursive CTEs are unsupported") 297 expression.args["recursive"] = False 298 return self.with_sql(expression) 299 300 301def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 302 n = self.sql(expression, "this") 303 d = self.sql(expression, "expression") 304 return f"IF({d} <> 0, {n} / {d}, NULL)" 305 306 307def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 308 self.unsupported("TABLESAMPLE unsupported") 309 return self.sql(expression.this) 310 311 312def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 313 self.unsupported("PIVOT unsupported") 314 return "" 315 316 317def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 318 return self.cast_sql(expression) 319 320 321def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 322 self.unsupported("Properties unsupported") 323 return "" 324 325 326def no_comment_column_constraint_sql( 327 self: Generator, expression: exp.CommentColumnConstraint 328) -> str: 329 self.unsupported("CommentColumnConstraint unsupported") 330 return "" 331 332 333def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 334 this = self.sql(expression, "this") 335 substr = self.sql(expression, "substr") 336 position = self.sql(expression, "position") 337 if position: 338 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 339 return f"STRPOS({this}, {substr})" 340 341 342def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 343 this = self.sql(expression, "this") 344 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 345 return f"{this}.{struct_key}" 346 347 348def var_map_sql( 349 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 350) -> str: 351 keys = expression.args["keys"] 352 values = expression.args["values"] 353 354 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 355 self.unsupported("Cannot convert array columns into map.") 356 return self.func(map_func_name, keys, values) 357 358 args = [] 359 for key, value in zip(keys.expressions, values.expressions): 360 args.append(self.sql(key)) 361 args.append(self.sql(value)) 362 return self.func(map_func_name, *args) 363 364 365def format_time_lambda( 366 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 367) -> t.Callable[[t.List], E]: 368 """Helper used for time expressions. 369 370 Args: 371 exp_class: the expression class to instantiate. 372 dialect: target sql dialect. 373 default: the default format, True being time. 374 375 Returns: 376 A callable that can be used to return the appropriately formatted time expression. 377 """ 378 379 def _format_time(args: t.List): 380 return exp_class( 381 this=seq_get(args, 0), 382 format=Dialect[dialect].format_time( 383 seq_get(args, 1) 384 or (Dialect[dialect].time_format if default is True else default or None) 385 ), 386 ) 387 388 return _format_time 389 390 391def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 392 """ 393 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 394 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 395 columns are removed from the create statement. 396 """ 397 has_schema = isinstance(expression.this, exp.Schema) 398 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 399 400 if has_schema and is_partitionable: 401 expression = expression.copy() 402 prop = expression.find(exp.PartitionedByProperty) 403 if prop and prop.this and not isinstance(prop.this, exp.Schema): 404 schema = expression.this 405 columns = {v.name.upper() for v in prop.this.expressions} 406 partitions = [col for col in schema.expressions if col.name.upper() in columns] 407 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 408 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 409 expression.set("this", schema) 410 411 return self.create_sql(expression) 412 413 414def parse_date_delta( 415 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 416) -> t.Callable[[t.List], E]: 417 def inner_func(args: t.List) -> E: 418 unit_based = len(args) == 3 419 this = args[2] if unit_based else seq_get(args, 0) 420 unit = args[0] if unit_based else exp.Literal.string("DAY") 421 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 422 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 423 424 return inner_func 425 426 427def parse_date_delta_with_interval( 428 expression_class: t.Type[E], 429) -> t.Callable[[t.List], t.Optional[E]]: 430 def func(args: t.List) -> t.Optional[E]: 431 if len(args) < 2: 432 return None 433 434 interval = args[1] 435 expression = interval.this 436 if expression and expression.is_string: 437 expression = exp.Literal.number(expression.this) 438 439 return expression_class( 440 this=args[0], 441 expression=expression, 442 unit=exp.Literal.string(interval.text("unit")), 443 ) 444 445 return func 446 447 448def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 449 unit = seq_get(args, 0) 450 this = seq_get(args, 1) 451 452 if isinstance(this, exp.Cast) and this.is_type("date"): 453 return exp.DateTrunc(unit=unit, this=this) 454 return exp.TimestampTrunc(this=this, unit=unit) 455 456 457def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 458 return self.func( 459 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 460 ) 461 462 463def locate_to_strposition(args: t.List) -> exp.Expression: 464 return exp.StrPosition( 465 this=seq_get(args, 1), 466 substr=seq_get(args, 0), 467 position=seq_get(args, 2), 468 ) 469 470 471def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 472 return self.func( 473 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 474 ) 475 476 477def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 478 expression = expression.copy() 479 return self.sql( 480 exp.Substring( 481 this=expression.this, start=exp.Literal.number(1), length=expression.expression 482 ) 483 ) 484 485 486def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 487 expression = expression.copy() 488 return self.sql( 489 exp.Substring( 490 this=expression.this, 491 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 492 ) 493 ) 494 495 496def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 497 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 498 499 500def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 501 return f"CAST({self.sql(expression, 'this')} AS DATE)" 502 503 504def min_or_least(self: Generator, expression: exp.Min) -> str: 505 name = "LEAST" if expression.expressions else "MIN" 506 return rename_func(name)(self, expression) 507 508 509def max_or_greatest(self: Generator, expression: exp.Max) -> str: 510 name = "GREATEST" if expression.expressions else "MAX" 511 return rename_func(name)(self, expression) 512 513 514def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 515 cond = expression.this 516 517 if isinstance(expression.this, exp.Distinct): 518 cond = expression.this.expressions[0] 519 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 520 521 return self.func("sum", exp.func("if", cond, 1, 0)) 522 523 524def trim_sql(self: Generator, expression: exp.Trim) -> str: 525 target = self.sql(expression, "this") 526 trim_type = self.sql(expression, "position") 527 remove_chars = self.sql(expression, "expression") 528 collation = self.sql(expression, "collation") 529 530 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 531 if not remove_chars and not collation: 532 return self.trim_sql(expression) 533 534 trim_type = f"{trim_type} " if trim_type else "" 535 remove_chars = f"{remove_chars} " if remove_chars else "" 536 from_part = "FROM " if trim_type or remove_chars else "" 537 collation = f" COLLATE {collation}" if collation else "" 538 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 539 540 541def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 542 return self.func("STRPTIME", expression.this, self.format_time(expression)) 543 544 545def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 546 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 547 _dialect = Dialect.get_or_raise(dialect) 548 time_format = self.format_time(expression) 549 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 550 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 551 return f"CAST({self.sql(expression, 'this')} AS DATE)" 552 553 return _ts_or_ds_to_date_sql 554 555 556# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 557def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 558 names = [] 559 for agg in aggregations: 560 if isinstance(agg, exp.Alias): 561 names.append(agg.alias) 562 else: 563 """ 564 This case corresponds to aggregations without aliases being used as suffixes 565 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 566 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 567 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 568 """ 569 agg_all_unquoted = agg.transform( 570 lambda node: exp.Identifier(this=node.name, quoted=False) 571 if isinstance(node, exp.Identifier) 572 else node 573 ) 574 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 575 576 return names
class
Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum): 25 DIALECT = "" 26 27 BIGQUERY = "bigquery" 28 CLICKHOUSE = "clickhouse" 29 DUCKDB = "duckdb" 30 HIVE = "hive" 31 MYSQL = "mysql" 32 ORACLE = "oracle" 33 POSTGRES = "postgres" 34 PRESTO = "presto" 35 REDSHIFT = "redshift" 36 SNOWFLAKE = "snowflake" 37 SPARK = "spark" 38 SPARK2 = "spark2" 39 SQLITE = "sqlite" 40 STARROCKS = "starrocks" 41 TABLEAU = "tableau" 42 TRINO = "trino" 43 TSQL = "tsql" 44 DATABRICKS = "databricks" 45 DRILL = "drill" 46 TERADATA = "teradata"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
115class Dialect(metaclass=_Dialect): 116 index_offset = 0 117 unnest_column_only = False 118 alias_post_tablesample = False 119 identifiers_can_start_with_digit = False 120 normalize_functions: t.Optional[str] = "upper" 121 null_ordering = "nulls_are_small" 122 123 date_format = "'%Y-%m-%d'" 124 dateint_format = "'%Y%m%d'" 125 time_format = "'%Y-%m-%d %H:%M:%S'" 126 time_mapping: t.Dict[str, str] = {} 127 128 # autofilled 129 quote_start = None 130 quote_end = None 131 identifier_start = None 132 identifier_end = None 133 134 time_trie = None 135 inverse_time_mapping = None 136 inverse_time_trie = None 137 tokenizer_class = None 138 parser_class = None 139 generator_class = None 140 141 def __eq__(self, other: t.Any) -> bool: 142 return type(self) == other 143 144 def __hash__(self) -> int: 145 return hash(type(self)) 146 147 @classmethod 148 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 149 if not dialect: 150 return cls 151 if isinstance(dialect, _Dialect): 152 return dialect 153 if isinstance(dialect, Dialect): 154 return dialect.__class__ 155 156 result = cls.get(dialect) 157 if not result: 158 raise ValueError(f"Unknown dialect '{dialect}'") 159 160 return result 161 162 @classmethod 163 def format_time( 164 cls, expression: t.Optional[str | exp.Expression] 165 ) -> t.Optional[exp.Expression]: 166 if isinstance(expression, str): 167 return exp.Literal.string( 168 format_time( 169 expression[1:-1], # the time formats are quoted 170 cls.time_mapping, 171 cls.time_trie, 172 ) 173 ) 174 if expression and expression.is_string: 175 return exp.Literal.string( 176 format_time( 177 expression.this, 178 cls.time_mapping, 179 cls.time_trie, 180 ) 181 ) 182 return expression 183 184 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 185 return self.parser(**opts).parse(self.tokenize(sql), sql) 186 187 def parse_into( 188 self, expression_type: exp.IntoType, sql: str, **opts 189 ) -> t.List[t.Optional[exp.Expression]]: 190 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 191 192 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 193 return self.generator(**opts).generate(expression) 194 195 def transpile(self, sql: str, **opts) -> t.List[str]: 196 return [self.generate(expression, **opts) for expression in self.parse(sql)] 197 198 def tokenize(self, sql: str) -> t.List[Token]: 199 return self.tokenizer.tokenize(sql) 200 201 @property 202 def tokenizer(self) -> Tokenizer: 203 if not hasattr(self, "_tokenizer"): 204 self._tokenizer = self.tokenizer_class() # type: ignore 205 return self._tokenizer 206 207 def parser(self, **opts) -> Parser: 208 return self.parser_class( # type: ignore 209 **{ 210 "index_offset": self.index_offset, 211 "unnest_column_only": self.unnest_column_only, 212 "alias_post_tablesample": self.alias_post_tablesample, 213 "null_ordering": self.null_ordering, 214 **opts, 215 }, 216 ) 217 218 def generator(self, **opts) -> Generator: 219 return self.generator_class( # type: ignore 220 **{ 221 "quote_start": self.quote_start, 222 "quote_end": self.quote_end, 223 "bit_start": self.bit_start, 224 "bit_end": self.bit_end, 225 "hex_start": self.hex_start, 226 "hex_end": self.hex_end, 227 "byte_start": self.byte_start, 228 "byte_end": self.byte_end, 229 "raw_start": self.raw_start, 230 "raw_end": self.raw_end, 231 "identifier_start": self.identifier_start, 232 "identifier_end": self.identifier_end, 233 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 234 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 235 "index_offset": self.index_offset, 236 "time_mapping": self.inverse_time_mapping, 237 "time_trie": self.inverse_time_trie, 238 "unnest_column_only": self.unnest_column_only, 239 "alias_post_tablesample": self.alias_post_tablesample, 240 "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, 241 "normalize_functions": self.normalize_functions, 242 "null_ordering": self.null_ordering, 243 **opts, 244 } 245 )
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
147 @classmethod 148 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 149 if not dialect: 150 return cls 151 if isinstance(dialect, _Dialect): 152 return dialect 153 if isinstance(dialect, Dialect): 154 return dialect.__class__ 155 156 result = cls.get(dialect) 157 if not result: 158 raise ValueError(f"Unknown dialect '{dialect}'") 159 160 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
162 @classmethod 163 def format_time( 164 cls, expression: t.Optional[str | exp.Expression] 165 ) -> t.Optional[exp.Expression]: 166 if isinstance(expression, str): 167 return exp.Literal.string( 168 format_time( 169 expression[1:-1], # the time formats are quoted 170 cls.time_mapping, 171 cls.time_trie, 172 ) 173 ) 174 if expression and expression.is_string: 175 return exp.Literal.string( 176 format_time( 177 expression.this, 178 cls.time_mapping, 179 cls.time_trie, 180 ) 181 ) 182 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
207 def parser(self, **opts) -> Parser: 208 return self.parser_class( # type: ignore 209 **{ 210 "index_offset": self.index_offset, 211 "unnest_column_only": self.unnest_column_only, 212 "alias_post_tablesample": self.alias_post_tablesample, 213 "null_ordering": self.null_ordering, 214 **opts, 215 }, 216 )
218 def generator(self, **opts) -> Generator: 219 return self.generator_class( # type: ignore 220 **{ 221 "quote_start": self.quote_start, 222 "quote_end": self.quote_end, 223 "bit_start": self.bit_start, 224 "bit_end": self.bit_end, 225 "hex_start": self.hex_start, 226 "hex_end": self.hex_end, 227 "byte_start": self.byte_start, 228 "byte_end": self.byte_end, 229 "raw_start": self.raw_start, 230 "raw_end": self.raw_end, 231 "identifier_start": self.identifier_start, 232 "identifier_end": self.identifier_end, 233 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 234 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 235 "index_offset": self.index_offset, 236 "time_mapping": self.inverse_time_mapping, 237 "time_trie": self.inverse_time_trie, 238 "unnest_column_only": self.unnest_column_only, 239 "alias_post_tablesample": self.alias_post_tablesample, 240 "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, 241 "normalize_functions": self.normalize_functions, 242 "null_ordering": self.null_ordering, 243 **opts, 244 } 245 )
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
334def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 335 this = self.sql(expression, "this") 336 substr = self.sql(expression, "substr") 337 position = self.sql(expression, "position") 338 if position: 339 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 340 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
349def var_map_sql( 350 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 351) -> str: 352 keys = expression.args["keys"] 353 values = expression.args["values"] 354 355 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 356 self.unsupported("Cannot convert array columns into map.") 357 return self.func(map_func_name, keys, values) 358 359 args = [] 360 for key, value in zip(keys.expressions, values.expressions): 361 args.append(self.sql(key)) 362 args.append(self.sql(value)) 363 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[List], ~E]:
366def format_time_lambda( 367 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 368) -> t.Callable[[t.List], E]: 369 """Helper used for time expressions. 370 371 Args: 372 exp_class: the expression class to instantiate. 373 dialect: target sql dialect. 374 default: the default format, True being time. 375 376 Returns: 377 A callable that can be used to return the appropriately formatted time expression. 378 """ 379 380 def _format_time(args: t.List): 381 return exp_class( 382 this=seq_get(args, 0), 383 format=Dialect[dialect].format_time( 384 seq_get(args, 1) 385 or (Dialect[dialect].time_format if default is True else default or None) 386 ), 387 ) 388 389 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
392def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 393 """ 394 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 395 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 396 columns are removed from the create statement. 397 """ 398 has_schema = isinstance(expression.this, exp.Schema) 399 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 400 401 if has_schema and is_partitionable: 402 expression = expression.copy() 403 prop = expression.find(exp.PartitionedByProperty) 404 if prop and prop.this and not isinstance(prop.this, exp.Schema): 405 schema = expression.this 406 columns = {v.name.upper() for v in prop.this.expressions} 407 partitions = [col for col in schema.expressions if col.name.upper() in columns] 408 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 409 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 410 expression.set("this", schema) 411 412 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
415def parse_date_delta( 416 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 417) -> t.Callable[[t.List], E]: 418 def inner_func(args: t.List) -> E: 419 unit_based = len(args) == 3 420 this = args[2] if unit_based else seq_get(args, 0) 421 unit = args[0] if unit_based else exp.Literal.string("DAY") 422 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 423 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 424 425 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
428def parse_date_delta_with_interval( 429 expression_class: t.Type[E], 430) -> t.Callable[[t.List], t.Optional[E]]: 431 def func(args: t.List) -> t.Optional[E]: 432 if len(args) < 2: 433 return None 434 435 interval = args[1] 436 expression = interval.this 437 if expression and expression.is_string: 438 expression = exp.Literal.number(expression.this) 439 440 return expression_class( 441 this=args[0], 442 expression=expression, 443 unit=exp.Literal.string(interval.text("unit")), 444 ) 445 446 return func
def
date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
def
right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
515def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 516 cond = expression.this 517 518 if isinstance(expression.this, exp.Distinct): 519 cond = expression.this.expressions[0] 520 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 521 522 return self.func("sum", exp.func("if", cond, 1, 0))
525def trim_sql(self: Generator, expression: exp.Trim) -> str: 526 target = self.sql(expression, "this") 527 trim_type = self.sql(expression, "position") 528 remove_chars = self.sql(expression, "expression") 529 collation = self.sql(expression, "collation") 530 531 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 532 if not remove_chars and not collation: 533 return self.trim_sql(expression) 534 535 trim_type = f"{trim_type} " if trim_type else "" 536 remove_chars = f"{remove_chars} " if remove_chars else "" 537 from_part = "FROM " if trim_type or remove_chars else "" 538 collation = f" COLLATE {collation}" if collation else "" 539 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
546def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 547 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 548 _dialect = Dialect.get_or_raise(dialect) 549 time_format = self.format_time(expression) 550 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 551 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 552 return f"CAST({self.sql(expression, 'this')} AS DATE)" 553 554 return _ts_or_ds_to_date_sql
def
pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
558def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 559 names = [] 560 for agg in aggregations: 561 if isinstance(agg, exp.Alias): 562 names.append(agg.alias) 563 else: 564 """ 565 This case corresponds to aggregations without aliases being used as suffixes 566 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 567 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 568 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 569 """ 570 agg_all_unquoted = agg.transform( 571 lambda node: exp.Identifier(this=node.name, quoted=False) 572 if isinstance(node, exp.Identifier) 573 else node 574 ) 575 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 576 577 return names