1
0
Fork 0

Merging upstream version 11.2.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:40:23 +01:00
parent c6f7c6bbe1
commit 428b7dd76f
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 33054 additions and 31671 deletions

View file

@ -9,6 +9,7 @@ repos:
require_serial: true
files: ^(sqlglot/|tests/|setup.py)
- id: isort
args: [--combine-as]
name: isort
entry: isort
language: system

View file

@ -71,6 +71,8 @@ Changes:
- Breaking: Change Power to binary expression.
- Breaking: Removed mapping of "}}" to BLOCK_END token.
- New: x GLOB y support.
v10.5.0

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -9,38 +9,45 @@ from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError
from sqlglot.expressions import Expression
from sqlglot.expressions import alias_ as alias
from sqlglot.expressions import (
and_,
column,
condition,
except_,
from_,
intersect,
maybe_parse,
not_,
or_,
select,
subquery,
from sqlglot.dialects.dialect import Dialect as Dialect, Dialects as Dialects
from sqlglot.diff import diff as diff
from sqlglot.errors import (
ErrorLevel as ErrorLevel,
ParseError as ParseError,
TokenError as TokenError,
UnsupportedError as UnsupportedError,
)
from sqlglot.expressions import table_ as table
from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.expressions import (
Expression as Expression,
alias_ as alias,
and_ as and_,
column as column,
condition as condition,
except_ as except_,
from_ as from_,
intersect as intersect,
maybe_parse as maybe_parse,
not_ as not_,
or_ as or_,
select as select,
subquery as subquery,
table_ as table,
to_column as to_column,
to_table as to_table,
union as union,
)
from sqlglot.generator import Generator as Generator
from sqlglot.parser import Parser as Parser
from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema
from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects.dialect import DialectType as DialectType
T = t.TypeVar("T", bound=Expression)
__version__ = "11.2.0"
__version__ = "11.2.3"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -4,8 +4,7 @@ import typing as t
from sqlglot import exp as expression
from sqlglot.dataframe.sql.column import Column
from sqlglot.helper import ensure_list
from sqlglot.helper import flatten as _flatten
from sqlglot.helper import ensure_list, flatten as _flatten
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName

View file

@ -38,7 +38,10 @@ def _date_add_sql(
) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression):
this = self.sql(expression, "this")
return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression, unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
@ -235,6 +238,7 @@ class BigQuery(Dialect):
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,

View file

@ -462,6 +462,11 @@ class MySQL(Dialect):
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""

View file

@ -318,3 +318,8 @@ class Postgres(Dialect):
if isinstance(seq_get(e.expressions, 0), exp.Select)
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}

View file

@ -150,6 +150,10 @@ class Snowflake(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
@ -215,7 +219,6 @@ class Snowflake(Dialect):
}
class Generator(generator.Generator):
CREATE_TRANSIENT = True
PARAMETER_TOKEN = "$"
TRANSFORMS = {
@ -252,6 +255,11 @@ class Snowflake(Dialect):
"replace": "RENAME",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
}
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY")

View file

@ -8,9 +8,12 @@ from sqlglot.helper import seq_get
def _create_sql(self, e):
kind = e.args.get("kind")
temporary = e.args.get("temporary")
properties = e.args.get("properties")
if kind.upper() == "TABLE" and temporary is True:
if kind.upper() == "TABLE" and any(
isinstance(prop, exp.TemporaryProperty)
for prop in (properties.expressions if properties else [])
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)

View file

@ -114,6 +114,7 @@ class Teradata(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
exp.VolatilityProperty: exp.Properties.Location.POST_CREATE,
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:

View file

@ -11,8 +11,7 @@ from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_collection
@ -58,7 +57,12 @@ if t.TYPE_CHECKING:
Edit = t.Union[Insert, Remove, Move, Update, Keep]
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
def diff(
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
**kwargs: t.Any,
) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@ -80,13 +84,38 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
Args:
source: the source expression.
target: the target expression against which the diff should be calculated.
matchings: the list of pre-matched node pairs which is used to help the algorithm's
heuristics produce better results for subtrees that are known by a caller to be matching.
Note: expression references in this list must refer to the same node objects that are
referenced in source / target trees.
Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
target expression trees. This list represents a sequence of steps needed to transform the source
expression tree into the target one.
"""
return ChangeDistiller().diff(source.copy(), target.copy())
matchings = matchings or []
matching_ids = {id(n) for pair in matchings for n in pair}
def compute_node_mappings(
original: exp.Expression, copy: exp.Expression
) -> t.Dict[int, exp.Expression]:
return {
id(old_node): new_node
for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
if id(old_node) in matching_ids
}
source_copy = source.copy()
target_copy = target.copy()
node_mappings = {
**compute_node_mappings(source, source_copy),
**compute_node_mappings(target, target_copy),
}
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
LEAF_EXPRESSION_TYPES = (
@ -109,16 +138,26 @@ class ChangeDistiller:
self.t = t
self._sql_generator = Dialect().generator()
def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
def diff(
self,
source: exp.Expression,
target: exp.Expression,
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
) -> t.List[Edit]:
matchings = matchings or []
pre_matched_nodes = {id(s): id(t) for s, t in matchings}
if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings):
raise ValueError("Each node can be referenced at most once in the list of matchings")
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set()
matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:

View file

@ -82,7 +82,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@ -90,6 +90,7 @@ class Expression(metaclass=_Expression):
self.arg_key: t.Optional[str] = None
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -219,10 +220,23 @@ class Expression(metaclass=_Expression):
dtype = DataType.build(dtype)
self._type = dtype # type: ignore
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
self._meta = {}
return self._meta
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments
copy.type = self.type
if self.comments is not None:
copy.comments = deepcopy(self.comments)
if self._type is not None:
copy._type = self._type.copy()
if self._meta is not None:
copy._meta = deepcopy(self._meta)
return copy
def copy(self):
@ -329,6 +343,15 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
def root(self) -> Expression:
"""
Returns the root expression of this tree.
"""
expression = self
while expression.parent:
expression = expression.parent
return expression
def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
@ -767,21 +790,10 @@ class Create(Expression):
"this": True,
"kind": True,
"expression": False,
"set": False,
"multiset": False,
"global_temporary": False,
"volatile": False,
"exists": False,
"properties": False,
"temporary": False,
"transient": False,
"external": False,
"replace": False,
"unique": False,
"materialized": False,
"data": False,
"statistics": False,
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@ -1336,42 +1348,92 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
class AfterJournalProperty(Property):
arg_types = {"no": True, "dual": False, "local": False}
class AlgorithmProperty(Property):
arg_types = {"this": True}
class AutoIncrementProperty(Property):
arg_types = {"this": True}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
class CharacterSetProperty(Property):
arg_types = {"this": True, "default": True}
class ChecksumProperty(Property):
arg_types = {"on": False, "default": False}
class CollateProperty(Property):
arg_types = {"this": True}
class DataBlocksizeProperty(Property):
arg_types = {"size": False, "units": False, "min": False, "default": False}
class DefinerProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
class PartitionedByProperty(Property):
arg_types = {"this": True}
class FileFormatProperty(Property):
arg_types = {"this": True}
class DistKeyProperty(Property):
arg_types = {"this": True}
class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class DistStyleProperty(Property):
arg_types = {"this": True}
class EngineProperty(Property):
arg_types = {"this": True}
class ExecuteAsProperty(Property):
arg_types = {"this": True}
class ExternalProperty(Property):
arg_types = {"this": False}
class FallbackProperty(Property):
arg_types = {"no": True, "protection": False}
class FileFormatProperty(Property):
arg_types = {"this": True}
class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
}
class JournalProperty(Property):
arg_types = {"no": True, "dual": False, "before": False}
class LanguageProperty(Property):
arg_types = {"this": True}
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
@ -1380,23 +1442,37 @@ class LocationProperty(Property):
arg_types = {"this": True}
class EngineProperty(Property):
arg_types = {"this": True}
class LockingProperty(Property):
arg_types = {
"this": False,
"kind": True,
"for_or_in": True,
"lock_type": True,
"override": False,
}
class AutoIncrementProperty(Property):
arg_types = {"this": True}
class LogProperty(Property):
arg_types = {"no": True}
class CharacterSetProperty(Property):
arg_types = {"this": True, "default": True}
class MaterializedProperty(Property):
arg_types = {"this": False}
class CollateProperty(Property):
arg_types = {"this": True}
class MergeBlockRatioProperty(Property):
arg_types = {"this": False, "no": False, "default": False, "percent": False}
class SchemaCommentProperty(Property):
class NoPrimaryIndexProperty(Property):
arg_types = {"this": False}
class OnCommitProperty(Property):
arg_type = {"this": False}
class PartitionedByProperty(Property):
arg_types = {"this": True}
@ -1404,18 +1480,6 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property):
arg_types = {"this": True}
class ExecuteAsProperty(Property):
arg_types = {"this": True}
class VolatilityProperty(Property):
arg_types = {"this": True}
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
@ -1433,70 +1497,50 @@ class RowFormatSerdeProperty(Property):
arg_types = {"this": True}
class SchemaCommentProperty(Property):
arg_types = {"this": True}
class SerdeProperties(Property):
arg_types = {"expressions": True}
class FallbackProperty(Property):
arg_types = {"no": True, "protection": False}
class SetProperty(Property):
arg_types = {"multi": True}
class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
class TemporaryProperty(Property):
arg_types = {"global_": True}
class TransientProperty(Property):
arg_types = {"this": False}
class VolatilityProperty(Property):
arg_types = {"this": True}
class WithDataProperty(Property):
arg_types = {"no": True, "statistics": False}
class WithJournalTableProperty(Property):
arg_types = {"this": True}
class LogProperty(Property):
arg_types = {"no": True}
class JournalProperty(Property):
arg_types = {"no": True, "dual": False, "before": False}
class AfterJournalProperty(Property):
arg_types = {"no": True, "dual": False, "local": False}
class ChecksumProperty(Property):
arg_types = {"on": False, "default": False}
class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class MergeBlockRatioProperty(Property):
arg_types = {"this": False, "no": False, "default": False, "percent": False}
class DataBlocksizeProperty(Property):
arg_types = {"size": False, "units": False, "min": False, "default": False}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
}
class LockingProperty(Property):
arg_types = {
"this": False,
"kind": True,
"for_or_in": True,
"lock_type": True,
"override": False,
}
class Properties(Expression):
arg_types = {"expressions": True}
@ -1533,7 +1577,7 @@ class Properties(Expression):
# Form: alias selection
# create [POST_CREATE]
# table a [POST_NAME]
# as [POST_ALIAS] (select * from b)
# as [POST_ALIAS] (select * from b) [POST_EXPRESSION]
# index (c) [POST_INDEX]
class Location(AutoName):
POST_CREATE = auto()
@ -1541,6 +1585,7 @@ class Properties(Expression):
POST_SCHEMA = auto()
POST_WITH = auto()
POST_ALIAS = auto()
POST_EXPRESSION = auto()
POST_INDEX = auto()
UNSUPPORTED = auto()
@ -1797,6 +1842,10 @@ class Union(Subqueryable):
def named_selects(self):
return self.this.unnest().named_selects
@property
def is_star(self) -> bool:
return self.this.is_star or self.expression.is_star
@property
def selects(self):
return self.this.unnest().selects
@ -2424,6 +2473,10 @@ class Select(Subqueryable):
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@property
def is_star(self) -> bool:
return any(expression.is_star for expression in self.expressions)
@property
def selects(self) -> t.List[Expression]:
return self.expressions
@ -2446,6 +2499,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
@property
def is_star(self) -> bool:
return self.this.is_star
@property
def output_name(self):
return self.alias
@ -2478,6 +2535,7 @@ class Tag(Expression):
class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
"field": True,
"unpivot": True,
@ -2603,6 +2661,7 @@ class DataType(Expression):
IMAGE = auto()
VARIANT = auto()
OBJECT = auto()
INET = auto()
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation

View file

@ -64,15 +64,22 @@ class Generator:
"TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
),
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
@ -87,9 +94,6 @@ class Generator:
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
@ -112,6 +116,7 @@ class Generator:
exp.DataType.Type.LONGTEXT: "TEXT",
exp.DataType.Type.MEDIUMBLOB: "BLOB",
exp.DataType.Type.LONGBLOB: "BLOB",
exp.DataType.Type.INET: "INET",
}
STAR_MAPPING = {
@ -140,6 +145,7 @@ class Generator:
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExternalProperty: exp.Properties.Location.POST_CREATE,
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
@ -150,7 +156,10 @@ class Generator:
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
exp.LogProperty: exp.Properties.Location.POST_NAME,
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
@ -158,10 +167,14 @@ class Generator:
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
@ -537,34 +550,9 @@ class Generator:
else:
expression_sql = f" AS{expression_sql}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
external = " EXTERNAL" if expression.args.get("external") else ""
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
set_ = " SET" if expression.args.get("set") else ""
multiset = " MULTISET" if expression.args.get("multiset") else ""
global_temporary = " GLOBAL TEMPORARY" if expression.args.get("global_temporary") else ""
volatile = " VOLATILE" if expression.args.get("volatile") else ""
data = expression.args.get("data")
if data is None:
data = ""
elif data:
data = " WITH DATA"
else:
data = " WITH NO DATA"
statistics = expression.args.get("statistics")
if statistics is None:
statistics = ""
elif statistics:
statistics = " AND STATISTICS"
else:
statistics = " AND NO STATISTICS"
no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
indexes = expression.args.get("indexes")
index_sql = ""
@ -605,28 +593,24 @@ class Generator:
wrapped=False,
)
modifiers = "".join(
(
replace,
temporary,
transient,
external,
unique,
materialized,
set_,
multiset,
global_temporary,
volatile,
postcreate_props_sql,
)
modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
postexpression_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION]
),
sep=" ",
prefix=" ",
wrapped=False,
)
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@ -810,6 +794,8 @@ class Generator:
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
elif p_loc == exp.Properties.Location.POST_ALIAS:
properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
elif p_loc == exp.Properties.Location.POST_EXPRESSION:
properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p)
elif p_loc == exp.Properties.Location.UNSUPPORTED:
self.unsupported(f"Unsupported property {p.key}")
@ -931,6 +917,14 @@ class Generator:
override = " OVERRIDE" if expression.args.get("override") else ""
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
statistics = expression.args.get("statistics")
statistics_sql = ""
if statistics is not None:
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
return f"{data_sql}{statistics_sql}"
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
@ -1003,10 +997,6 @@ class Generator:
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
if alias and pivots:
pivots = f"{pivots}{alias}"
alias = ""
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str:
@ -1034,11 +1024,13 @@ class Generator:
def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
return f"{this} {direction}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"

View file

@ -144,6 +144,7 @@ class Parser(metaclass=_Parser):
TokenType.IMAGE,
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.INET,
*NESTED_TYPE_TOKENS,
}
@ -509,73 +510,82 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS = {
"AFTER": lambda self: self._parse_afterjournal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BEFORE": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self.expression(
exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"SORTKEY": lambda self: self._parse_sortkey(),
"LIKE": lambda self: self._parse_create_like(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"DATABLOCKSIZE": lambda self: self._parse_datablocksize(
default=self._prev.text.upper() == "DEFAULT"
),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"EXTERNAL": lambda self: self.expression(exp.ExternalProperty),
"FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
"GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
"WITH": lambda self: self._parse_with_property(),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
"LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
"BEFORE": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"AFTER": lambda self: self._parse_afterjournal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LIKE": lambda self: self._parse_create_like(),
"LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True),
"NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
"CHECKSUM": lambda self: self._parse_checksum(),
"FREESPACE": lambda self: self._parse_freespace(),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"LOCK": lambda self: self._parse_locking(),
"LOCKING": lambda self: self._parse_locking(),
"LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
"MAX": lambda self: self._parse_datablocksize(),
"MAXIMUM": lambda self: self._parse_datablocksize(),
"MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio(
no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT"
),
"MIN": lambda self: self._parse_datablocksize(),
"MINIMUM": lambda self: self._parse_datablocksize(),
"MAX": lambda self: self._parse_datablocksize(),
"MAXIMUM": lambda self: self._parse_datablocksize(),
"DATABLOCKSIZE": lambda self: self._parse_datablocksize(
default=self._prev.text.upper() == "DEFAULT"
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
"NO": lambda self: self._parse_noprimaryindex(),
"NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
"ON": lambda self: self._parse_oncommit(),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"DEFINER": lambda self: self._parse_definer(),
"LOCK": lambda self: self._parse_locking(),
"LOCKING": lambda self: self._parse_locking(),
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"TEMPORARY": lambda self: self._parse_temporary(global_=False),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
"WITH": lambda self: self._parse_with_property(),
}
CONSTRAINT_PARSERS = {
@ -979,15 +989,7 @@ class Parser(metaclass=_Parser):
replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata
volatile = self._match(TokenType.VOLATILE) # Teradata
temporary = self._match(TokenType.TEMPORARY)
transient = self._match_text_seq("TRANSIENT")
external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@ -1005,16 +1007,17 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists(not_=True)
this = None
expression = None
data = None
statistics = None
no_primary_index = None
indexes = None
no_schema_binding = None
begin = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties()
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
@ -1036,7 +1039,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.COMMA):
temp_properties = self._parse_properties(before=True)
if properties and temp_properties:
properties.expressions.append(temp_properties.expressions)
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
@ -1045,7 +1048,7 @@ class Parser(metaclass=_Parser):
# exp.Properties.Location.POST_SCHEMA and POST_WITH
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.append(temp_properties.expressions)
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
@ -1059,24 +1062,19 @@ class Parser(metaclass=_Parser):
):
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.append(temp_properties.expressions)
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
if self._match_text_seq("WITH", "DATA"):
data = True
elif self._match_text_seq("WITH", "NO", "DATA"):
data = False
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
statistics = False
no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX")
# exp.Properties.Location.POST_EXPRESSION
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
indexes = []
while True:
@ -1086,7 +1084,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.PARTITION_BY, advance=False):
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.append(temp_properties.expressions)
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
@ -1102,22 +1100,11 @@ class Parser(metaclass=_Parser):
exp.Create,
this=this,
kind=create_token.text,
unique=unique,
expression=expression,
set=set_,
multiset=multiset,
global_temporary=global_temporary,
volatile=volatile,
exists=exists,
properties=properties,
temporary=temporary,
transient=transient,
external=external,
replace=replace,
unique=unique,
materialized=materialized,
data=data,
statistics=statistics,
no_primary_index=no_primary_index,
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
@ -1196,15 +1183,21 @@ class Parser(metaclass=_Parser):
def _parse_with_property(
self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
self._match(TokenType.WITH)
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
if self._match_text_seq("JOURNAL"):
return self._parse_withjournaltable()
if self._match_text_seq("DATA"):
return self._parse_withdata(no=False)
elif self._match_text_seq("NO", "DATA"):
return self._parse_withdata(no=True)
if not self._next:
return None
if self._next.text.upper() == "JOURNAL":
return self._parse_withjournaltable()
return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
@ -1221,7 +1214,7 @@ class Parser(metaclass=_Parser):
return exp.DefinerProperty(this=f"{user}@{host}")
def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE")
self._match(TokenType.TABLE)
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
@ -1319,7 +1312,6 @@ class Parser(metaclass=_Parser):
)
def _parse_withisolatedloading(self) -> exp.Expression:
self._match(TokenType.WITH)
no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT")
self._match_text_seq("ISOLATED", "LOADING")
@ -1397,6 +1389,24 @@ class Parser(metaclass=_Parser):
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_withdata(self, no=False) -> exp.Expression:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
statistics = False
else:
statistics = None
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
def _parse_noprimaryindex(self) -> exp.Expression:
self._match_text_seq("PRIMARY", "INDEX")
return exp.NoPrimaryIndexProperty()
def _parse_oncommit(self) -> exp.Expression:
self._match_text_seq("COMMIT", "PRESERVE", "ROWS")
return exp.OnCommitProperty()
def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@ -1450,6 +1460,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_temporary(self, global_=False) -> exp.Expression:
self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL"
return self.expression(exp.TemporaryProperty, global_=global_)
def _parse_describe(self) -> exp.Expression:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
@ -2042,6 +2056,9 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set(
"hints",
@ -2182,7 +2199,12 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
return pivot
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
@ -3783,11 +3805,12 @@ class Parser(metaclass=_Parser):
return None
def _match_set(self, types):
def _match_set(self, types, advance=True):
if not self._curr:
return None
if self._curr.token_type in types:
if advance:
self._advance()
return True
@ -3816,8 +3839,9 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comments:
expression.comments = self._prev_comments
def _match_texts(self, texts):
def _match_texts(self, texts, advance=True):
if self._curr and self._curr.text.upper() in texts:
if advance:
self._advance()
return True
return False

View file

@ -32,6 +32,9 @@ def dump(node: Node) -> JSON:
obj["type"] = node.type.sql()
if node.comments:
obj["comments"] = node.comments
if node._meta is not None:
obj["meta"] = node._meta
return obj
return node
@ -57,11 +60,9 @@ def load(obj: JSON) -> Node:
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
type_ = obj.get("type")
if type_:
expression.type = exp.DataType.build(type_)
comments = obj.get("comments")
if comments:
expression.comments = load(comments)
expression.type = obj.get("type")
expression.comments = obj.get("comments")
expression._meta = obj.get("meta")
return expression
return obj

View file

@ -115,6 +115,7 @@ class TokenType(AutoName):
IMAGE = auto()
VARIANT = auto()
OBJECT = auto()
INET = auto()
# keywords
ALIAS = auto()
@ -437,16 +438,8 @@ class Tokenizer(metaclass=_Tokenizer):
_IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
for key in ("{%", "{#")
for postfix in ("", "+", "-")
},
**{
f"{prefix}{key}": TokenType.BLOCK_END
for key in ("%}", "#}")
for prefix in ("", "+", "-")
},
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
"{{+": TokenType.BLOCK_START,
"{{-": TokenType.BLOCK_START,
"+}}": TokenType.BLOCK_END,
@ -533,6 +526,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
"INET": TokenType.INET,
"INNER": TokenType.INNER,
"INSERT": TokenType.INSERT,
"INTERVAL": TokenType.INTERVAL,
@ -701,7 +695,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VACUUM": TokenType.COMMAND,
}
WHITE_SPACE = {
WHITE_SPACE: t.Dict[str, TokenType] = {
" ": TokenType.SPACE,
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
@ -723,7 +717,7 @@ class Tokenizer(metaclass=_Tokenizer):
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE = None # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
@ -778,20 +772,14 @@ class Tokenizer(metaclass=_Tokenizer):
self._start = self._current
self._advance()
if not self._char:
if self._char is None:
break
white_space = self.WHITE_SPACE.get(self._char) # type: ignore
identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
elif self._char.isdigit(): # type:ignore
if self._char not in self.WHITE_SPACE:
if self._char.isdigit():
self._scan_number()
elif identifier_end:
self._scan_identifier(identifier_end)
elif self._char in self._IDENTIFIERS:
self._scan_identifier(self._IDENTIFIERS[self._char])
else:
self._scan_keywords()
@ -807,13 +795,23 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
def _line_break(self, char: t.Optional[str]) -> bool:
return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
def _advance(self, i: int = 1) -> None:
if self._line_break(self._char):
self._set_new_line()
self._col += i
self._current += i
self._end = self._current >= self.size # type: ignore
self._char = self.sql[self._current - 1] # type: ignore
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
def _set_new_line(self) -> None:
self._col = 1
self._line += 1
@property
def _text(self) -> str:
return self.sql[self._start : self._current]
@ -917,7 +915,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
while not self._end and not self._line_break(self._peek):
self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore
@ -926,6 +924,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self._prev_token_line = self._line
return True

View file

@ -2,8 +2,7 @@ import datetime
import inspect
import unittest
from sqlglot import expressions as exp
from sqlglot import parse_one
from sqlglot import expressions as exp, parse_one
from sqlglot.dataframe.sql import functions as SF
from sqlglot.errors import ErrorLevel

View file

@ -1,8 +1,7 @@
from unittest import mock
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql import functions as F, types
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator

View file

@ -285,6 +285,10 @@ class TestDialect(Validator):
read={"oracle": "CAST(a AS NUMBER)"},
write={"oracle": "CAST(a AS NUMBER)"},
)
self.validate_all(
"CAST('127.0.0.1/32' AS INET)",
read={"postgres": "INET '127.0.0.1/32'"},
)
def test_if_null(self):
self.validate_all(
@ -509,7 +513,7 @@ class TestDialect(Validator):
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
@ -526,7 +530,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(x, 1)",
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
@ -540,6 +544,7 @@ class TestDialect(Validator):
"DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
"snowflake": "DATE_TRUNC('day', x)",
},
)
self.validate_all(
@ -576,6 +581,7 @@ class TestDialect(Validator):
"DATE_TRUNC('year', x)",
read={
"bigquery": "DATE_TRUNC(x, year)",
"snowflake": "DATE_TRUNC(year, x)",
"starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
},
@ -583,6 +589,7 @@ class TestDialect(Validator):
"bigquery": "DATE_TRUNC(x, year)",
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"postgres": "DATE_TRUNC('year', x)",
"snowflake": "DATE_TRUNC('year', x)",
"starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
},

View file

@ -397,6 +397,12 @@ class TestSnowflake(Validator):
},
)
self.validate_all(
"CREATE TABLE a (b INT)",
read={"teradata": "CREATE MULTISET TABLE a (b INT)"},
write={"snowflake": "CREATE TABLE a (b INT)"},
)
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",

View file

@ -213,6 +213,13 @@ TBLPROPERTIES (
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_all(
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)",
},
)
self.validate_all(
"AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
write={

View file

@ -35,6 +35,8 @@ class TestTeradata(Validator):
write={"teradata": "SELECT a FROM b"},
)
self.validate_identity("CREATE VOLATILE TABLE a (b INT)")
def test_insert(self):
self.validate_all(
"INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"}

View file

@ -305,6 +305,7 @@ SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q'))
SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))
SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1))
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d'))
SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q'))
@ -557,10 +558,11 @@ CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZ
CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT)
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE TABLE a (b INT) ON COMMIT PRESERVE ROWS
CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d

View file

@ -1,6 +1,6 @@
import unittest
from sqlglot import parse_one
from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_identifier
@ -128,6 +128,33 @@ class TestDiff(unittest.TestCase):
],
)
def test_pre_matchings(self):
expr_src = parse_one("SELECT 1")
expr_tgt = parse_one("SELECT 1, 2, 3, 4")
self._validate_delta_only(
diff(expr_src, expr_tgt),
[
Remove(expr_src),
Insert(expr_tgt),
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
],
)
self._validate_delta_only(
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
[
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
],
)
with self.assertRaises(ValueError):
diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])
def _validate_delta_only(self, actual_diff, expected_delta):
actual_delta = _delta_only(actual_diff)
self.assertEqual(set(actual_delta), set(expected_delta))

View file

@ -91,6 +91,11 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(column.parent_select, exp.Select)
self.assertIsNone(column.find_ancestor(exp.Join))
def test_root(self):
ast = parse_one("select * from (select a from x)")
self.assertIs(ast, ast.root())
self.assertIs(ast, ast.find(exp.Column).root())
def test_alias_or_name(self):
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
@ -767,3 +772,36 @@ FROM foo""",
exp.rename_table("t1", "t2").sql(),
"ALTER TABLE t1 RENAME TO t2",
)
def test_is_star(self):
assert parse_one("*").is_star
assert parse_one("foo.*").is_star
assert parse_one("SELECT * FROM foo").is_star
assert parse_one("(SELECT * FROM foo)").is_star
assert parse_one("SELECT *, 1 FROM foo").is_star
assert parse_one("SELECT foo.* FROM foo").is_star
assert parse_one("SELECT * EXCEPT (a, b) FROM foo").is_star
assert parse_one("SELECT foo.* EXCEPT (foo.a, foo.b) FROM foo").is_star
assert parse_one("SELECT * REPLACE (a AS b, b AS C)").is_star
assert parse_one("SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)").is_star
assert parse_one("SELECT * INTO newevent FROM event").is_star
assert parse_one("SELECT * FROM foo UNION SELECT * FROM bar").is_star
assert parse_one("SELECT * FROM bla UNION SELECT 1 AS x").is_star
assert parse_one("SELECT 1 AS x UNION SELECT * FROM bla").is_star
assert parse_one("SELECT 1 AS x UNION SELECT 1 AS x UNION SELECT * FROM foo").is_star
def test_set_metadata(self):
ast = parse_one("SELECT foo.col FROM foo")
self.assertIsNone(ast._meta)
# calling ast.meta would lazily instantiate self._meta
self.assertEqual(ast.meta, {})
self.assertEqual(ast._meta, {})
ast.meta["some_meta_key"] = "some_meta_value"
self.assertEqual(ast.meta.get("some_meta_key"), "some_meta_value")
self.assertEqual(ast.meta.get("some_other_meta_key"), None)
ast.meta["some_other_meta_key"] = "some_other_meta_value"
self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value")

View file

@ -31,3 +31,9 @@ class TestSerDe(unittest.TestCase):
after = self.dump_load(before)
self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type)
def test_meta(self):
before = parse_one("SELECT * FROM X")
before.meta["x"] = 1
after = self.dump_load(before)
self.assertEqual(before.meta, after.meta)

View file

@ -18,6 +18,18 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
def test_token_line(self):
tokens = Tokenizer().tokenize(
"""SELECT /*
line break
*/
'x
y',
x"""
)
self.assertEqual(tokens[-1].line, 6)
def test_jinja(self):
tokenizer = Tokenizer()
@ -26,6 +38,7 @@ class TestTokens(unittest.TestCase):
SELECT
{{ x }},
{{- x -}},
{# it's a comment #}
{% for x in y -%}
a {{+ b }}
{% endfor %};

View file

@ -28,6 +28,11 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
self.assertEqual(
transpile("SELECT 1 FROM a.b.table1 t UNPIVOT((c3) FOR c4 IN (a, b))")[0],
"SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))",
)
for key in ("union", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")