1
0
Fork 0

Merging upstream version 16.4.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:04:59 +01:00
parent 8a4abed982
commit 71f21d9752
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
90 changed files with 35638 additions and 33343 deletions

View file

@ -84,58 +84,17 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre
@t.overload
def parse_one(
sql: str,
read: None = None,
into: t.Type[E] = ...,
**opts,
) -> E:
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Type[E],
**opts,
) -> E:
...
@t.overload
def parse_one(
sql: str,
read: None = None,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
**opts,
) -> Expression:
def parse_one(sql: str, **opts) -> Expression:
...
def parse_one(
sql: str,
read: DialectType = None,
into: t.Optional[exp.IntoType] = None,
**opts,
sql: str, read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts
) -> Expression:
"""
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.

View file

@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will
## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`.
* The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
* Ex: `{'cola': 'string', 'colb': 'int'}`
@ -33,12 +33,16 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
sqlglot.schema.add_table('employee', {
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
}) # Register the table structure prior to reading from the table
sqlglot.schema.add_table(
'employee',
{
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
},
dialect="spark",
) # Register the table structure prior to reading from the table
spark = SparkSession()

View file

@ -5,6 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
@ -22,6 +23,10 @@ class Column:
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")
if isinstance(expression, exp.Column):
expression.transform(Spark.normalize_identifier, copy=False)
self.expression: exp.Expression = expression
def __repr__(self):

View file

@ -316,6 +316,7 @@ class DataFrame:
expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
dialect="spark",
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
Spark.normalize_identifier(identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)

View file

@ -4,7 +4,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict, should_identify
from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@ -18,17 +19,14 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
sqlglot.schema.add_table(tableName, dialect="spark")
return DataFrame(
self.spark,
exp.Select()
.from_(tableName)
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.select(
*(
column if should_identify(column, "safe") else f'"{column}"'
for column in sqlglot.schema.column_names(tableName)
)
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
),
)
@ -73,7 +71,7 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True)
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)

View file

@ -4,6 +4,7 @@ import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
@ -106,6 +107,9 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
TIME_MAPPING = {
"%D": "%m/%d/%y",
}
@ -126,6 +130,20 @@ class BigQuery(Dialect):
"TZH": "%z",
}
@classmethod
def normalize_identifier(cls, expression: E) -> E:
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
# The following check is essentially a heuristic to detect tables based on whether or
# not they're qualified.
if (
isinstance(expression, exp.Identifier)
and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
and not expression.meta.get("is_table")
):
expression.set("this", expression.this.lower())
return expression
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"', '"""', "'''"]
COMMENTS = ["--", "#", ("/*", "*/")]
@ -176,6 +194,7 @@ class BigQuery(Dialect):
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
@ -201,6 +220,7 @@ class BigQuery(Dialect):
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
}
FUNCTION_PARSERS = {
@ -289,6 +309,8 @@ class BigQuery(Dialect):
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),

View file

@ -345,7 +345,7 @@ class ClickHouse(Dialect):
"CONCAT",
*[
exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
for e in expression.expressions
for e in t.cast(t.List[exp.Condition], expression.expressions)
],
)

View file

@ -4,6 +4,7 @@ import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
@ -11,14 +12,6 @@ from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
if t.TYPE_CHECKING:
from sqlglot._typing import E
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
DIALECT = ""
@ -117,6 +110,9 @@ class _Dialect(type):
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
}
if enum not in ("", "bigquery"):
dialect_properties["SELECT_KINDS"] = ()
# Pass required dialect properties to the tokenizer, parser and generator classes
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
for name, value in dialect_properties.items():
@ -126,6 +122,8 @@ class _Dialect(type):
if not klass.STRICT_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
klass.generator_class.can_identify = klass.can_identify
return klass
@ -139,6 +137,10 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
# Determines whether or not unquoted identifiers are resolved as uppercase
# When set to None, it means that the dialect treats all identifiers as case-insensitive
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
@ -213,6 +215,66 @@ class Dialect(metaclass=_Dialect):
return expression
@classmethod
def normalize_identifier(cls, expression: E) -> E:
"""
Normalizes an unquoted identifier to either lower or upper case, thus essentially
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
they will be normalized regardless of being quoted or not.
"""
if isinstance(expression, exp.Identifier) and (
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
):
expression.set(
"this",
expression.this.upper()
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
else expression.this.lower(),
)
return expression
@classmethod
def case_sensitive(cls, text: str) -> bool:
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
return False
unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)
@classmethod
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
"""Checks if text can be identified given an identify option.
Args:
text: The text to check.
identify:
"always" or `True`: Always returns true.
"safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
"""
if identify is True or identify == "always":
return True
if identify == "safe":
return not cls.case_sensitive(text)
return False
@classmethod
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)

View file

@ -85,9 +85,17 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
)
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
return f"CAST({sql} AS TEXT)"
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -167,7 +175,7 @@ class DuckDB(Dialect):
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
if e.expressions and e.expressions[0].find(exp.Select)
else rename_func("LIST_VALUE")(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
@ -192,6 +200,7 @@ class DuckDB(Dialect):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),

View file

@ -86,13 +86,17 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
# Since FROM_JSON requires a nested type, we always wrap the json string with
# an array to ensure that "naked" strings like "'a'" will be handled correctly
wrapped_json = exp.Literal.string(f"[{this.this.name}]")
annotate_types(this)
from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json))
to_json = self.func("TO_JSON", from_json)
# This strips the [, ] delimiters of the dummy array printed by TO_JSON
return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1")
if this.type.is_type("json"):
return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
@ -153,6 +157,9 @@ class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
@ -268,9 +275,9 @@ class Hive(Dialect):
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
"sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
"cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
"distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
"sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
}
def _parse_types(

View file

@ -123,6 +123,8 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CHARSET": TokenType.CHARACTER_SET,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
@ -180,6 +182,9 @@ class MySQL(Dialect):
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -389,7 +394,7 @@ class MySQL(Dialect):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
JOIN_HINTS = False
TABLE_HINTS = False
TABLE_HINTS = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,

View file

@ -103,24 +103,15 @@ def _str_to_time_sql(
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto")
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
this = self.func(
"DATE_PARSE",
self.func(
"SUBSTR",
this if this.is_string else exp.cast(this, "VARCHAR"),
exp.Literal.number(1),
exp.Literal.number(10),
),
Presto.DATE_FORMAT,
)
this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")
return self.func(
"DATE_ADD",
@ -181,6 +172,11 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,

View file

@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
**Postgres.TIME_MAPPING,

View file

@ -167,6 +167,8 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
@ -283,11 +285,12 @@ class Snowflake(Dialect):
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"SAMPLE": TokenType.TABLE_SAMPLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
"SAMPLE": TokenType.TABLE_SAMPLE,
"TOP": TokenType.TOP,
}
SINGLE_TOKENS = {

View file

@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]

View file

@ -31,18 +31,19 @@ class Teradata(Dialect):
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BYTEINT": TokenType.SMALLINT,
"SEL": TokenType.SELECT,
"INS": TokenType.INSERT,
"MOD": TokenType.MOD,
"LT": TokenType.LT,
"LE": TokenType.LTE,
"GT": TokenType.GT,
"GE": TokenType.GTE,
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
"GE": TokenType.GTE,
"GT": TokenType.GT,
"INS": TokenType.INSERT,
"LE": TokenType.LTE,
"LT": TokenType.LT,
"MOD": TokenType.MOD,
"NE": TokenType.NEQ,
"NOT=": TokenType.NEQ,
"SEL": TokenType.SELECT,
"ST_GEOMETRY": TokenType.GEOMETRY,
"TOP": TokenType.TOP,
}
# Teradata does not support % as a modulo operator

View file

@ -1301,7 +1301,14 @@ class Constraint(Expression):
class Delete(Expression):
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
arg_types = {
"with": False,
"this": False,
"using": False,
"where": False,
"returning": False,
"limit": False,
}
def delete(
self,
@ -1844,6 +1851,10 @@ class CollateProperty(Property):
arg_types = {"this": True}
class CopyGrantsProperty(Property):
arg_types = {}
class DataBlocksizeProperty(Property):
arg_types = {
"size": False,
@ -2245,6 +2256,16 @@ QUERY_MODIFIERS = {
}
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
class WithTableHint(Expression):
arg_types = {"expressions": True}
# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html
class IndexTableHint(Expression):
arg_types = {"this": True, "expressions": False, "target": False}
class Table(Expression):
arg_types = {
"this": True,
@ -2402,6 +2423,7 @@ class Update(Expression):
"from": False,
"where": False,
"returning": False,
"limit": False,
}
@ -2434,8 +2456,6 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
"struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
"value": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
@ -3223,15 +3243,15 @@ class Star(Expression):
return self.name
class Parameter(Expression):
class Parameter(Condition):
arg_types = {"this": True, "wrapped": False}
class SessionParameter(Expression):
class SessionParameter(Condition):
arg_types = {"this": True, "kind": False}
class Placeholder(Expression):
class Placeholder(Condition):
arg_types = {"this": False, "kind": False}
@ -3333,6 +3353,7 @@ class DataType(Expression):
UINT128 = auto()
UINT256 = auto()
UNIQUEIDENTIFIER = auto()
USERDEFINED = "USER-DEFINED"
UUID = auto()
VARBINARY = auto()
VARCHAR = auto()

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@ -56,39 +56,40 @@ class Generator:
exp.TsOrDsAdd: lambda self, e: self.func(
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
}
# Whether or not null ordering is supported in order by
@ -142,6 +143,9 @@ class Generator:
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -182,6 +186,7 @@ class Generator:
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
@ -263,6 +268,8 @@ class Generator:
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
can_identify: t.Callable[[str, str | bool], bool]
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
@ -771,9 +778,11 @@ class Generator:
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
string = expression.this
if self.RAW_START:
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
return f"{self.RAW_START}{self.escape_str(expression.this)}{self.RAW_END}"
string = self.escape_str(string.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
this = self.sql(expression, "this")
@ -815,7 +824,8 @@ class Generator:
)
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
sql = f"DELETE{this}{using_sql}{where_sql}{returning}"
limit = self.sql(expression, "limit")
sql = f"DELETE{this}{using_sql}{where_sql}{returning}{limit}"
return self.prepend_ctes(expression, sql)
def drop_sql(self, expression: exp.Drop) -> str:
@ -883,7 +893,7 @@ class Generator:
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or self.can_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
@ -1157,6 +1167,15 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
def withtablehint_sql(self, expression: exp.WithTableHint) -> str:
return f"WITH ({self.expressions(expression, flat=True)})"
def indextablehint_sql(self, expression: exp.IndexTableHint) -> str:
this = f"{self.sql(expression, 'this')} INDEX"
target = self.sql(expression, "target")
target = f" FOR {target}" if target else ""
return f"{this}{target} ({self.expressions(expression, flat=True)})"
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
@ -1170,8 +1189,8 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", flat=True)
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
hints = self.expressions(expression, key="hints", sep=" ")
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="")
@ -1238,7 +1257,8 @@ class Generator:
from_sql = self.sql(expression, "from")
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}"
limit = self.sql(expression, "limit")
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}{limit}"
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
@ -1413,10 +1433,13 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
return text
def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
@ -1565,9 +1588,30 @@ class Generator:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = expression.args.get("kind")
kind = f" AS {kind}" if kind else ""
kind = self.sql(expression, "kind").upper()
expressions = self.expressions(expression)
if kind:
if kind in self.SELECT_KINDS:
kind = f" AS {kind}"
else:
if kind == "STRUCT":
expressions = self.expressions(
sqls=[
self.sql(
exp.Struct(
expressions=[
exp.column(e.output_name).eq(
e.this if isinstance(e, exp.Alias) else e
)
for e in expression.expressions
]
)
)
]
)
kind = ""
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,

View file

@ -14,7 +14,6 @@ from itertools import count
if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
@ -23,7 +22,12 @@ logger = logging.getLogger("sqlglot")
class AutoName(Enum):
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
"""
This is used for creating Enum classes where `auto()` is the string form
of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
"""
def _generate_next_value_(name, _start, _count, _last_values):
return name
@ -52,7 +56,7 @@ def ensure_list(value):
Ensures that a value is a list, otherwise casts or wraps it into one.
Args:
value: the value of interest.
value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
@ -80,7 +84,7 @@ def ensure_collection(value):
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
Args:
value: the value of interest.
value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
@ -97,8 +101,8 @@ def csv(*args: str, sep: str = ", ") -> str:
Formats any number of string arguments as CSV.
Args:
args: the string arguments to format.
sep: the argument separator.
args: The string arguments to format.
sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
@ -115,9 +119,9 @@ def subclasses(
Returns all subclasses for a collection of classes, possibly excluding some of them.
Args:
module_name: the name of the module to search for subclasses in.
classes: class(es) we want to find the subclasses of.
exclude: class(es) we want to exclude from the returned list.
module_name: The name of the module to search for subclasses in.
classes: Class(es) we want to find the subclasses of.
exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
@ -140,13 +144,13 @@ def apply_index_offset(
Applies an offset to a given integer literal expression.
Args:
this: the target of the index
expressions: the expression the offset will be applied to, wrapped in a list.
offset: the offset that will be applied.
this: The target of the index.
expressions: The expression the offset will be applied to, wrapped in a list.
offset: The offset that will be applied.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
`expressions` argument contains more than one expressions, it's returned unaffected.
`expressions` argument contains more than one expression, it's returned unaffected.
"""
if not offset or len(expressions) != 1:
return expressions
@ -189,8 +193,8 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
Applies a transformation to a given expression until a fix point is reached.
Args:
expression: the expression to be transformed.
func: the transformation to be applied.
expression: The expression to be transformed.
func: The transformation to be applied.
Returns:
The transformed expression.
@ -198,6 +202,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
while True:
for n, *_ in reversed(tuple(expression.walk())):
n._hash = hash(n)
start = hash(expression)
expression = func(expression)
@ -205,6 +210,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
n._hash = None
if start == hash(expression):
break
return expression
@ -213,7 +219,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
Sorts a given directed acyclic graph in topological order.
Args:
dag: the graph to be sorted.
dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
@ -261,7 +267,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
read_csv: a `ReadCSV` function call
read_csv: A `ReadCSV` function call.
Yields:
A python csv reader.
@ -288,8 +294,8 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
Searches for a new name.
Args:
taken: a collection of taken names.
base: base name to alter.
taken: A collection of taken names.
base: Base name to alter.
Returns:
The new, available name.
@ -327,10 +333,10 @@ def split_num_words(
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args:
value: the value to be split.
sep: the value to use to split on.
min_num_words: the minimum number of words that are going to be in the result.
fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
value: The value to be split.
sep: The value to use to split on.
min_num_words: The minimum number of words that are going to be in the result.
fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
@ -360,7 +366,7 @@ def is_iterable(value: t.Any) -> bool:
False
Args:
value: the value to check if it is an iterable.
value: The value to check if it is an iterable.
Returns:
A `bool` value indicating if it is an iterable.
@ -380,7 +386,7 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
[1, 2, 3]
Args:
values: the value to be flattened.
values: The value to be flattened.
Yields:
Non-iterable elements in `values`.
@ -396,7 +402,7 @@ def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
For example:
Example:
>>> dict_depth(None)
0
>>> dict_depth({})
@ -407,12 +413,6 @@ def dict_depth(d: t.Dict) -> int:
2
>>> dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + dict_depth(next(iter(d.values())))
@ -425,36 +425,5 @@ def dict_depth(d: t.Dict) -> int:
def first(it: t.Iterable[T]) -> T:
"""Returns the first element from an iterable.
Useful for sets.
"""
"""Returns the first element from an iterable (useful for sets)."""
return next(i for i in it)
def case_sensitive(text: str, dialect: DialectType) -> bool:
"""Checks if text contains any case sensitive characters depending on dialect."""
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)
def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
identify:
"always" or `True`: always returns true.
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
dialect: the dialect to use in order to decide whether a text should be identified.
Returns:
Whether or not a string should be identified.
"""
if identify is True or identify == "always":
return True
if identify == "safe":
return not case_sensitive(text, dialect)
return False

View file

@ -187,7 +187,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
}
ANNOTATORS = {
ANNOTATORS: t.Dict = {
**{
expr_type: lambda self, e: self._annotate_unary(e)
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))

View file

@ -1,12 +1,15 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
from sqlglot.dialects.dialect import Dialect, DialectType
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
Note:
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
when they're quoted, so in these cases all identifiers are normalized.
Example:
>>> import sqlglot
@ -21,16 +24,4 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
return expression.transform(_normalize, dialect, copy=False)
def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
if isinstance(node, exp.Identifier) and not node.quoted:
node.set(
"this",
node.this.upper()
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
else node.this.lower(),
)
return node
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)

View file

@ -5,9 +5,9 @@ import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import case_sensitive, seq_get
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema
@ -417,19 +417,9 @@ def _qualify_outputs(scope):
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
"""Makes sure all identifiers that need to be quoted are quoted."""
def _quote(expression: E) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
return expression.transform(_quote, copy=False)
return expression.transform(
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
)
class Resolver:

View file

@ -408,9 +408,14 @@ def remove_where_true(expression):
if always_true(where.this):
where.parent.set("where", None)
for join in expression.find_all(exp.Join):
if always_true(join.args.get("on")):
join.set("kind", "CROSS")
if (
always_true(join.args.get("on"))
and not join.args.get("using")
and not join.args.get("method")
):
join.set("on", None)
join.set("side", None)
join.set("kind", "CROSS")
def always_true(expression):

View file

@ -9,7 +9,7 @@ from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_list, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot._typing import E
@ -177,6 +177,7 @@ class Parser(metaclass=_Parser):
TokenType.BIGSERIAL,
TokenType.XML,
TokenType.UNIQUEIDENTIFIER,
TokenType.USERDEFINED,
TokenType.MONEY,
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
@ -465,7 +466,7 @@ class Parser(metaclass=_Parser):
}
EXPRESSION_PARSERS = {
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
@ -484,7 +485,7 @@ class Parser(metaclass=_Parser):
exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Where: lambda self: self._parse_where(),
@ -540,8 +541,7 @@ class Parser(metaclass=_Parser):
exp.Literal, this=token.text, is_string=False
),
TokenType.STAR: lambda self, _: self.expression(
exp.Star,
**{"except": self._parse_except(), "replace": self._parse_replace()},
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
),
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
@ -584,9 +584,10 @@ class Parser(metaclass=_Parser):
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER": lambda self: self._parse_cluster(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
@ -780,6 +781,8 @@ class Parser(metaclass=_Parser):
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
@ -788,7 +791,8 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default
# A NULL arg in CONCAT yields NULL by default
CONCAT_NULL_OUTPUTS_STRING = False
CONVERT_TYPE_FIRST = False
@ -1423,11 +1427,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
def _parse_cluster(self) -> t.Optional[exp.Cluster]:
if not self._match_text_seq("BY"):
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
if not self._match_text_seq("GRANTS"):
self._retreat(self._index - 1)
return None
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
return self.expression(exp.CopyGrantsProperty)
def _parse_freespace(self) -> exp.FreespaceProperty:
self._match(TokenType.EQ)
@ -1779,6 +1786,7 @@ class Parser(metaclass=_Parser):
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
returning=self._parse_returning(),
limit=self._parse_limit(),
)
def _parse_update(self) -> exp.Update:
@ -1790,6 +1798,7 @@ class Parser(metaclass=_Parser):
"from": self._parse_from(modifiers=True),
"where": self._parse_where(),
"returning": self._parse_returning(),
"limit": self._parse_limit(),
},
)
@ -2268,6 +2277,33 @@ class Parser(metaclass=_Parser):
partition_by=self._parse_partition_by(),
)
def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]:
hints: t.List[exp.Expression] = []
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16
hints.append(
self.expression(
exp.WithTableHint,
expressions=self._parse_csv(
lambda: self._parse_function() or self._parse_var(any_token=True)
),
)
)
self._match_r_paren()
else:
# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
hint = exp.IndexTableHint(this=self._prev.text.upper())
self._match_texts({"INDEX", "KEY"})
if self._match(TokenType.FOR):
hint.set("target", self._advance_any() and self._prev.text.upper())
hint.set("expressions", self._parse_wrapped_id_vars())
hints.append(hint)
return hints or None
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
return (
(not schema and self._parse_function(optional_parens=False))
@ -2335,12 +2371,7 @@ class Parser(metaclass=_Parser):
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set(
"hints",
self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
)
self._match_r_paren()
this.set("hints", self._parse_table_hints())
if not self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
@ -2610,8 +2641,8 @@ class Parser(metaclass=_Parser):
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]:
if not self._match_text_seq(*texts):
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
if not self._match(token):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
@ -3655,7 +3686,11 @@ class Parser(metaclass=_Parser):
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
if self.CONCAT_NULL_OUTPUTS_STRING:
args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args]
args = [
exp.func("COALESCE", exp.cast(arg, "text"), exp.Literal.string(""))
for arg in args
if arg
]
# Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
# we find such a call we replace it with its argument.
@ -4553,13 +4588,16 @@ class Parser(metaclass=_Parser):
curr = self._curr.text.upper()
key = curr.split(" ")
this.append(curr)
self._advance()
result, trie = in_trie(trie, key)
if result == 0:
if result == TrieResult.FAILED:
break
if result == 2:
if result == TrieResult.EXISTS:
subparser = parsers[" ".join(this)]
return subparser
self._retreat(index)
return None

View file

@ -6,10 +6,10 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot._typing import T
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
@ -135,10 +135,10 @@ class AbstractMappingSchema(t.Generic[T]):
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
if value == 0:
if value == TrieResult.FAILED:
return None
if value == 1:
if value == TrieResult.PREFIX:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
if len(possibilities) == 1:
@ -289,7 +289,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
def _normalize(self, schema: t.Dict) -> t.Dict:
"""
Converts all identifiers in the schema into lowercase, unless they're quoted.
Normalizes all identifiers in the schema.
Args:
schema: the schema to normalize.
@ -304,7 +304,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
columns = nested_get(schema, *zip(keys, keys))
assert columns is not None
normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys]
normalized_keys = [
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
]
for column_name, column_type in columns.items():
nested_set(
normalized_mapping,
@ -321,12 +323,15 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
value = normalized_table.args.get(arg)
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(
arg, exp.to_identifier(self._normalize_name(value, dialect=dialect))
arg,
exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
)
return normalized_table
def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str:
def _normalize_name(
self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
) -> str:
dialect = dialect or self.dialect
try:
@ -335,11 +340,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return name if isinstance(name, str) else name.name
name = identifier.name
if not self.normalize or identifier.quoted:
if not self.normalize:
return name
return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those

View file

@ -2,7 +2,7 @@ import typing as t
# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import in_trie, new_trie
from sqlglot.trie import TrieResult, in_trie, new_trie
def format_time(
@ -37,7 +37,7 @@ def format_time(
chars = string[start:end]
result, current = in_trie(current, chars[-1])
if result == 0:
if result == TrieResult.FAILED:
if sym:
end -= 1
chars = sym
@ -45,11 +45,12 @@ def format_time(
start += len(chars)
chunks.append(chars)
current = trie
elif result == 2:
elif result == TrieResult.EXISTS:
sym = chars
end += 1
if result and end > size:
if result != TrieResult.FAILED and end > size:
chunks.append(chars)
return "".join(mapping.get(chars, chars) for chars in chunks)

View file

@ -4,7 +4,7 @@ import typing as t
from enum import auto
from sqlglot.helper import AutoName
from sqlglot.trie import in_trie, new_trie
from sqlglot.trie import TrieResult, in_trie, new_trie
class TokenType(AutoName):
@ -137,6 +137,7 @@ class TokenType(AutoName):
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
USERDEFINED = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
@ -163,6 +164,7 @@ class TokenType(AutoName):
CACHE = auto()
CASE = auto()
CHARACTER_SET = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
COMMENT = auto()
@ -182,6 +184,7 @@ class TokenType(AutoName):
DESCRIBE = auto()
DICTIONARY = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@ -196,6 +199,7 @@ class TokenType(AutoName):
FINAL = auto()
FIRST = auto()
FOR = auto()
FORCE = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
FROM = auto()
@ -208,6 +212,7 @@ class TokenType(AutoName):
HAVING = auto()
HINT = auto()
IF = auto()
IGNORE = auto()
ILIKE = auto()
ILIKE_ANY = auto()
IN = auto()
@ -282,6 +287,7 @@ class TokenType(AutoName):
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
@ -509,6 +515,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
@ -526,6 +533,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESC": TokenType.DESC,
"DESCRIBE": TokenType.DESCRIBE,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@ -617,6 +625,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
@ -717,6 +726,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PREPARE": TokenType.COMMAND,
"TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
"USER-DEFINED": TokenType.USERDEFINED,
}
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
@ -905,13 +915,13 @@ class Tokenizer(metaclass=_Tokenizer):
while chars:
if skip:
result = 1
result = TrieResult.PREFIX
else:
result, trie = in_trie(trie, char.upper())
if result == 0:
if result == TrieResult.FAILED:
break
if result == 2:
if result == TrieResult.EXISTS:
word = chars
size += 1

View file

@ -1,14 +1,21 @@
import typing as t
from enum import Enum, auto
key = t.Sequence[t.Hashable]
class TrieResult(Enum):
FAILED = auto()
PREFIX = auto()
EXISTS = auto()
def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
The trie is represented as a sequence of nested dictionaries keyed by either single character
strings, or by 0, which is used to designate that a keyword is in the trie.
The trie is represented as a sequence of nested dictionaries keyed by either single
character strings, or by 0, which is used to designate that a keyword is in the trie.
Example:
>>> new_trie(["bla", "foo", "blab"])
@ -25,46 +32,50 @@ def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Di
for key in keywords:
current = trie
for char in key:
current = current.setdefault(char, {})
current[0] = True
return trie
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
def in_trie(trie: t.Dict, key: key) -> t.Tuple[TrieResult, t.Dict]:
"""
Checks whether a key is in a trie.
Examples:
>>> in_trie(new_trie(["cat"]), "bob")
(0, {'c': {'a': {'t': {0: True}}}})
(<TrieResult.FAILED: 1>, {'c': {'a': {'t': {0: True}}}})
>>> in_trie(new_trie(["cat"]), "ca")
(1, {'t': {0: True}})
(<TrieResult.PREFIX: 2>, {'t': {0: True}})
>>> in_trie(new_trie(["cat"]), "cat")
(2, {0: True})
(<TrieResult.EXISTS: 3>, {0: True})
Args:
trie: the trie to be searched.
key: the target key.
trie: The trie to be searched.
key: The target key.
Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point
where the search stops, and `value` is a TrieResult value that can be one of:
- TrieResult.FAILED: the search was unsuccessful
- TrieResult.PREFIX: `value` is a prefix of a keyword in `trie`
- TrieResult.EXISTS: `key` exists in `trie`
"""
if not key:
return (0, trie)
return (TrieResult.FAILED, trie)
current = trie
for char in key:
if char not in current:
return (0, current)
return (TrieResult.FAILED, current)
current = current[char]
if 0 in current:
return (2, current)
return (1, current)
return (TrieResult.EXISTS, current)
return (TrieResult.PREFIX, current)