1
0
Fork 0

Merging upstream version 11.2.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:40:23 +01:00
parent c6f7c6bbe1
commit 428b7dd76f
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 33054 additions and 31671 deletions

View file

@ -9,6 +9,7 @@ repos:
require_serial: true require_serial: true
files: ^(sqlglot/|tests/|setup.py) files: ^(sqlglot/|tests/|setup.py)
- id: isort - id: isort
args: [--combine-as]
name: isort name: isort
entry: isort entry: isort
language: system language: system

View file

@ -71,6 +71,8 @@ Changes:
- Breaking: Change Power to binary expression. - Breaking: Change Power to binary expression.
- Breaking: Removed mapping of "}}" to BLOCK_END token.
- New: x GLOB y support. - New: x GLOB y support.
v10.5.0 v10.5.0

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -9,38 +9,45 @@ from __future__ import annotations
import typing as t import typing as t
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects from sqlglot.dialects.dialect import Dialect as Dialect, Dialects as Dialects
from sqlglot.diff import diff from sqlglot.diff import diff as diff
from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError from sqlglot.errors import (
from sqlglot.expressions import Expression ErrorLevel as ErrorLevel,
from sqlglot.expressions import alias_ as alias ParseError as ParseError,
from sqlglot.expressions import ( TokenError as TokenError,
and_, UnsupportedError as UnsupportedError,
column,
condition,
except_,
from_,
intersect,
maybe_parse,
not_,
or_,
select,
subquery,
) )
from sqlglot.expressions import table_ as table from sqlglot.expressions import (
from sqlglot.expressions import to_column, to_table, union Expression as Expression,
from sqlglot.generator import Generator alias_ as alias,
from sqlglot.parser import Parser and_ as and_,
from sqlglot.schema import MappingSchema, Schema column as column,
from sqlglot.tokens import Tokenizer, TokenType condition as condition,
except_ as except_,
from_ as from_,
intersect as intersect,
maybe_parse as maybe_parse,
not_ as not_,
or_ as or_,
select as select,
subquery as subquery,
table_ as table,
to_column as to_column,
to_table as to_table,
union as union,
)
from sqlglot.generator import Generator as Generator
from sqlglot.parser import Parser as Parser
from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType from sqlglot.dialects.dialect import DialectType as DialectType
T = t.TypeVar("T", bound=Expression) T = t.TypeVar("T", bound=Expression)
__version__ = "11.2.0" __version__ = "11.2.3"
pretty = False pretty = False
"""Whether to format generated SQL by default.""" """Whether to format generated SQL by default."""

View file

@ -4,8 +4,7 @@ import typing as t
from sqlglot import exp as expression from sqlglot import exp as expression
from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.column import Column
from sqlglot.helper import ensure_list from sqlglot.helper import ensure_list, flatten as _flatten
from sqlglot.helper import flatten as _flatten
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName

View file

@ -38,7 +38,10 @@ def _date_add_sql(
) -> t.Callable[[generator.Generator, exp.Expression], str]: ) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression): def func(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})" unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression, unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func return func
@ -235,6 +238,7 @@ class BigQuery(Dialect):
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest, exp.Values: _derived_table_values_to_unnest,

View file

@ -462,6 +462,11 @@ class MySQL(Dialect):
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}
def show_sql(self, expression): def show_sql(self, expression):
this = f" {expression.name}" this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else "" full = " FULL" if expression.args.get("full") else ""

View file

@ -318,3 +318,8 @@ class Postgres(Dialect):
if isinstance(seq_get(e.expressions, 0), exp.Select) if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
} }
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -150,6 +150,10 @@ class Snowflake(Dialect):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAYAGG": exp.ArrayAgg.from_arg_list,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
"IFF": exp.If.from_arg_list, "IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp, "TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list,
@ -215,7 +219,6 @@ class Snowflake(Dialect):
} }
class Generator(generator.Generator): class Generator(generator.Generator):
CREATE_TRANSIENT = True
PARAMETER_TOKEN = "$" PARAMETER_TOKEN = "$"
TRANSFORMS = { TRANSFORMS = {
@ -252,6 +255,11 @@ class Snowflake(Dialect):
"replace": "RENAME", "replace": "RENAME",
} }
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
}
def ilikeany_sql(self, expression: exp.ILikeAny) -> str: def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY") return self.binary(expression, "ILIKE ANY")

View file

@ -8,9 +8,12 @@ from sqlglot.helper import seq_get
def _create_sql(self, e): def _create_sql(self, e):
kind = e.args.get("kind") kind = e.args.get("kind")
temporary = e.args.get("temporary") properties = e.args.get("properties")
if kind.upper() == "TABLE" and temporary is True: if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e) return create_with_partitions_sql(self, e)

View file

@ -114,6 +114,7 @@ class Teradata(Dialect):
PROPERTIES_LOCATION = { PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore **generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
exp.VolatilityProperty: exp.Properties.Location.POST_CREATE,
} }
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:

View file

@ -11,8 +11,7 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from heapq import heappop, heappush from heapq import heappop, heappush
from sqlglot import Dialect from sqlglot import Dialect, expressions as exp
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection from sqlglot.helper import ensure_collection
@ -58,7 +57,12 @@ if t.TYPE_CHECKING:
Edit = t.Union[Insert, Remove, Move, Update, Keep] Edit = t.Union[Insert, Remove, Move, Update, Keep]
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: def diff(
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
**kwargs: t.Any,
) -> t.List[Edit]:
""" """
Returns the list of changes between the source and the target expressions. Returns the list of changes between the source and the target expressions.
@ -80,13 +84,38 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
Args: Args:
source: the source expression. source: the source expression.
target: the target expression against which the diff should be calculated. target: the target expression against which the diff should be calculated.
matchings: the list of pre-matched node pairs which is used to help the algorithm's
heuristics produce better results for subtrees that are known by a caller to be matching.
Note: expression references in this list must refer to the same node objects that are
referenced in source / target trees.
Returns: Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
target expression trees. This list represents a sequence of steps needed to transform the source target expression trees. This list represents a sequence of steps needed to transform the source
expression tree into the target one. expression tree into the target one.
""" """
return ChangeDistiller().diff(source.copy(), target.copy()) matchings = matchings or []
matching_ids = {id(n) for pair in matchings for n in pair}
def compute_node_mappings(
original: exp.Expression, copy: exp.Expression
) -> t.Dict[int, exp.Expression]:
return {
id(old_node): new_node
for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
if id(old_node) in matching_ids
}
source_copy = source.copy()
target_copy = target.copy()
node_mappings = {
**compute_node_mappings(source, source_copy),
**compute_node_mappings(target, target_copy),
}
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
LEAF_EXPRESSION_TYPES = ( LEAF_EXPRESSION_TYPES = (
@ -109,16 +138,26 @@ class ChangeDistiller:
self.t = t self.t = t
self._sql_generator = Dialect().generator() self._sql_generator = Dialect().generator()
def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]: def diff(
self,
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
) -> t.List[Edit]:
matchings = matchings or []
pre_matched_nodes = {id(s): id(t) for s, t in matchings}
if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings):
raise ValueError("Each node can be referenced at most once in the list of matchings")
self._source = source self._source = source
self._target = target self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index) self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set() matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
return self._generate_edit_script(matching_set) return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:

View file

@ -82,7 +82,7 @@ class Expression(metaclass=_Expression):
key = "expression" key = "expression"
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type") __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
def __init__(self, **args: t.Any): def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args self.args: t.Dict[str, t.Any] = args
@ -90,6 +90,7 @@ class Expression(metaclass=_Expression):
self.arg_key: t.Optional[str] = None self.arg_key: t.Optional[str] = None
self.comments: t.Optional[t.List[str]] = None self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
self._set_parent(arg_key, value) self._set_parent(arg_key, value)
@ -219,10 +220,23 @@ class Expression(metaclass=_Expression):
dtype = DataType.build(dtype) dtype = DataType.build(dtype)
self._type = dtype # type: ignore self._type = dtype # type: ignore
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
self._meta = {}
return self._meta
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args)) copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments if self.comments is not None:
copy.type = self.type copy.comments = deepcopy(self.comments)
if self._type is not None:
copy._type = self._type.copy()
if self._meta is not None:
copy._meta = deepcopy(self._meta)
return copy return copy
def copy(self): def copy(self):
@ -329,6 +343,15 @@ class Expression(metaclass=_Expression):
""" """
return self.find_ancestor(Select) return self.find_ancestor(Select)
def root(self) -> Expression:
"""
Returns the root expression of this tree.
"""
expression = self
while expression.parent:
expression = expression.parent
return expression
def walk(self, bfs=True, prune=None): def walk(self, bfs=True, prune=None):
""" """
Returns a generator object which visits all nodes in this tree. Returns a generator object which visits all nodes in this tree.
@ -767,21 +790,10 @@ class Create(Expression):
"this": True, "this": True,
"kind": True, "kind": True,
"expression": False, "expression": False,
"set": False,
"multiset": False,
"global_temporary": False,
"volatile": False,
"exists": False, "exists": False,
"properties": False, "properties": False,
"temporary": False,
"transient": False,
"external": False,
"replace": False, "replace": False,
"unique": False, "unique": False,
"materialized": False,
"data": False,
"statistics": False,
"no_primary_index": False,
"indexes": False, "indexes": False,
"no_schema_binding": False, "no_schema_binding": False,
"begin": False, "begin": False,
@ -1336,42 +1348,92 @@ class Property(Expression):
arg_types = {"this": True, "value": True} arg_types = {"this": True, "value": True}
class AfterJournalProperty(Property):
arg_types = {"no": True, "dual": False, "local": False}
class AlgorithmProperty(Property): class AlgorithmProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class AutoIncrementProperty(Property):
arg_types = {"this": True}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
class CharacterSetProperty(Property):
arg_types = {"this": True, "default": True}
class ChecksumProperty(Property):
arg_types = {"on": False, "default": False}
class CollateProperty(Property):
arg_types = {"this": True}
class DataBlocksizeProperty(Property):
arg_types = {"size": False, "units": False, "min": False, "default": False}
class DefinerProperty(Property): class DefinerProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
class PartitionedByProperty(Property):
arg_types = {"this": True}
class FileFormatProperty(Property):
arg_types = {"this": True}
class DistKeyProperty(Property): class DistKeyProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property): class DistStyleProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class EngineProperty(Property):
arg_types = {"this": True}
class ExecuteAsProperty(Property):
arg_types = {"this": True}
class ExternalProperty(Property):
arg_types = {"this": False}
class FallbackProperty(Property):
arg_types = {"no": True, "protection": False}
class FileFormatProperty(Property):
arg_types = {"this": True}
class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
}
class JournalProperty(Property):
arg_types = {"no": True, "dual": False, "before": False}
class LanguageProperty(Property):
arg_types = {"this": True}
class LikeProperty(Property): class LikeProperty(Property):
arg_types = {"this": True, "expressions": False} arg_types = {"this": True, "expressions": False}
@ -1380,23 +1442,37 @@ class LocationProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class EngineProperty(Property): class LockingProperty(Property):
arg_types = {"this": True} arg_types = {
"this": False,
"kind": True,
"for_or_in": True,
"lock_type": True,
"override": False,
}
class AutoIncrementProperty(Property): class LogProperty(Property):
arg_types = {"this": True} arg_types = {"no": True}
class CharacterSetProperty(Property): class MaterializedProperty(Property):
arg_types = {"this": True, "default": True} arg_types = {"this": False}
class CollateProperty(Property): class MergeBlockRatioProperty(Property):
arg_types = {"this": True} arg_types = {"this": False, "no": False, "default": False, "percent": False}
class SchemaCommentProperty(Property): class NoPrimaryIndexProperty(Property):
arg_types = {"this": False}
class OnCommitProperty(Property):
arg_type = {"this": False}
class PartitionedByProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
@ -1404,18 +1480,6 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False} arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property):
arg_types = {"this": True}
class ExecuteAsProperty(Property):
arg_types = {"this": True}
class VolatilityProperty(Property):
arg_types = {"this": True}
class RowFormatDelimitedProperty(Property): class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = { arg_types = {
@ -1433,70 +1497,50 @@ class RowFormatSerdeProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class SchemaCommentProperty(Property):
arg_types = {"this": True}
class SerdeProperties(Property): class SerdeProperties(Property):
arg_types = {"expressions": True} arg_types = {"expressions": True}
class FallbackProperty(Property): class SetProperty(Property):
arg_types = {"no": True, "protection": False} arg_types = {"multi": True}
class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
class TemporaryProperty(Property):
arg_types = {"global_": True}
class TransientProperty(Property):
arg_types = {"this": False}
class VolatilityProperty(Property):
arg_types = {"this": True}
class WithDataProperty(Property):
arg_types = {"no": True, "statistics": False}
class WithJournalTableProperty(Property): class WithJournalTableProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
class LogProperty(Property):
arg_types = {"no": True}
class JournalProperty(Property):
arg_types = {"no": True, "dual": False, "before": False}
class AfterJournalProperty(Property):
arg_types = {"no": True, "dual": False, "local": False}
class ChecksumProperty(Property):
arg_types = {"on": False, "default": False}
class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class MergeBlockRatioProperty(Property):
arg_types = {"this": False, "no": False, "default": False, "percent": False}
class DataBlocksizeProperty(Property):
arg_types = {"size": False, "units": False, "min": False, "default": False}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
}
class LockingProperty(Property):
arg_types = {
"this": False,
"kind": True,
"for_or_in": True,
"lock_type": True,
"override": False,
}
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -1533,7 +1577,7 @@ class Properties(Expression):
# Form: alias selection # Form: alias selection
# create [POST_CREATE] # create [POST_CREATE]
# table a [POST_NAME] # table a [POST_NAME]
# as [POST_ALIAS] (select * from b) # as [POST_ALIAS] (select * from b) [POST_EXPRESSION]
# index (c) [POST_INDEX] # index (c) [POST_INDEX]
class Location(AutoName): class Location(AutoName):
POST_CREATE = auto() POST_CREATE = auto()
@ -1541,6 +1585,7 @@ class Properties(Expression):
POST_SCHEMA = auto() POST_SCHEMA = auto()
POST_WITH = auto() POST_WITH = auto()
POST_ALIAS = auto() POST_ALIAS = auto()
POST_EXPRESSION = auto()
POST_INDEX = auto() POST_INDEX = auto()
UNSUPPORTED = auto() UNSUPPORTED = auto()
@ -1797,6 +1842,10 @@ class Union(Subqueryable):
def named_selects(self): def named_selects(self):
return self.this.unnest().named_selects return self.this.unnest().named_selects
@property
def is_star(self) -> bool:
return self.this.is_star or self.expression.is_star
@property @property
def selects(self): def selects(self):
return self.this.unnest().selects return self.this.unnest().selects
@ -2424,6 +2473,10 @@ class Select(Subqueryable):
def named_selects(self) -> t.List[str]: def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name] return [e.output_name for e in self.expressions if e.alias_or_name]
@property
def is_star(self) -> bool:
return any(expression.is_star for expression in self.expressions)
@property @property
def selects(self) -> t.List[Expression]: def selects(self) -> t.List[Expression]:
return self.expressions return self.expressions
@ -2446,6 +2499,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this expression = expression.this
return expression return expression
@property
def is_star(self) -> bool:
return self.this.is_star
@property @property
def output_name(self): def output_name(self):
return self.alias return self.alias
@ -2478,6 +2535,7 @@ class Tag(Expression):
class Pivot(Expression): class Pivot(Expression):
arg_types = { arg_types = {
"this": False, "this": False,
"alias": False,
"expressions": True, "expressions": True,
"field": True, "field": True,
"unpivot": True, "unpivot": True,
@ -2603,6 +2661,7 @@ class DataType(Expression):
IMAGE = auto() IMAGE = auto()
VARIANT = auto() VARIANT = auto()
OBJECT = auto() OBJECT = auto()
INET = auto()
NULL = auto() NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation UNKNOWN = auto() # Sentinel value, useful for type annotation

View file

@ -64,15 +64,22 @@ class Generator:
"TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
), ),
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS",
exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.VolatilityProperty: lambda self, e: e.name, exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
@ -87,9 +94,6 @@ class Generator:
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
} }
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
# Whether or not null ordering is supported in order by # Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True NULL_ORDERING_SUPPORTED = True
@ -112,6 +116,7 @@ class Generator:
exp.DataType.Type.LONGTEXT: "TEXT", exp.DataType.Type.LONGTEXT: "TEXT",
exp.DataType.Type.MEDIUMBLOB: "BLOB", exp.DataType.Type.MEDIUMBLOB: "BLOB",
exp.DataType.Type.LONGBLOB: "BLOB", exp.DataType.Type.LONGBLOB: "BLOB",
exp.DataType.Type.INET: "INET",
} }
STAR_MAPPING = { STAR_MAPPING = {
@ -140,6 +145,7 @@ class Generator:
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExternalProperty: exp.Properties.Location.POST_CREATE,
exp.FallbackProperty: exp.Properties.Location.POST_NAME, exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME, exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
@ -150,7 +156,10 @@ class Generator:
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockingProperty: exp.Properties.Location.POST_ALIAS, exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
exp.LogProperty: exp.Properties.Location.POST_NAME, exp.LogProperty: exp.Properties.Location.POST_NAME,
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.Property: exp.Properties.Location.POST_WITH, exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
@ -158,10 +167,14 @@ class Generator:
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH, exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
} }
@ -537,34 +550,9 @@ class Generator:
else: else:
expression_sql = f" AS{expression_sql}" expression_sql = f" AS{expression_sql}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
external = " EXTERNAL" if expression.args.get("external") else ""
replace = " OR REPLACE" if expression.args.get("replace") else "" replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else "" unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
set_ = " SET" if expression.args.get("set") else ""
multiset = " MULTISET" if expression.args.get("multiset") else ""
global_temporary = " GLOBAL TEMPORARY" if expression.args.get("global_temporary") else ""
volatile = " VOLATILE" if expression.args.get("volatile") else ""
data = expression.args.get("data")
if data is None:
data = ""
elif data:
data = " WITH DATA"
else:
data = " WITH NO DATA"
statistics = expression.args.get("statistics")
if statistics is None:
statistics = ""
elif statistics:
statistics = " AND STATISTICS"
else:
statistics = " AND NO STATISTICS"
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
indexes = expression.args.get("indexes") indexes = expression.args.get("indexes")
index_sql = "" index_sql = ""
@ -605,28 +593,24 @@ class Generator:
wrapped=False, wrapped=False,
) )
modifiers = "".join( modifiers = "".join((replace, unique, postcreate_props_sql))
(
replace, postexpression_props_sql = ""
temporary, if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
transient, postexpression_props_sql = self.properties(
external, exp.Properties(
unique, expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION]
materialized, ),
set_, sep=" ",
multiset, prefix=" ",
global_temporary, wrapped=False,
volatile,
postcreate_props_sql,
)
) )
no_schema_binding = ( no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
) )
post_expression_modifiers = "".join((data, statistics, no_primary_index)) expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}"
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql) return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str: def describe_sql(self, expression: exp.Describe) -> str:
@ -810,6 +794,8 @@ class Generator:
properties_locs[exp.Properties.Location.POST_CREATE].append(p) properties_locs[exp.Properties.Location.POST_CREATE].append(p)
elif p_loc == exp.Properties.Location.POST_ALIAS: elif p_loc == exp.Properties.Location.POST_ALIAS:
properties_locs[exp.Properties.Location.POST_ALIAS].append(p) properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
elif p_loc == exp.Properties.Location.POST_EXPRESSION:
properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p)
elif p_loc == exp.Properties.Location.UNSUPPORTED: elif p_loc == exp.Properties.Location.UNSUPPORTED:
self.unsupported(f"Unsupported property {p.key}") self.unsupported(f"Unsupported property {p.key}")
@ -931,6 +917,14 @@ class Generator:
override = " OVERRIDE" if expression.args.get("override") else "" override = " OVERRIDE" if expression.args.get("override") else ""
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}" return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
statistics = expression.args.get("statistics")
statistics_sql = ""
if statistics is not None:
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
return f"{data_sql}{statistics_sql}"
def insert_sql(self, expression: exp.Insert) -> str: def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite") overwrite = expression.args.get("overwrite")
@ -1003,10 +997,6 @@ class Generator:
system_time = expression.args.get("system_time") system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
if alias and pivots:
pivots = f"{pivots}{alias}"
alias = ""
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str: def tablesample_sql(self, expression: exp.TableSample) -> str:
@ -1034,11 +1024,13 @@ class Generator:
def pivot_sql(self, expression: exp.Pivot) -> str: def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot") unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT" direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions") expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field") field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})" return f"{this} {direction}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str: def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})" return f"({self.expressions(expression, flat=True)})"

View file

@ -144,6 +144,7 @@ class Parser(metaclass=_Parser):
TokenType.IMAGE, TokenType.IMAGE,
TokenType.VARIANT, TokenType.VARIANT,
TokenType.OBJECT, TokenType.OBJECT,
TokenType.INET,
*NESTED_TYPE_TOKENS, *NESTED_TYPE_TOKENS,
} }
@ -509,73 +510,82 @@ class Parser(metaclass=_Parser):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
"AFTER": lambda self: self._parse_afterjournal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BEFORE": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(), "CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self.expression( "CLUSTER BY": lambda self: self.expression(
exp.Cluster, expressions=self._parse_csv(self._parse_ordered) exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
), ),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"SORTKEY": lambda self: self._parse_sortkey(),
"LIKE": lambda self: self._parse_create_like(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), "DATABLOCKSIZE": lambda self: self._parse_datablocksize(
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), default=self._prev.text.upper() == "DEFAULT"
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), ),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), "DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression( "DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
), ),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"EXTERNAL": lambda self: self.expression(exp.ExternalProperty),
"FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
"GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression( "IMMUTABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
), ),
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
"WITH": lambda self: self._parse_with_property(),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
"LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
"BEFORE": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"JOURNAL": lambda self: self._parse_journal( "JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
), ),
"AFTER": lambda self: self._parse_afterjournal( "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" "LIKE": lambda self: self._parse_create_like(),
),
"LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True), "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True),
"NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"CHECKSUM": lambda self: self._parse_checksum(), "LOCK": lambda self: self._parse_locking(),
"FREESPACE": lambda self: self._parse_freespace(), "LOCKING": lambda self: self._parse_locking(),
"LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
"MAX": lambda self: self._parse_datablocksize(),
"MAXIMUM": lambda self: self._parse_datablocksize(),
"MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio( "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio(
no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT" no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT"
), ),
"MIN": lambda self: self._parse_datablocksize(), "MIN": lambda self: self._parse_datablocksize(),
"MINIMUM": lambda self: self._parse_datablocksize(), "MINIMUM": lambda self: self._parse_datablocksize(),
"MAX": lambda self: self._parse_datablocksize(), "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
"MAXIMUM": lambda self: self._parse_datablocksize(), "NO": lambda self: self._parse_noprimaryindex(),
"DATABLOCKSIZE": lambda self: self._parse_datablocksize( "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
default=self._prev.text.upper() == "DEFAULT" "ON": lambda self: self._parse_oncommit(),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
), ),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"DEFINER": lambda self: self._parse_definer(), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"LOCK": lambda self: self._parse_locking(), "TEMPORARY": lambda self: self._parse_temporary(global_=False),
"LOCKING": lambda self: self._parse_locking(), "TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
"WITH": lambda self: self._parse_with_property(),
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
@ -979,15 +989,7 @@ class Parser(metaclass=_Parser):
replace = self._prev.text.upper() == "REPLACE" or self._match_pair( replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE TokenType.OR, TokenType.REPLACE
) )
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata
volatile = self._match(TokenType.VOLATILE) # Teradata
temporary = self._match(TokenType.TEMPORARY)
transient = self._match_text_seq("TRANSIENT")
external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE) unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE) self._match(TokenType.TABLE)
@ -1005,16 +1007,17 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists(not_=True) exists = self._parse_exists(not_=True)
this = None this = None
expression = None expression = None
data = None
statistics = None
no_primary_index = None
indexes = None indexes = None
no_schema_binding = None no_schema_binding = None
begin = None begin = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type) this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties() temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
self._match(TokenType.ALIAS) self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN) begin = self._match(TokenType.BEGIN)
@ -1036,7 +1039,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.COMMA): if self._match(TokenType.COMMA):
temp_properties = self._parse_properties(before=True) temp_properties = self._parse_properties(before=True)
if properties and temp_properties: if properties and temp_properties:
properties.expressions.append(temp_properties.expressions) properties.expressions.extend(temp_properties.expressions)
elif temp_properties: elif temp_properties:
properties = temp_properties properties = temp_properties
@ -1045,7 +1048,7 @@ class Parser(metaclass=_Parser):
# exp.Properties.Location.POST_SCHEMA and POST_WITH # exp.Properties.Location.POST_SCHEMA and POST_WITH
temp_properties = self._parse_properties() temp_properties = self._parse_properties()
if properties and temp_properties: if properties and temp_properties:
properties.expressions.append(temp_properties.expressions) properties.expressions.extend(temp_properties.expressions)
elif temp_properties: elif temp_properties:
properties = temp_properties properties = temp_properties
@ -1059,24 +1062,19 @@ class Parser(metaclass=_Parser):
): ):
temp_properties = self._parse_properties() temp_properties = self._parse_properties()
if properties and temp_properties: if properties and temp_properties:
properties.expressions.append(temp_properties.expressions) properties.expressions.extend(temp_properties.expressions)
elif temp_properties: elif temp_properties:
properties = temp_properties properties = temp_properties
expression = self._parse_ddl_select() expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE: if create_token.token_type == TokenType.TABLE:
if self._match_text_seq("WITH", "DATA"): # exp.Properties.Location.POST_EXPRESSION
data = True temp_properties = self._parse_properties()
elif self._match_text_seq("WITH", "NO", "DATA"): if properties and temp_properties:
data = False properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
if self._match_text_seq("AND", "STATISTICS"): properties = temp_properties
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
statistics = False
no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX")
indexes = [] indexes = []
while True: while True:
@ -1086,7 +1084,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PARTITION_BY, advance=False): if self._match(TokenType.PARTITION_BY, advance=False):
temp_properties = self._parse_properties() temp_properties = self._parse_properties()
if properties and temp_properties: if properties and temp_properties:
properties.expressions.append(temp_properties.expressions) properties.expressions.extend(temp_properties.expressions)
elif temp_properties: elif temp_properties:
properties = temp_properties properties = temp_properties
@ -1102,22 +1100,11 @@ class Parser(metaclass=_Parser):
exp.Create, exp.Create,
this=this, this=this,
kind=create_token.text, kind=create_token.text,
unique=unique,
expression=expression, expression=expression,
set=set_,
multiset=multiset,
global_temporary=global_temporary,
volatile=volatile,
exists=exists, exists=exists,
properties=properties, properties=properties,
temporary=temporary,
transient=transient,
external=external,
replace=replace, replace=replace,
unique=unique,
materialized=materialized,
data=data,
statistics=statistics,
no_primary_index=no_primary_index,
indexes=indexes, indexes=indexes,
no_schema_binding=no_schema_binding, no_schema_binding=no_schema_binding,
begin=begin, begin=begin,
@ -1196,15 +1183,21 @@ class Parser(metaclass=_Parser):
def _parse_with_property( def _parse_with_property(
self, self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
self._match(TokenType.WITH)
if self._match(TokenType.L_PAREN, advance=False): if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property) return self._parse_wrapped_csv(self._parse_property)
if self._match_text_seq("JOURNAL"):
return self._parse_withjournaltable()
if self._match_text_seq("DATA"):
return self._parse_withdata(no=False)
elif self._match_text_seq("NO", "DATA"):
return self._parse_withdata(no=True)
if not self._next: if not self._next:
return None return None
if self._next.text.upper() == "JOURNAL":
return self._parse_withjournaltable()
return self._parse_withisolatedloading() return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html # https://dev.mysql.com/doc/refman/8.0/en/create-view.html
@ -1221,7 +1214,7 @@ class Parser(metaclass=_Parser):
return exp.DefinerProperty(this=f"{user}@{host}") return exp.DefinerProperty(this=f"{user}@{host}")
def _parse_withjournaltable(self) -> exp.Expression: def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE") self._match(TokenType.TABLE)
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
@ -1319,7 +1312,6 @@ class Parser(metaclass=_Parser):
) )
def _parse_withisolatedloading(self) -> exp.Expression: def _parse_withisolatedloading(self) -> exp.Expression:
self._match(TokenType.WITH)
no = self._match_text_seq("NO") no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT") concurrent = self._match_text_seq("CONCURRENT")
self._match_text_seq("ISOLATED", "LOADING") self._match_text_seq("ISOLATED", "LOADING")
@ -1397,6 +1389,24 @@ class Parser(metaclass=_Parser):
this=self._parse_schema() or self._parse_bracket(self._parse_field()), this=self._parse_schema() or self._parse_bracket(self._parse_field()),
) )
def _parse_withdata(self, no=False) -> exp.Expression:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
statistics = False
else:
statistics = None
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
def _parse_noprimaryindex(self) -> exp.Expression:
self._match_text_seq("PRIMARY", "INDEX")
return exp.NoPrimaryIndexProperty()
def _parse_oncommit(self) -> exp.Expression:
self._match_text_seq("COMMIT", "PRESERVE", "ROWS")
return exp.OnCommitProperty()
def _parse_distkey(self) -> exp.Expression: def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@ -1450,6 +1460,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_temporary(self, global_=False) -> exp.Expression:
self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL"
return self.expression(exp.TemporaryProperty, global_=global_)
def _parse_describe(self) -> exp.Expression: def _parse_describe(self) -> exp.Expression:
kind = self._match_set(self.CREATABLES) and self._prev.text kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table() this = self._parse_table()
@ -2042,6 +2056,9 @@ class Parser(metaclass=_Parser):
if alias: if alias:
this.set("alias", alias) this.set("alias", alias)
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
if self._match_pair(TokenType.WITH, TokenType.L_PAREN): if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set( this.set(
"hints", "hints",
@ -2182,7 +2199,12 @@ class Parser(metaclass=_Parser):
self._match_r_paren() self._match_r_paren()
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
return pivot
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE): if not skip_where_token and not self._match(TokenType.WHERE):
@ -3783,11 +3805,12 @@ class Parser(metaclass=_Parser):
return None return None
def _match_set(self, types): def _match_set(self, types, advance=True):
if not self._curr: if not self._curr:
return None return None
if self._curr.token_type in types: if self._curr.token_type in types:
if advance:
self._advance() self._advance()
return True return True
@ -3816,8 +3839,9 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comments: if expression and self._prev_comments:
expression.comments = self._prev_comments expression.comments = self._prev_comments
def _match_texts(self, texts): def _match_texts(self, texts, advance=True):
if self._curr and self._curr.text.upper() in texts: if self._curr and self._curr.text.upper() in texts:
if advance:
self._advance() self._advance()
return True return True
return False return False

View file

@ -32,6 +32,9 @@ def dump(node: Node) -> JSON:
obj["type"] = node.type.sql() obj["type"] = node.type.sql()
if node.comments: if node.comments:
obj["comments"] = node.comments obj["comments"] = node.comments
if node._meta is not None:
obj["meta"] = node._meta
return obj return obj
return node return node
@ -57,11 +60,9 @@ def load(obj: JSON) -> Node:
klass = getattr(module, class_name) klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()}) expression = klass(**{k: load(v) for k, v in obj["args"].items()})
type_ = obj.get("type") expression.type = obj.get("type")
if type_: expression.comments = obj.get("comments")
expression.type = exp.DataType.build(type_) expression._meta = obj.get("meta")
comments = obj.get("comments")
if comments:
expression.comments = load(comments)
return expression return expression
return obj return obj

View file

@ -115,6 +115,7 @@ class TokenType(AutoName):
IMAGE = auto() IMAGE = auto()
VARIANT = auto() VARIANT = auto()
OBJECT = auto() OBJECT = auto()
INET = auto()
# keywords # keywords
ALIAS = auto() ALIAS = auto()
@ -437,16 +438,8 @@ class Tokenizer(metaclass=_Tokenizer):
_IDENTIFIER_ESCAPES: t.Set[str] = set() _IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = { KEYWORDS = {
**{ **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
f"{key}{postfix}": TokenType.BLOCK_START **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
for key in ("{%", "{#")
for postfix in ("", "+", "-")
},
**{
f"{prefix}{key}": TokenType.BLOCK_END
for key in ("%}", "#}")
for prefix in ("", "+", "-")
},
"{{+": TokenType.BLOCK_START, "{{+": TokenType.BLOCK_START,
"{{-": TokenType.BLOCK_START, "{{-": TokenType.BLOCK_START,
"+}}": TokenType.BLOCK_END, "+}}": TokenType.BLOCK_END,
@ -533,6 +526,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IGNORE NULLS": TokenType.IGNORE_NULLS, "IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN, "IN": TokenType.IN,
"INDEX": TokenType.INDEX, "INDEX": TokenType.INDEX,
"INET": TokenType.INET,
"INNER": TokenType.INNER, "INNER": TokenType.INNER,
"INSERT": TokenType.INSERT, "INSERT": TokenType.INSERT,
"INTERVAL": TokenType.INTERVAL, "INTERVAL": TokenType.INTERVAL,
@ -701,7 +695,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VACUUM": TokenType.COMMAND, "VACUUM": TokenType.COMMAND,
} }
WHITE_SPACE = { WHITE_SPACE: t.Dict[str, TokenType] = {
" ": TokenType.SPACE, " ": TokenType.SPACE,
"\t": TokenType.SPACE, "\t": TokenType.SPACE,
"\n": TokenType.BREAK, "\n": TokenType.BREAK,
@ -723,7 +717,7 @@ class Tokenizer(metaclass=_Tokenizer):
NUMERIC_LITERALS: t.Dict[str, str] = {} NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")] COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE = None # autofilled KEYWORD_TRIE = None # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False IDENTIFIER_CAN_START_WITH_DIGIT = False
@ -778,20 +772,14 @@ class Tokenizer(metaclass=_Tokenizer):
self._start = self._current self._start = self._current
self._advance() self._advance()
if not self._char: if self._char is None:
break break
white_space = self.WHITE_SPACE.get(self._char) # type: ignore if self._char not in self.WHITE_SPACE:
identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore if self._char.isdigit():
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
elif self._char.isdigit(): # type:ignore
self._scan_number() self._scan_number()
elif identifier_end: elif self._char in self._IDENTIFIERS:
self._scan_identifier(identifier_end) self._scan_identifier(self._IDENTIFIERS[self._char])
else: else:
self._scan_keywords() self._scan_keywords()
@ -807,13 +795,23 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end] return self.sql[start:end]
return "" return ""
def _line_break(self, char: t.Optional[str]) -> bool:
return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
def _advance(self, i: int = 1) -> None: def _advance(self, i: int = 1) -> None:
if self._line_break(self._char):
self._set_new_line()
self._col += i self._col += i
self._current += i self._current += i
self._end = self._current >= self.size # type: ignore self._end = self._current >= self.size # type: ignore
self._char = self.sql[self._current - 1] # type: ignore self._char = self.sql[self._current - 1] # type: ignore
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
def _set_new_line(self) -> None:
self._col = 1
self._line += 1
@property @property
def _text(self) -> str: def _text(self) -> str:
return self.sql[self._start : self._current] return self.sql[self._start : self._current]
@ -917,7 +915,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1) self._advance(comment_end_size - 1)
else: else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore while not self._end and not self._line_break(self._peek):
self._advance() self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore self._comments.append(self._text[comment_start_size:]) # type: ignore
@ -926,6 +924,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line: if comment_start_line == self._prev_token_line:
self.tokens[-1].comments.extend(self._comments) self.tokens[-1].comments.extend(self._comments)
self._comments = [] self._comments = []
self._prev_token_line = self._line
return True return True

View file

@ -2,8 +2,7 @@ import datetime
import inspect import inspect
import unittest import unittest
from sqlglot import expressions as exp from sqlglot import expressions as exp, parse_one
from sqlglot import parse_one
from sqlglot.dataframe.sql import functions as SF from sqlglot.dataframe.sql import functions as SF
from sqlglot.errors import ErrorLevel from sqlglot.errors import ErrorLevel

View file

@ -1,8 +1,7 @@
from unittest import mock from unittest import mock
import sqlglot import sqlglot
from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql import functions as F, types
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator

View file

@ -285,6 +285,10 @@ class TestDialect(Validator):
read={"oracle": "CAST(a AS NUMBER)"}, read={"oracle": "CAST(a AS NUMBER)"},
write={"oracle": "CAST(a AS NUMBER)"}, write={"oracle": "CAST(a AS NUMBER)"},
) )
self.validate_all(
"CAST('127.0.0.1/32' AS INET)",
read={"postgres": "INET '127.0.0.1/32'"},
)
def test_if_null(self): def test_if_null(self):
self.validate_all( self.validate_all(
@ -509,7 +513,7 @@ class TestDialect(Validator):
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
}, },
write={ write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')", "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 day", "duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)", "hive": "DATE_ADD(x, 1)",
@ -526,7 +530,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"DATE_ADD(x, 1)", "DATE_ADD(x, 1)",
write={ write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')", "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 DAY", "duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)", "hive": "DATE_ADD(x, 1)",
@ -540,6 +544,7 @@ class TestDialect(Validator):
"DATE_TRUNC('day', x)", "DATE_TRUNC('day', x)",
write={ write={
"mysql": "DATE(x)", "mysql": "DATE(x)",
"snowflake": "DATE_TRUNC('day', x)",
}, },
) )
self.validate_all( self.validate_all(
@ -576,6 +581,7 @@ class TestDialect(Validator):
"DATE_TRUNC('year', x)", "DATE_TRUNC('year', x)",
read={ read={
"bigquery": "DATE_TRUNC(x, year)", "bigquery": "DATE_TRUNC(x, year)",
"snowflake": "DATE_TRUNC(year, x)",
"starrocks": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')", "spark": "TRUNC(x, 'year')",
}, },
@ -583,6 +589,7 @@ class TestDialect(Validator):
"bigquery": "DATE_TRUNC(x, year)", "bigquery": "DATE_TRUNC(x, year)",
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"postgres": "DATE_TRUNC('year', x)", "postgres": "DATE_TRUNC('year', x)",
"snowflake": "DATE_TRUNC('year', x)",
"starrocks": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')", "spark": "TRUNC(x, 'year')",
}, },

View file

@ -397,6 +397,12 @@ class TestSnowflake(Validator):
}, },
) )
self.validate_all(
"CREATE TABLE a (b INT)",
read={"teradata": "CREATE MULTISET TABLE a (b INT)"},
write={"snowflake": "CREATE TABLE a (b INT)"},
)
def test_user_defined_functions(self): def test_user_defined_functions(self):
self.validate_all( self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$", "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",

View file

@ -213,6 +213,13 @@ TBLPROPERTIES (
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_all(
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)",
},
)
self.validate_all( self.validate_all(
"AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
write={ write={

View file

@ -35,6 +35,8 @@ class TestTeradata(Validator):
write={"teradata": "SELECT a FROM b"}, write={"teradata": "SELECT a FROM b"},
) )
self.validate_identity("CREATE VOLATILE TABLE a (b INT)")
def test_insert(self): def test_insert(self):
self.validate_all( self.validate_all(
"INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"} "INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"}

View file

@ -305,6 +305,7 @@ SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50) SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75) SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))
SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1)) SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1))
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d')) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d'))
SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q')) SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q'))
@ -557,10 +558,11 @@ CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZ
CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT) CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT)
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT) CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT) CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT) CREATE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b) CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE TABLE a (b INT) ON COMMIT PRESERVE ROWS
CREATE VIEW x AS SELECT a FROM b CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d

View file

@ -1,6 +1,6 @@
import unittest import unittest
from sqlglot import parse_one from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_identifier from sqlglot.expressions import Join, to_identifier
@ -128,6 +128,33 @@ class TestDiff(unittest.TestCase):
], ],
) )
def test_pre_matchings(self):
expr_src = parse_one("SELECT 1")
expr_tgt = parse_one("SELECT 1, 2, 3, 4")
self._validate_delta_only(
diff(expr_src, expr_tgt),
[
Remove(expr_src),
Insert(expr_tgt),
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
],
)
self._validate_delta_only(
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
[
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
],
)
with self.assertRaises(ValueError):
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])
def _validate_delta_only(self, actual_diff, expected_delta): def _validate_delta_only(self, actual_diff, expected_delta):
actual_delta = _delta_only(actual_diff) actual_delta = _delta_only(actual_diff)
self.assertEqual(set(actual_delta), set(expected_delta)) self.assertEqual(set(actual_delta), set(expected_delta))

View file

@ -91,6 +91,11 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(column.parent_select, exp.Select) self.assertIsInstance(column.parent_select, exp.Select)
self.assertIsNone(column.find_ancestor(exp.Join)) self.assertIsNone(column.find_ancestor(exp.Join))
def test_root(self):
ast = parse_one("select * from (select a from x)")
self.assertIs(ast, ast.root())
self.assertIs(ast, ast.find(exp.Column).root())
def test_alias_or_name(self): def test_alias_or_name(self):
expression = parse_one( expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
@ -767,3 +772,36 @@ FROM foo""",
exp.rename_table("t1", "t2").sql(), exp.rename_table("t1", "t2").sql(),
"ALTER TABLE t1 RENAME TO t2", "ALTER TABLE t1 RENAME TO t2",
) )
def test_is_star(self):
assert parse_one("*").is_star
assert parse_one("foo.*").is_star
assert parse_one("SELECT * FROM foo").is_star
assert parse_one("(SELECT * FROM foo)").is_star
assert parse_one("SELECT *, 1 FROM foo").is_star
assert parse_one("SELECT foo.* FROM foo").is_star
assert parse_one("SELECT * EXCEPT (a, b) FROM foo").is_star
assert parse_one("SELECT foo.* EXCEPT (foo.a, foo.b) FROM foo").is_star
assert parse_one("SELECT * REPLACE (a AS b, b AS C)").is_star
assert parse_one("SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)").is_star
assert parse_one("SELECT * INTO newevent FROM event").is_star
assert parse_one("SELECT * FROM foo UNION SELECT * FROM bar").is_star
assert parse_one("SELECT * FROM bla UNION SELECT 1 AS x").is_star
assert parse_one("SELECT 1 AS x UNION SELECT * FROM bla").is_star
assert parse_one("SELECT 1 AS x UNION SELECT 1 AS x UNION SELECT * FROM foo").is_star
def test_set_metadata(self):
ast = parse_one("SELECT foo.col FROM foo")
self.assertIsNone(ast._meta)
# calling ast.meta would lazily instantiate self._meta
self.assertEqual(ast.meta, {})
self.assertEqual(ast._meta, {})
ast.meta["some_meta_key"] = "some_meta_value"
self.assertEqual(ast.meta.get("some_meta_key"), "some_meta_value")
self.assertEqual(ast.meta.get("some_other_meta_key"), None)
ast.meta["some_other_meta_key"] = "some_other_meta_value"
self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value")

View file

@ -31,3 +31,9 @@ class TestSerDe(unittest.TestCase):
after = self.dump_load(before) after = self.dump_load(before)
self.assertEqual(before.type, after.type) self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type) self.assertEqual(before.this.type, after.this.type)
def test_meta(self):
before = parse_one("SELECT * FROM X")
before.meta["x"] = 1
after = self.dump_load(before)
self.assertEqual(before.meta, after.meta)

View file

@ -18,6 +18,18 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment: for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
def test_token_line(self):
tokens = Tokenizer().tokenize(
"""SELECT /*
line break
*/
'x
y',
x"""
)
self.assertEqual(tokens[-1].line, 6)
def test_jinja(self): def test_jinja(self):
tokenizer = Tokenizer() tokenizer = Tokenizer()
@ -26,6 +38,7 @@ class TestTokens(unittest.TestCase):
SELECT SELECT
{{ x }}, {{ x }},
{{- x -}}, {{- x -}},
{# it's a comment #}
{% for x in y -%} {% for x in y -%}
a {{+ b }} a {{+ b }}
{% endfor %}; {% endfor %};

View file

@ -28,6 +28,11 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
self.assertEqual(
transpile("SELECT 1 FROM a.b.table1 t UNPIVOT((c3) FOR c4 IN (a, b))")[0],
"SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))",
)
for key in ("union", "over", "from", "join"): for key in ("union", "over", "from", "join"):
with self.subTest(f"alias {key}"): with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")