Merging upstream version 16.4.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8a4abed982
commit
71f21d9752
90 changed files with 35638 additions and 33343 deletions
|
@ -84,58 +84,17 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre
|
|||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: None = None,
|
||||
into: t.Type[E] = ...,
|
||||
**opts,
|
||||
) -> E:
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType,
|
||||
into: t.Type[E],
|
||||
**opts,
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: None = None,
|
||||
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType,
|
||||
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
|
||||
**opts,
|
||||
) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(
|
||||
sql: str,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
def parse_one(sql: str, **opts) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: DialectType = None,
|
||||
into: t.Optional[exp.IntoType] = None,
|
||||
**opts,
|
||||
sql: str, read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts
|
||||
) -> Expression:
|
||||
"""
|
||||
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
||||
|
|
|
@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will
|
|||
## Instructions
|
||||
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
|
||||
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
|
||||
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
|
||||
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`.
|
||||
* The column structure can be defined the following ways:
|
||||
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
|
||||
* Ex: `{'cola': 'string', 'colb': 'int'}`
|
||||
|
@ -33,12 +33,16 @@ import sqlglot
|
|||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
|
||||
sqlglot.schema.add_table('employee', {
|
||||
'employee_id': 'INT',
|
||||
'fname': 'STRING',
|
||||
'lname': 'STRING',
|
||||
'age': 'INT',
|
||||
}) # Register the table structure prior to reading from the table
|
||||
sqlglot.schema.add_table(
|
||||
'employee',
|
||||
{
|
||||
'employee_id': 'INT',
|
||||
'fname': 'STRING',
|
||||
'lname': 'STRING',
|
||||
'age': 'INT',
|
||||
},
|
||||
dialect="spark",
|
||||
) # Register the table structure prior to reading from the table
|
||||
|
||||
spark = SparkSession()
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.types import DataType
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import flatten, is_iterable
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -22,6 +23,10 @@ class Column:
|
|||
expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||
if expression is None:
|
||||
raise ValueError(f"Could not parse {expression}")
|
||||
|
||||
if isinstance(expression, exp.Column):
|
||||
expression.transform(Spark.normalize_identifier, copy=False)
|
||||
|
||||
self.expression: exp.Expression = expression
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -316,6 +316,7 @@ class DataFrame:
|
|||
expression.alias_or_name: expression.type.sql("spark")
|
||||
for expression in select_expression.expressions
|
||||
},
|
||||
dialect="spark",
|
||||
)
|
||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||
options = [
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
||||
|
@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
|
|||
for expression in expressions:
|
||||
identifiers = expression.find_all(exp.Identifier)
|
||||
for identifier in identifiers:
|
||||
Spark.normalize_identifier(identifier)
|
||||
replace_alias_name_with_cte_name(spark, expression_context, identifier)
|
||||
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import object_to_dict, should_identify
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import object_to_dict
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
@ -18,17 +19,14 @@ class DataFrameReader:
|
|||
def table(self, tableName: str) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
sqlglot.schema.add_table(tableName)
|
||||
sqlglot.schema.add_table(tableName, dialect="spark")
|
||||
|
||||
return DataFrame(
|
||||
self.spark,
|
||||
exp.Select()
|
||||
.from_(tableName)
|
||||
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
|
||||
.select(
|
||||
*(
|
||||
column if should_identify(column, "safe") else f'"{column}"'
|
||||
for column in sqlglot.schema.column_names(tableName)
|
||||
)
|
||||
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -73,7 +71,7 @@ class DataFrameWriter:
|
|||
)
|
||||
df = self._df.copy(output_expression_container=output_expression_container)
|
||||
if self._by_name:
|
||||
columns = sqlglot.schema.column_names(tableName, only_visible=True)
|
||||
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
|
||||
df = df._convert_leaf_to_cte().select(*columns)
|
||||
|
||||
return self.copy(_df=df)
|
||||
|
|
|
@ -4,6 +4,7 @@ import re
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
|
@ -106,6 +107,9 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
class BigQuery(Dialect):
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
}
|
||||
|
@ -126,6 +130,20 @@ class BigQuery(Dialect):
|
|||
"TZH": "%z",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
# The following check is essentially a heuristic to detect tables based on whether or
|
||||
# not they're qualified.
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
|
||||
and not expression.meta.get("is_table")
|
||||
):
|
||||
expression.set("this", expression.this.lower())
|
||||
|
||||
return expression
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", '"', '"""', "'''"]
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
|
@ -176,6 +194,7 @@ class BigQuery(Dialect):
|
|||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
|
||||
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
|
||||
[seq_get(args, 1), seq_get(args, 0)]
|
||||
),
|
||||
|
@ -201,6 +220,7 @@ class BigQuery(Dialect):
|
|||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -289,6 +309,8 @@ class BigQuery(Dialect):
|
|||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
|
|
|
@ -345,7 +345,7 @@ class ClickHouse(Dialect):
|
|||
"CONCAT",
|
||||
*[
|
||||
exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
|
||||
for e in expression.expressions
|
||||
for e in t.cast(t.List[exp.Condition], expression.expressions)
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import typing as t
|
|||
from enum import Enum
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
|
@ -11,14 +12,6 @@ from sqlglot.time import format_time
|
|||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
DIALECT = ""
|
||||
|
@ -117,6 +110,9 @@ class _Dialect(type):
|
|||
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||
}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
dialect_properties["SELECT_KINDS"] = ()
|
||||
|
||||
# Pass required dialect properties to the tokenizer, parser and generator classes
|
||||
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
|
||||
for name, value in dialect_properties.items():
|
||||
|
@ -126,6 +122,8 @@ class _Dialect(type):
|
|||
if not klass.STRICT_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
|
||||
klass.generator_class.can_identify = klass.can_identify
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -139,6 +137,10 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not the table alias comes after tablesample
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
|
||||
# Determines whether or not unquoted identifiers are resolved as uppercase
|
||||
# When set to None, it means that the dialect treats all identifiers as case-insensitive
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
|
||||
|
||||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
|
||||
|
@ -213,6 +215,66 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
"""
|
||||
Normalizes an unquoted identifier to either lower or upper case, thus essentially
|
||||
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
|
||||
they will be normalized regardless of being quoted or not.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier) and (
|
||||
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
|
||||
):
|
||||
expression.set(
|
||||
"this",
|
||||
expression.this.upper()
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
else expression.this.lower(),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def case_sensitive(cls, text: str) -> bool:
|
||||
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
|
||||
return False
|
||||
|
||||
unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
|
||||
return any(unsafe(char) for char in text)
|
||||
|
||||
@classmethod
|
||||
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
|
||||
"""Checks if text can be identified given an identify option.
|
||||
|
||||
Args:
|
||||
text: The text to check.
|
||||
identify:
|
||||
"always" or `True`: Always returns true.
|
||||
"safe": True if the identifier is case-insensitive.
|
||||
|
||||
Returns:
|
||||
Whether or not the given text can be identified.
|
||||
"""
|
||||
if identify is True or identify == "always":
|
||||
return True
|
||||
|
||||
if identify == "safe":
|
||||
return not cls.case_sensitive(text)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse(self.tokenize(sql), sql)
|
||||
|
||||
|
|
|
@ -85,9 +85,17 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
|
|||
)
|
||||
|
||||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
|
||||
return f"CAST({sql} AS TEXT)"
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
||||
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -167,7 +175,7 @@ class DuckDB(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
if e.expressions and e.expressions[0].find(exp.Select)
|
||||
else rename_func("LIST_VALUE")(self, e),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
|
@ -192,6 +200,7 @@ class DuckDB(Dialect):
|
|||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
|
|
|
@ -86,13 +86,17 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
|
||||
# Since FROM_JSON requires a nested type, we always wrap the json string with
|
||||
# an array to ensure that "naked" strings like "'a'" will be handled correctly
|
||||
wrapped_json = exp.Literal.string(f"[{this.this.name}]")
|
||||
|
||||
annotate_types(this)
|
||||
from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json))
|
||||
to_json = self.func("TO_JSON", from_json)
|
||||
|
||||
# This strips the [, ] delimiters of the dummy array printed by TO_JSON
|
||||
return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
|
||||
|
||||
if this.type.is_type("json"):
|
||||
return self.sql(this)
|
||||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
||||
|
@ -153,6 +157,9 @@ class Hive(Dialect):
|
|||
ALIAS_POST_TABLESAMPLE = True
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
||||
# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
"Y": "%Y",
|
||||
|
@ -268,9 +275,9 @@ class Hive(Dialect):
|
|||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
**parser.Parser.QUERY_MODIFIER_PARSERS,
|
||||
"distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
|
||||
"sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
|
||||
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
|
||||
"cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
|
||||
"distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
|
||||
"sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
|
|
|
@ -123,6 +123,8 @@ class MySQL(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
"FORCE": TokenType.FORCE,
|
||||
"IGNORE": TokenType.IGNORE,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
|
@ -180,6 +182,9 @@ class MySQL(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
|
||||
TABLE_ALIAS_TOKENS = (
|
||||
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
|
||||
)
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -389,7 +394,7 @@ class MySQL(Dialect):
|
|||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
TABLE_HINTS = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
|
|
@ -103,24 +103,15 @@ def _str_to_time_sql(
|
|||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
|
||||
return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto")
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
this = self.func(
|
||||
"DATE_PARSE",
|
||||
self.func(
|
||||
"SUBSTR",
|
||||
this if this.is_string else exp.cast(this, "VARCHAR"),
|
||||
exp.Literal.number(1),
|
||||
exp.Literal.number(10),
|
||||
),
|
||||
Presto.DATE_FORMAT,
|
||||
)
|
||||
this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
|
@ -181,6 +172,11 @@ class Presto(Dialect):
|
|||
TIME_MAPPING = MySQL.TIME_MAPPING
|
||||
STRICT_STRING_CONCAT = True
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
# https://github.com/prestodb/presto/issues/2863
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
|
|
@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
|
|||
|
||||
|
||||
class Redshift(Postgres):
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
**Postgres.TIME_MAPPING,
|
||||
|
|
|
@ -167,6 +167,8 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
|
||||
|
@ -283,11 +285,12 @@ class Snowflake(Dialect):
|
|||
"NCHAR VARYING": TokenType.VARCHAR,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TOP": TokenType.TOP,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
|
|
@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
class SQLite(Dialect):
|
||||
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||
|
|
|
@ -31,18 +31,19 @@ class Teradata(Dialect):
|
|||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"SEL": TokenType.SELECT,
|
||||
"INS": TokenType.INSERT,
|
||||
"MOD": TokenType.MOD,
|
||||
"LT": TokenType.LT,
|
||||
"LE": TokenType.LTE,
|
||||
"GT": TokenType.GT,
|
||||
"GE": TokenType.GTE,
|
||||
"^=": TokenType.NEQ,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"GE": TokenType.GTE,
|
||||
"GT": TokenType.GT,
|
||||
"INS": TokenType.INSERT,
|
||||
"LE": TokenType.LTE,
|
||||
"LT": TokenType.LT,
|
||||
"MOD": TokenType.MOD,
|
||||
"NE": TokenType.NEQ,
|
||||
"NOT=": TokenType.NEQ,
|
||||
"SEL": TokenType.SELECT,
|
||||
"ST_GEOMETRY": TokenType.GEOMETRY,
|
||||
"TOP": TokenType.TOP,
|
||||
}
|
||||
|
||||
# Teradata does not support % as a modulo operator
|
||||
|
|
|
@ -1301,7 +1301,14 @@ class Constraint(Expression):
|
|||
|
||||
|
||||
class Delete(Expression):
|
||||
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": False,
|
||||
"using": False,
|
||||
"where": False,
|
||||
"returning": False,
|
||||
"limit": False,
|
||||
}
|
||||
|
||||
def delete(
|
||||
self,
|
||||
|
@ -1844,6 +1851,10 @@ class CollateProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class CopyGrantsProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class DataBlocksizeProperty(Property):
|
||||
arg_types = {
|
||||
"size": False,
|
||||
|
@ -2245,6 +2256,16 @@ QUERY_MODIFIERS = {
|
|||
}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
|
||||
class WithTableHint(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html
|
||||
class IndexTableHint(Expression):
|
||||
arg_types = {"this": True, "expressions": False, "target": False}
|
||||
|
||||
|
||||
class Table(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -2402,6 +2423,7 @@ class Update(Expression):
|
|||
"from": False,
|
||||
"where": False,
|
||||
"returning": False,
|
||||
"limit": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2434,8 +2456,6 @@ class Select(Subqueryable):
|
|||
"expressions": False,
|
||||
"hint": False,
|
||||
"distinct": False,
|
||||
"struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
|
||||
"value": False,
|
||||
"into": False,
|
||||
"from": False,
|
||||
**QUERY_MODIFIERS,
|
||||
|
@ -3223,15 +3243,15 @@ class Star(Expression):
|
|||
return self.name
|
||||
|
||||
|
||||
class Parameter(Expression):
|
||||
class Parameter(Condition):
|
||||
arg_types = {"this": True, "wrapped": False}
|
||||
|
||||
|
||||
class SessionParameter(Expression):
|
||||
class SessionParameter(Condition):
|
||||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Placeholder(Expression):
|
||||
class Placeholder(Condition):
|
||||
arg_types = {"this": False, "kind": False}
|
||||
|
||||
|
||||
|
@ -3333,6 +3353,7 @@ class DataType(Expression):
|
|||
UINT128 = auto()
|
||||
UINT256 = auto()
|
||||
UNIQUEIDENTIFIER = auto()
|
||||
USERDEFINED = "USER-DEFINED"
|
||||
UUID = auto()
|
||||
VARBINARY = auto()
|
||||
VARCHAR = auto()
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -56,39 +56,40 @@ class Generator:
|
|||
exp.TsOrDsAdd: lambda self, e: self.func(
|
||||
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ExternalProperty: lambda self, e: "EXTERNAL",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
|
||||
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
|
||||
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
|
@ -142,6 +143,9 @@ class Generator:
|
|||
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
|
||||
IS_BOOL_ALLOWED = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -182,6 +186,7 @@ class Generator:
|
|||
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -263,6 +268,8 @@ class Generator:
|
|||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
can_identify: t.Callable[[str, str | bool], bool]
|
||||
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
|
@ -771,9 +778,11 @@ class Generator:
|
|||
return this
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
string = expression.this
|
||||
if self.RAW_START:
|
||||
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
|
||||
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
|
||||
return f"{self.RAW_START}{self.escape_str(expression.this)}{self.RAW_END}"
|
||||
string = self.escape_str(string.replace("\\", "\\\\"))
|
||||
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
|
||||
|
||||
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -815,7 +824,8 @@ class Generator:
|
|||
)
|
||||
where_sql = self.sql(expression, "where")
|
||||
returning = self.sql(expression, "returning")
|
||||
sql = f"DELETE{this}{using_sql}{where_sql}{returning}"
|
||||
limit = self.sql(expression, "limit")
|
||||
sql = f"DELETE{this}{using_sql}{where_sql}{returning}{limit}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def drop_sql(self, expression: exp.Drop) -> str:
|
||||
|
@ -883,7 +893,7 @@ class Generator:
|
|||
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
|
||||
if (
|
||||
expression.quoted
|
||||
or should_identify(text, self.identify)
|
||||
or self.can_identify(text, self.identify)
|
||||
or lower in self.RESERVED_KEYWORDS
|
||||
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
|
||||
):
|
||||
|
@ -1157,6 +1167,15 @@ class Generator:
|
|||
null = f" NULL DEFINED AS {null}" if null else ""
|
||||
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
|
||||
|
||||
def withtablehint_sql(self, expression: exp.WithTableHint) -> str:
|
||||
return f"WITH ({self.expressions(expression, flat=True)})"
|
||||
|
||||
def indextablehint_sql(self, expression: exp.IndexTableHint) -> str:
|
||||
this = f"{self.sql(expression, 'this')} INDEX"
|
||||
target = self.sql(expression, "target")
|
||||
target = f" FOR {target}" if target else ""
|
||||
return f"{this}{target} ({self.expressions(expression, flat=True)})"
|
||||
|
||||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = ".".join(
|
||||
part
|
||||
|
@ -1170,8 +1189,8 @@ class Generator:
|
|||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", flat=True)
|
||||
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
|
||||
hints = self.expressions(expression, key="hints", sep=" ")
|
||||
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
|
||||
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
|
||||
pivots = f" {pivots}" if pivots else ""
|
||||
joins = self.expressions(expression, key="joins", sep="")
|
||||
|
@ -1238,7 +1257,8 @@ class Generator:
|
|||
from_sql = self.sql(expression, "from")
|
||||
where_sql = self.sql(expression, "where")
|
||||
returning = self.sql(expression, "returning")
|
||||
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}"
|
||||
limit = self.sql(expression, "limit")
|
||||
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}{limit}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
|
@ -1413,10 +1433,13 @@ class Generator:
|
|||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
if expression.is_string:
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
|
||||
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
|
||||
return text
|
||||
|
||||
def escape_str(self, text: str) -> str:
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression: exp.LoadData) -> str:
|
||||
|
@ -1565,9 +1588,30 @@ class Generator:
|
|||
hint = self.sql(expression, "hint")
|
||||
distinct = self.sql(expression, "distinct")
|
||||
distinct = f" {distinct}" if distinct else ""
|
||||
kind = expression.args.get("kind")
|
||||
kind = f" AS {kind}" if kind else ""
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
expressions = self.expressions(expression)
|
||||
|
||||
if kind:
|
||||
if kind in self.SELECT_KINDS:
|
||||
kind = f" AS {kind}"
|
||||
else:
|
||||
if kind == "STRUCT":
|
||||
expressions = self.expressions(
|
||||
sqls=[
|
||||
self.sql(
|
||||
exp.Struct(
|
||||
expressions=[
|
||||
exp.column(e.output_name).eq(
|
||||
e.this if isinstance(e, exp.Alias) else e
|
||||
)
|
||||
for e in expression.expressions
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
kind = ""
|
||||
|
||||
expressions = f"{self.sep()}{expressions}" if expressions else expressions
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
|
|
|
@ -14,7 +14,6 @@ from itertools import count
|
|||
if t.TYPE_CHECKING:
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E, T
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.expressions import Expression
|
||||
|
||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||
|
@ -23,7 +22,12 @@ logger = logging.getLogger("sqlglot")
|
|||
|
||||
|
||||
class AutoName(Enum):
|
||||
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
|
||||
"""
|
||||
This is used for creating Enum classes where `auto()` is the string form
|
||||
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
|
||||
|
||||
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
|
||||
"""
|
||||
|
||||
def _generate_next_value_(name, _start, _count, _last_values):
|
||||
return name
|
||||
|
@ -52,7 +56,7 @@ def ensure_list(value):
|
|||
Ensures that a value is a list, otherwise casts or wraps it into one.
|
||||
|
||||
Args:
|
||||
value: the value of interest.
|
||||
value: The value of interest.
|
||||
|
||||
Returns:
|
||||
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
|
||||
|
@ -80,7 +84,7 @@ def ensure_collection(value):
|
|||
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
|
||||
|
||||
Args:
|
||||
value: the value of interest.
|
||||
value: The value of interest.
|
||||
|
||||
Returns:
|
||||
The value if it's a collection, or else the value wrapped in a list.
|
||||
|
@ -97,8 +101,8 @@ def csv(*args: str, sep: str = ", ") -> str:
|
|||
Formats any number of string arguments as CSV.
|
||||
|
||||
Args:
|
||||
args: the string arguments to format.
|
||||
sep: the argument separator.
|
||||
args: The string arguments to format.
|
||||
sep: The argument separator.
|
||||
|
||||
Returns:
|
||||
The arguments formatted as a CSV string.
|
||||
|
@ -115,9 +119,9 @@ def subclasses(
|
|||
Returns all subclasses for a collection of classes, possibly excluding some of them.
|
||||
|
||||
Args:
|
||||
module_name: the name of the module to search for subclasses in.
|
||||
classes: class(es) we want to find the subclasses of.
|
||||
exclude: class(es) we want to exclude from the returned list.
|
||||
module_name: The name of the module to search for subclasses in.
|
||||
classes: Class(es) we want to find the subclasses of.
|
||||
exclude: Class(es) we want to exclude from the returned list.
|
||||
|
||||
Returns:
|
||||
The target subclasses.
|
||||
|
@ -140,13 +144,13 @@ def apply_index_offset(
|
|||
Applies an offset to a given integer literal expression.
|
||||
|
||||
Args:
|
||||
this: the target of the index
|
||||
expressions: the expression the offset will be applied to, wrapped in a list.
|
||||
offset: the offset that will be applied.
|
||||
this: The target of the index.
|
||||
expressions: The expression the offset will be applied to, wrapped in a list.
|
||||
offset: The offset that will be applied.
|
||||
|
||||
Returns:
|
||||
The original expression with the offset applied to it, wrapped in a list. If the provided
|
||||
`expressions` argument contains more than one expressions, it's returned unaffected.
|
||||
`expressions` argument contains more than one expression, it's returned unaffected.
|
||||
"""
|
||||
if not offset or len(expressions) != 1:
|
||||
return expressions
|
||||
|
@ -189,8 +193,8 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
|
|||
Applies a transformation to a given expression until a fix point is reached.
|
||||
|
||||
Args:
|
||||
expression: the expression to be transformed.
|
||||
func: the transformation to be applied.
|
||||
expression: The expression to be transformed.
|
||||
func: The transformation to be applied.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
|
@ -198,6 +202,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
|
|||
while True:
|
||||
for n, *_ in reversed(tuple(expression.walk())):
|
||||
n._hash = hash(n)
|
||||
|
||||
start = hash(expression)
|
||||
expression = func(expression)
|
||||
|
||||
|
@ -205,6 +210,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
|
|||
n._hash = None
|
||||
if start == hash(expression):
|
||||
break
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -213,7 +219,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
|
|||
Sorts a given directed acyclic graph in topological order.
|
||||
|
||||
Args:
|
||||
dag: the graph to be sorted.
|
||||
dag: The graph to be sorted.
|
||||
|
||||
Returns:
|
||||
A list that contains all of the graph's nodes in topological order.
|
||||
|
@ -261,7 +267,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
|||
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
||||
|
||||
Args:
|
||||
read_csv: a `ReadCSV` function call
|
||||
read_csv: A `ReadCSV` function call.
|
||||
|
||||
Yields:
|
||||
A python csv reader.
|
||||
|
@ -288,8 +294,8 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
|
|||
Searches for a new name.
|
||||
|
||||
Args:
|
||||
taken: a collection of taken names.
|
||||
base: base name to alter.
|
||||
taken: A collection of taken names.
|
||||
base: Base name to alter.
|
||||
|
||||
Returns:
|
||||
The new, available name.
|
||||
|
@ -327,10 +333,10 @@ def split_num_words(
|
|||
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
|
||||
|
||||
Args:
|
||||
value: the value to be split.
|
||||
sep: the value to use to split on.
|
||||
min_num_words: the minimum number of words that are going to be in the result.
|
||||
fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
|
||||
value: The value to be split.
|
||||
sep: The value to use to split on.
|
||||
min_num_words: The minimum number of words that are going to be in the result.
|
||||
fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
|
||||
|
||||
Examples:
|
||||
>>> split_num_words("db.table", ".", 3)
|
||||
|
@ -360,7 +366,7 @@ def is_iterable(value: t.Any) -> bool:
|
|||
False
|
||||
|
||||
Args:
|
||||
value: the value to check if it is an iterable.
|
||||
value: The value to check if it is an iterable.
|
||||
|
||||
Returns:
|
||||
A `bool` value indicating if it is an iterable.
|
||||
|
@ -380,7 +386,7 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
|
|||
[1, 2, 3]
|
||||
|
||||
Args:
|
||||
values: the value to be flattened.
|
||||
values: The value to be flattened.
|
||||
|
||||
Yields:
|
||||
Non-iterable elements in `values`.
|
||||
|
@ -396,7 +402,7 @@ def dict_depth(d: t.Dict) -> int:
|
|||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
For example:
|
||||
Example:
|
||||
>>> dict_depth(None)
|
||||
0
|
||||
>>> dict_depth({})
|
||||
|
@ -407,12 +413,6 @@ def dict_depth(d: t.Dict) -> int:
|
|||
2
|
||||
>>> dict_depth({"a": {"b": {}}})
|
||||
3
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
try:
|
||||
return 1 + dict_depth(next(iter(d.values())))
|
||||
|
@ -425,36 +425,5 @@ def dict_depth(d: t.Dict) -> int:
|
|||
|
||||
|
||||
def first(it: t.Iterable[T]) -> T:
|
||||
"""Returns the first element from an iterable.
|
||||
|
||||
Useful for sets.
|
||||
"""
|
||||
"""Returns the first element from an iterable (useful for sets)."""
|
||||
return next(i for i in it)
|
||||
|
||||
|
||||
def case_sensitive(text: str, dialect: DialectType) -> bool:
|
||||
"""Checks if text contains any case sensitive characters depending on dialect."""
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
|
||||
unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
|
||||
return any(unsafe(char) for char in text)
|
||||
|
||||
|
||||
def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
|
||||
"""Checks if text should be identified given an identify option.
|
||||
|
||||
Args:
|
||||
text: the text to check.
|
||||
identify:
|
||||
"always" or `True`: always returns true.
|
||||
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
|
||||
dialect: the dialect to use in order to decide whether a text should be identified.
|
||||
|
||||
Returns:
|
||||
Whether or not a string should be identified.
|
||||
"""
|
||||
if identify is True or identify == "always":
|
||||
return True
|
||||
if identify == "safe":
|
||||
return not case_sensitive(text, dialect)
|
||||
return False
|
||||
|
|
|
@ -187,7 +187,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
},
|
||||
}
|
||||
|
||||
ANNOTATORS = {
|
||||
ANNOTATORS: t.Dict = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending on
|
||||
the dialect. This essentially makes those identifiers case-insensitive.
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
||||
Note:
|
||||
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
|
||||
when they're quoted, so in these cases all identifiers are normalized.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
@ -21,16 +24,4 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
return expression.transform(_normalize, dialect, copy=False)
|
||||
|
||||
|
||||
def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
|
||||
if isinstance(node, exp.Identifier) and not node.quoted:
|
||||
node.set(
|
||||
"this",
|
||||
node.this.upper()
|
||||
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
else node.this.lower(),
|
||||
)
|
||||
|
||||
return node
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
|
|
@ -5,9 +5,9 @@ import typing as t
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import case_sensitive, seq_get
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -417,19 +417,9 @@ def _qualify_outputs(scope):
|
|||
|
||||
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
|
||||
"""Makes sure all identifiers that need to be quoted are quoted."""
|
||||
|
||||
def _quote(expression: E) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
identify
|
||||
or case_sensitive(name, dialect=dialect)
|
||||
or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
)
|
||||
return expression
|
||||
|
||||
return expression.transform(_quote, copy=False)
|
||||
return expression.transform(
|
||||
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
|
||||
)
|
||||
|
||||
|
||||
class Resolver:
|
||||
|
|
|
@ -408,9 +408,14 @@ def remove_where_true(expression):
|
|||
if always_true(where.this):
|
||||
where.parent.set("where", None)
|
||||
for join in expression.find_all(exp.Join):
|
||||
if always_true(join.args.get("on")):
|
||||
join.set("kind", "CROSS")
|
||||
if (
|
||||
always_true(join.args.get("on"))
|
||||
and not join.args.get("using")
|
||||
and not join.args.get("method")
|
||||
):
|
||||
join.set("on", None)
|
||||
join.set("side", None)
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
|
||||
def always_true(expression):
|
||||
|
|
|
@ -9,7 +9,7 @@ from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
|
|||
from sqlglot.helper import apply_index_offset, ensure_list, seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
@ -177,6 +177,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.BIGSERIAL,
|
||||
TokenType.XML,
|
||||
TokenType.UNIQUEIDENTIFIER,
|
||||
TokenType.USERDEFINED,
|
||||
TokenType.MONEY,
|
||||
TokenType.SMALLMONEY,
|
||||
TokenType.ROWVERSION,
|
||||
|
@ -465,7 +466,7 @@ class Parser(metaclass=_Parser):
|
|||
}
|
||||
|
||||
EXPRESSION_PARSERS = {
|
||||
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
|
||||
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
|
||||
exp.Column: lambda self: self._parse_column(),
|
||||
exp.Condition: lambda self: self._parse_conjunction(),
|
||||
exp.DataType: lambda self: self._parse_types(),
|
||||
|
@ -484,7 +485,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Properties: lambda self: self._parse_properties(),
|
||||
exp.Qualify: lambda self: self._parse_qualify(),
|
||||
exp.Returning: lambda self: self._parse_returning(),
|
||||
exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
|
||||
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
|
||||
exp.Table: lambda self: self._parse_table_parts(),
|
||||
exp.TableAlias: lambda self: self._parse_table_alias(),
|
||||
exp.Where: lambda self: self._parse_where(),
|
||||
|
@ -540,8 +541,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Literal, this=token.text, is_string=False
|
||||
),
|
||||
TokenType.STAR: lambda self, _: self.expression(
|
||||
exp.Star,
|
||||
**{"except": self._parse_except(), "replace": self._parse_replace()},
|
||||
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
||||
),
|
||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||
|
@ -584,9 +584,10 @@ class Parser(metaclass=_Parser):
|
|||
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
||||
"CHARACTER SET": lambda self: self._parse_character_set(),
|
||||
"CHECKSUM": lambda self: self._parse_checksum(),
|
||||
"CLUSTER": lambda self: self._parse_cluster(),
|
||||
"CLUSTER BY": lambda self: self._parse_cluster(),
|
||||
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||
"COPY": lambda self: self._parse_copy_property(),
|
||||
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
|
||||
"DEFINER": lambda self: self._parse_definer(),
|
||||
"DETERMINISTIC": lambda self: self.expression(
|
||||
|
@ -780,6 +781,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
|
||||
|
||||
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
|
||||
|
||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
|
||||
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
|
||||
|
@ -788,7 +791,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
STRICT_CAST = True
|
||||
|
||||
CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default
|
||||
# A NULL arg in CONCAT yields NULL by default
|
||||
CONCAT_NULL_OUTPUTS_STRING = False
|
||||
|
||||
CONVERT_TYPE_FIRST = False
|
||||
|
||||
|
@ -1423,11 +1427,14 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
|
||||
|
||||
def _parse_cluster(self) -> t.Optional[exp.Cluster]:
|
||||
if not self._match_text_seq("BY"):
|
||||
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
|
||||
|
||||
def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
|
||||
if not self._match_text_seq("GRANTS"):
|
||||
self._retreat(self._index - 1)
|
||||
return None
|
||||
|
||||
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
|
||||
return self.expression(exp.CopyGrantsProperty)
|
||||
|
||||
def _parse_freespace(self) -> exp.FreespaceProperty:
|
||||
self._match(TokenType.EQ)
|
||||
|
@ -1779,6 +1786,7 @@ class Parser(metaclass=_Parser):
|
|||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
|
||||
where=self._parse_where(),
|
||||
returning=self._parse_returning(),
|
||||
limit=self._parse_limit(),
|
||||
)
|
||||
|
||||
def _parse_update(self) -> exp.Update:
|
||||
|
@ -1790,6 +1798,7 @@ class Parser(metaclass=_Parser):
|
|||
"from": self._parse_from(modifiers=True),
|
||||
"where": self._parse_where(),
|
||||
"returning": self._parse_returning(),
|
||||
"limit": self._parse_limit(),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -2268,6 +2277,33 @@ class Parser(metaclass=_Parser):
|
|||
partition_by=self._parse_partition_by(),
|
||||
)
|
||||
|
||||
def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
hints: t.List[exp.Expression] = []
|
||||
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
|
||||
hints.append(
|
||||
self.expression(
|
||||
exp.WithTableHint,
|
||||
expressions=self._parse_csv(
|
||||
lambda: self._parse_function() or self._parse_var(any_token=True)
|
||||
),
|
||||
)
|
||||
)
|
||||
self._match_r_paren()
|
||||
else:
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html
|
||||
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
|
||||
hint = exp.IndexTableHint(this=self._prev.text.upper())
|
||||
|
||||
self._match_texts({"INDEX", "KEY"})
|
||||
if self._match(TokenType.FOR):
|
||||
hint.set("target", self._advance_any() and self._prev.text.upper())
|
||||
|
||||
hint.set("expressions", self._parse_wrapped_id_vars())
|
||||
hints.append(hint)
|
||||
|
||||
return hints or None
|
||||
|
||||
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
(not schema and self._parse_function(optional_parens=False))
|
||||
|
@ -2335,12 +2371,7 @@ class Parser(metaclass=_Parser):
|
|||
if not this.args.get("pivots"):
|
||||
this.set("pivots", self._parse_pivots())
|
||||
|
||||
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
|
||||
this.set(
|
||||
"hints",
|
||||
self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
|
||||
)
|
||||
self._match_r_paren()
|
||||
this.set("hints", self._parse_table_hints())
|
||||
|
||||
if not self.ALIAS_POST_TABLESAMPLE:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
@ -2610,8 +2641,8 @@ class Parser(metaclass=_Parser):
|
|||
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
|
||||
)
|
||||
|
||||
def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]:
|
||||
if not self._match_text_seq(*texts):
|
||||
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
|
||||
if not self._match(token):
|
||||
return None
|
||||
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
||||
|
||||
|
@ -3655,7 +3686,11 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_concat(self) -> t.Optional[exp.Expression]:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
if self.CONCAT_NULL_OUTPUTS_STRING:
|
||||
args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args]
|
||||
args = [
|
||||
exp.func("COALESCE", exp.cast(arg, "text"), exp.Literal.string(""))
|
||||
for arg in args
|
||||
if arg
|
||||
]
|
||||
|
||||
# Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
|
||||
# we find such a call we replace it with its argument.
|
||||
|
@ -4553,13 +4588,16 @@ class Parser(metaclass=_Parser):
|
|||
curr = self._curr.text.upper()
|
||||
key = curr.split(" ")
|
||||
this.append(curr)
|
||||
|
||||
self._advance()
|
||||
result, trie = in_trie(trie, key)
|
||||
if result == 0:
|
||||
if result == TrieResult.FAILED:
|
||||
break
|
||||
if result == 2:
|
||||
|
||||
if result == TrieResult.EXISTS:
|
||||
subparser = parsers[" ".join(this)]
|
||||
return subparser
|
||||
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ import typing as t
|
|||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot._typing import T
|
||||
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.errors import ParseError, SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
|
@ -135,10 +135,10 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
||||
|
||||
if value == 0:
|
||||
if value == TrieResult.FAILED:
|
||||
return None
|
||||
|
||||
if value == 1:
|
||||
if value == TrieResult.PREFIX:
|
||||
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
|
||||
|
||||
if len(possibilities) == 1:
|
||||
|
@ -289,7 +289,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
|
||||
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||
"""
|
||||
Converts all identifiers in the schema into lowercase, unless they're quoted.
|
||||
Normalizes all identifiers in the schema.
|
||||
|
||||
Args:
|
||||
schema: the schema to normalize.
|
||||
|
@ -304,7 +304,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
columns = nested_get(schema, *zip(keys, keys))
|
||||
assert columns is not None
|
||||
|
||||
normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
|
||||
normalized_keys = [
|
||||
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
|
||||
]
|
||||
for column_name, column_type in columns.items():
|
||||
nested_set(
|
||||
normalized_mapping,
|
||||
|
@ -321,12 +323,15 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
value = normalized_table.args.get(arg)
|
||||
if isinstance(value, (str, exp.Identifier)):
|
||||
normalized_table.set(
|
||||
arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
|
||||
arg,
|
||||
exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
|
||||
)
|
||||
|
||||
return normalized_table
|
||||
|
||||
def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
|
||||
def _normalize_name(
|
||||
self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
|
||||
) -> str:
|
||||
dialect = dialect or self.dialect
|
||||
|
||||
try:
|
||||
|
@ -335,11 +340,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
return name if isinstance(name, str) else name.name
|
||||
|
||||
name = identifier.name
|
||||
|
||||
if not self.normalize or identifier.quoted:
|
||||
if not self.normalize:
|
||||
return name
|
||||
|
||||
return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
|
||||
# This can be useful for normalize_identifier
|
||||
identifier.meta["is_table"] = is_table
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
|
||||
|
||||
def _depth(self) -> int:
|
||||
# The columns themselves are a mapping, but we don't want to include those
|
||||
|
|
|
@ -2,7 +2,7 @@ import typing as t
|
|||
|
||||
# The generic time format is based on python time.strftime.
|
||||
# https://docs.python.org/3/library/time.html#time.strftime
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
||||
def format_time(
|
||||
|
@ -37,7 +37,7 @@ def format_time(
|
|||
chars = string[start:end]
|
||||
result, current = in_trie(current, chars[-1])
|
||||
|
||||
if result == 0:
|
||||
if result == TrieResult.FAILED:
|
||||
if sym:
|
||||
end -= 1
|
||||
chars = sym
|
||||
|
@ -45,11 +45,12 @@ def format_time(
|
|||
start += len(chars)
|
||||
chunks.append(chars)
|
||||
current = trie
|
||||
elif result == 2:
|
||||
elif result == TrieResult.EXISTS:
|
||||
sym = chars
|
||||
|
||||
end += 1
|
||||
|
||||
if result and end > size:
|
||||
if result != TrieResult.FAILED and end > size:
|
||||
chunks.append(chars)
|
||||
|
||||
return "".join(mapping.get(chars, chars) for chars in chunks)
|
||||
|
|
|
@ -4,7 +4,7 @@ import typing as t
|
|||
from enum import auto
|
||||
|
||||
from sqlglot.helper import AutoName
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
||||
class TokenType(AutoName):
|
||||
|
@ -137,6 +137,7 @@ class TokenType(AutoName):
|
|||
BIGSERIAL = auto()
|
||||
XML = auto()
|
||||
UNIQUEIDENTIFIER = auto()
|
||||
USERDEFINED = auto()
|
||||
MONEY = auto()
|
||||
SMALLMONEY = auto()
|
||||
ROWVERSION = auto()
|
||||
|
@ -163,6 +164,7 @@ class TokenType(AutoName):
|
|||
CACHE = auto()
|
||||
CASE = auto()
|
||||
CHARACTER_SET = auto()
|
||||
CLUSTER_BY = auto()
|
||||
COLLATE = auto()
|
||||
COMMAND = auto()
|
||||
COMMENT = auto()
|
||||
|
@ -182,6 +184,7 @@ class TokenType(AutoName):
|
|||
DESCRIBE = auto()
|
||||
DICTIONARY = auto()
|
||||
DISTINCT = auto()
|
||||
DISTRIBUTE_BY = auto()
|
||||
DIV = auto()
|
||||
DROP = auto()
|
||||
ELSE = auto()
|
||||
|
@ -196,6 +199,7 @@ class TokenType(AutoName):
|
|||
FINAL = auto()
|
||||
FIRST = auto()
|
||||
FOR = auto()
|
||||
FORCE = auto()
|
||||
FOREIGN_KEY = auto()
|
||||
FORMAT = auto()
|
||||
FROM = auto()
|
||||
|
@ -208,6 +212,7 @@ class TokenType(AutoName):
|
|||
HAVING = auto()
|
||||
HINT = auto()
|
||||
IF = auto()
|
||||
IGNORE = auto()
|
||||
ILIKE = auto()
|
||||
ILIKE_ANY = auto()
|
||||
IN = auto()
|
||||
|
@ -282,6 +287,7 @@ class TokenType(AutoName):
|
|||
SHOW = auto()
|
||||
SIMILAR_TO = auto()
|
||||
SOME = auto()
|
||||
SORT_BY = auto()
|
||||
STRUCT = auto()
|
||||
TABLE_SAMPLE = auto()
|
||||
TEMPORARY = auto()
|
||||
|
@ -509,6 +515,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"UNCACHE": TokenType.UNCACHE,
|
||||
"CASE": TokenType.CASE,
|
||||
"CHARACTER SET": TokenType.CHARACTER_SET,
|
||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||
"COLLATE": TokenType.COLLATE,
|
||||
"COLUMN": TokenType.COLUMN,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
|
@ -526,6 +533,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DESC": TokenType.DESC,
|
||||
"DESCRIBE": TokenType.DESCRIBE,
|
||||
"DISTINCT": TokenType.DISTINCT,
|
||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||
"DIV": TokenType.DIV,
|
||||
"DROP": TokenType.DROP,
|
||||
"ELSE": TokenType.ELSE,
|
||||
|
@ -617,6 +625,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"SHOW": TokenType.SHOW,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
"SOME": TokenType.SOME,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
"TABLE": TokenType.TABLE,
|
||||
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
|
@ -717,6 +726,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"PREPARE": TokenType.COMMAND,
|
||||
"TRUNCATE": TokenType.COMMAND,
|
||||
"VACUUM": TokenType.COMMAND,
|
||||
"USER-DEFINED": TokenType.USERDEFINED,
|
||||
}
|
||||
|
||||
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
|
||||
|
@ -905,13 +915,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
while chars:
|
||||
if skip:
|
||||
result = 1
|
||||
result = TrieResult.PREFIX
|
||||
else:
|
||||
result, trie = in_trie(trie, char.upper())
|
||||
|
||||
if result == 0:
|
||||
if result == TrieResult.FAILED:
|
||||
break
|
||||
if result == 2:
|
||||
if result == TrieResult.EXISTS:
|
||||
word = chars
|
||||
|
||||
size += 1
|
||||
|
|
|
@ -1,14 +1,21 @@
|
|||
import typing as t
|
||||
from enum import Enum, auto
|
||||
|
||||
key = t.Sequence[t.Hashable]
|
||||
|
||||
|
||||
class TrieResult(Enum):
|
||||
FAILED = auto()
|
||||
PREFIX = auto()
|
||||
EXISTS = auto()
|
||||
|
||||
|
||||
def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict:
|
||||
"""
|
||||
Creates a new trie out of a collection of keywords.
|
||||
|
||||
The trie is represented as a sequence of nested dictionaries keyed by either single character
|
||||
strings, or by 0, which is used to designate that a keyword is in the trie.
|
||||
The trie is represented as a sequence of nested dictionaries keyed by either single
|
||||
character strings, or by 0, which is used to designate that a keyword is in the trie.
|
||||
|
||||
Example:
|
||||
>>> new_trie(["bla", "foo", "blab"])
|
||||
|
@ -25,46 +32,50 @@ def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Di
|
|||
|
||||
for key in keywords:
|
||||
current = trie
|
||||
|
||||
for char in key:
|
||||
current = current.setdefault(char, {})
|
||||
|
||||
current[0] = True
|
||||
|
||||
return trie
|
||||
|
||||
|
||||
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
|
||||
def in_trie(trie: t.Dict, key: key) -> t.Tuple[TrieResult, t.Dict]:
|
||||
"""
|
||||
Checks whether a key is in a trie.
|
||||
|
||||
Examples:
|
||||
>>> in_trie(new_trie(["cat"]), "bob")
|
||||
(0, {'c': {'a': {'t': {0: True}}}})
|
||||
(<TrieResult.FAILED: 1>, {'c': {'a': {'t': {0: True}}}})
|
||||
|
||||
>>> in_trie(new_trie(["cat"]), "ca")
|
||||
(1, {'t': {0: True}})
|
||||
(<TrieResult.PREFIX: 2>, {'t': {0: True}})
|
||||
|
||||
>>> in_trie(new_trie(["cat"]), "cat")
|
||||
(2, {0: True})
|
||||
(<TrieResult.EXISTS: 3>, {0: True})
|
||||
|
||||
Args:
|
||||
trie: the trie to be searched.
|
||||
key: the target key.
|
||||
trie: The trie to be searched.
|
||||
key: The target key.
|
||||
|
||||
Returns:
|
||||
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
||||
is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point
|
||||
where the search stops, and `value` is a TrieResult value that can be one of:
|
||||
|
||||
- TrieResult.FAILED: the search was unsuccessful
|
||||
- TrieResult.PREFIX: `value` is a prefix of a keyword in `trie`
|
||||
- TrieResult.EXISTS: `key` exists in `trie`
|
||||
"""
|
||||
if not key:
|
||||
return (0, trie)
|
||||
return (TrieResult.FAILED, trie)
|
||||
|
||||
current = trie
|
||||
|
||||
for char in key:
|
||||
if char not in current:
|
||||
return (0, current)
|
||||
return (TrieResult.FAILED, current)
|
||||
current = current[char]
|
||||
|
||||
if 0 in current:
|
||||
return (2, current)
|
||||
return (1, current)
|
||||
return (TrieResult.EXISTS, current)
|
||||
|
||||
return (TrieResult.PREFIX, current)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue