Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
|
@ -1,5 +1,9 @@
|
|||
"""## Python SQL parser, transpiler and optimizer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dialects import Dialect, Dialects
|
||||
from sqlglot.diff import diff
|
||||
|
@ -20,51 +24,54 @@ from sqlglot.expressions import (
|
|||
subquery,
|
||||
)
|
||||
from sqlglot.expressions import table_ as table
|
||||
from sqlglot.expressions import union
|
||||
from sqlglot.expressions import to_column, to_table, union
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "9.0.6"
|
||||
__version__ = "10.0.1"
|
||||
|
||||
pretty = False
|
||||
|
||||
schema = MappingSchema()
|
||||
|
||||
|
||||
def parse(sql, read=None, **opts):
|
||||
def parse(
|
||||
sql: str, read: t.Optional[str | Dialect] = None, **opts
|
||||
) -> t.List[t.Optional[Expression]]:
|
||||
"""
|
||||
Parses the given SQL string into a collection of syntax trees, one per
|
||||
parsed SQL statement.
|
||||
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
||||
|
||||
Args:
|
||||
sql (str): the SQL code string to parse.
|
||||
read (str): the SQL dialect to apply during parsing
|
||||
(eg. "spark", "hive", "presto", "mysql").
|
||||
sql: the SQL code string to parse.
|
||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||
**opts: other options.
|
||||
|
||||
Returns:
|
||||
typing.List[Expression]: the list of parsed syntax trees.
|
||||
The resulting syntax tree collection.
|
||||
"""
|
||||
dialect = Dialect.get_or_raise(read)()
|
||||
return dialect.parse(sql, **opts)
|
||||
|
||||
|
||||
def parse_one(sql, read=None, into=None, **opts):
|
||||
def parse_one(
|
||||
sql: str,
|
||||
read: t.Optional[str | Dialect] = None,
|
||||
into: t.Optional[Expression | str] = None,
|
||||
**opts,
|
||||
) -> t.Optional[Expression]:
|
||||
"""
|
||||
Parses the given SQL string and returns a syntax tree for the first
|
||||
parsed SQL statement.
|
||||
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
|
||||
|
||||
Args:
|
||||
sql (str): the SQL code string to parse.
|
||||
read (str): the SQL dialect to apply during parsing
|
||||
(eg. "spark", "hive", "presto", "mysql").
|
||||
into (Expression): the SQLGlot Expression to parse into
|
||||
sql: the SQL code string to parse.
|
||||
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||
into: the SQLGlot Expression to parse into.
|
||||
**opts: other options.
|
||||
|
||||
Returns:
|
||||
Expression: the syntax tree for the first parsed statement.
|
||||
The syntax tree for the first parsed statement.
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(read)()
|
||||
|
@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts):
|
|||
return result[0] if result else None
|
||||
|
||||
|
||||
def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
|
||||
def transpile(
|
||||
sql: str,
|
||||
read: t.Optional[str | Dialect] = None,
|
||||
write: t.Optional[str | Dialect] = None,
|
||||
identity: bool = True,
|
||||
error_level: t.Optional[ErrorLevel] = None,
|
||||
**opts,
|
||||
) -> t.List[str]:
|
||||
"""
|
||||
Parses the given SQL string using the source dialect and returns a list of SQL strings
|
||||
transformed to conform to the target dialect. Each string in the returned list represents
|
||||
a single transformed SQL statement.
|
||||
Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
|
||||
to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.
|
||||
|
||||
Args:
|
||||
sql (str): the SQL code string to transpile.
|
||||
read (str): the source dialect used to parse the input string
|
||||
(eg. "spark", "hive", "presto", "mysql").
|
||||
write (str): the target dialect into which the input should be transformed
|
||||
(eg. "spark", "hive", "presto", "mysql").
|
||||
identity (bool): if set to True and if the target dialect is not specified
|
||||
the source dialect will be used as both: the source and the target dialect.
|
||||
error_level (ErrorLevel): the desired error level of the parser.
|
||||
sql: the SQL code string to transpile.
|
||||
read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
|
||||
write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
|
||||
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
|
||||
the source and the target dialect.
|
||||
error_level: the desired error level of the parser.
|
||||
**opts: other options.
|
||||
|
||||
Returns:
|
||||
typing.List[str]: the list of transpiled SQL statements / expressions.
|
||||
The list of transpiled SQL statements.
|
||||
"""
|
||||
write = write or read if identity else write
|
||||
return [
|
||||
|
|
|
@ -49,7 +49,10 @@ args = parser.parse_args()
|
|||
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
||||
|
||||
if args.parse:
|
||||
sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
|
||||
sqls = [
|
||||
repr(expression)
|
||||
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
|
||||
]
|
||||
else:
|
||||
sqls = sqlglot.transpile(
|
||||
args.sql,
|
||||
|
|
|
@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.dataframe.sql.types import StructType
|
||||
|
||||
ColumnLiterals = t.TypeVar(
|
||||
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
"ColumnLiterals",
|
||||
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||
)
|
||||
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
|
||||
ColumnOrLiteral = t.TypeVar(
|
||||
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
"ColumnOrLiteral",
|
||||
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||
)
|
||||
SchemaInput = t.TypeVar(
|
||||
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
|
||||
)
|
||||
OutputExpressionContainer = t.TypeVar(
|
||||
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
|
||||
)
|
||||
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
|
||||
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
|
||||
|
|
|
@ -18,7 +18,11 @@ class Column:
|
|||
expression = expression.expression # type: ignore
|
||||
elif expression is None or not isinstance(expression, (str, exp.Expression)):
|
||||
expression = self._lit(expression).expression # type: ignore
|
||||
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||
|
||||
expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||
if expression is None:
|
||||
raise ValueError(f"Could not parse {expression}")
|
||||
self.expression: exp.Expression = expression
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.expression)
|
||||
|
@ -135,21 +139,29 @@ class Column:
|
|||
) -> Column:
|
||||
ensured_column = None if column is None else cls.ensure_col(column)
|
||||
ensure_expression_values = {
|
||||
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
|
||||
k: [Column.ensure_col(x).expression for x in v]
|
||||
if is_iterable(v)
|
||||
else Column.ensure_col(v).expression
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
new_expression = (
|
||||
callable_expression(**ensure_expression_values)
|
||||
if ensured_column is None
|
||||
else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
|
||||
else callable_expression(
|
||||
this=ensured_column.column_expression, **ensure_expression_values
|
||||
)
|
||||
)
|
||||
return Column(new_expression)
|
||||
|
||||
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
|
||||
return Column(
|
||||
klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
|
||||
)
|
||||
|
||||
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
|
||||
return Column(
|
||||
klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
|
||||
)
|
||||
|
||||
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
|
||||
return Column(klass(this=self.column_expression, **kwargs))
|
||||
|
@ -188,7 +200,7 @@ class Column:
|
|||
expression.set("table", exp.to_identifier(table_name))
|
||||
return Column(expression)
|
||||
|
||||
def sql(self, **kwargs) -> Column:
|
||||
def sql(self, **kwargs) -> str:
|
||||
return self.expression.sql(**{"dialect": "spark", **kwargs})
|
||||
|
||||
def alias(self, name: str) -> Column:
|
||||
|
@ -265,10 +277,14 @@ class Column:
|
|||
)
|
||||
|
||||
def like(self, other: str):
|
||||
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
|
||||
return self.invoke_expression_over_column(
|
||||
self, exp.Like, expression=self._lit(other).expression
|
||||
)
|
||||
|
||||
def ilike(self, other: str):
|
||||
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
|
||||
return self.invoke_expression_over_column(
|
||||
self, exp.ILike, expression=self._lit(other).expression
|
||||
)
|
||||
|
||||
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
|
||||
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
|
||||
|
@ -287,10 +303,18 @@ class Column:
|
|||
lowerBound: t.Union[ColumnOrLiteral],
|
||||
upperBound: t.Union[ColumnOrLiteral],
|
||||
) -> Column:
|
||||
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
|
||||
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
|
||||
lower_bound_exp = (
|
||||
self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
|
||||
)
|
||||
upper_bound_exp = (
|
||||
self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
|
||||
)
|
||||
return Column(
|
||||
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
|
||||
exp.Between(
|
||||
this=self.column_expression,
|
||||
low=lower_bound_exp.expression,
|
||||
high=upper_bound_exp.expression,
|
||||
)
|
||||
)
|
||||
|
||||
def over(self, window: WindowSpec) -> Column:
|
||||
|
|
|
@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
|
|||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
|
||||
from sqlglot.dataframe.sql._typing import (
|
||||
ColumnLiterals,
|
||||
ColumnOrLiteral,
|
||||
ColumnOrName,
|
||||
OutputExpressionContainer,
|
||||
)
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
|
||||
|
@ -83,7 +88,9 @@ class DataFrame:
|
|||
return from_exp.alias_or_name
|
||||
table_alias = from_exp.find(exp.TableAlias)
|
||||
if not table_alias:
|
||||
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
|
||||
raise RuntimeError(
|
||||
f"Could not find an alias name for this expression: {self.expression}"
|
||||
)
|
||||
return table_alias.alias_or_name
|
||||
return self.expression.ctes[-1].alias
|
||||
|
||||
|
@ -132,12 +139,16 @@ class DataFrame:
|
|||
cte.set("sequence_id", sequence_id or self.sequence_id)
|
||||
return cte, name
|
||||
|
||||
def _ensure_list_of_columns(
|
||||
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
|
||||
) -> t.List[Column]:
|
||||
columns = ensure_list(cols)
|
||||
columns = Column.ensure_cols(columns)
|
||||
return columns
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
||||
...
|
||||
|
||||
def _ensure_list_of_columns(self, cols):
|
||||
return Column.ensure_cols(ensure_list(cols))
|
||||
|
||||
def _ensure_and_normalize_cols(self, cols):
|
||||
cols = self._ensure_list_of_columns(cols)
|
||||
|
@ -153,10 +164,16 @@ class DataFrame:
|
|||
df = self._resolve_pending_hints()
|
||||
sequence_id = sequence_id or df.sequence_id
|
||||
expression = df.expression.copy()
|
||||
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
|
||||
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
|
||||
cte_expression, cte_name = df._create_cte_from_expression(
|
||||
expression=expression, sequence_id=sequence_id
|
||||
)
|
||||
new_expression = df._add_ctes_to_expression(
|
||||
exp.Select(), expression.ctes + [cte_expression]
|
||||
)
|
||||
sel_columns = df._get_outer_select_columns(cte_expression)
|
||||
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
|
||||
new_expression = new_expression.from_(cte_name).select(
|
||||
*[x.alias_or_name for x in sel_columns]
|
||||
)
|
||||
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
||||
|
||||
def _resolve_pending_hints(self) -> DataFrame:
|
||||
|
@ -169,16 +186,23 @@ class DataFrame:
|
|||
hint_expression.args.get("expressions").append(hint)
|
||||
df.pending_hints.remove(hint)
|
||||
|
||||
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
|
||||
join_aliases = {
|
||||
join_table.alias_or_name
|
||||
for join_table in get_tables_from_expression_with_join(expression)
|
||||
}
|
||||
if join_aliases:
|
||||
for hint in df.pending_join_hints:
|
||||
for sequence_id_expression in hint.expressions:
|
||||
sequence_id_or_name = sequence_id_expression.alias_or_name
|
||||
sequence_ids_to_match = [sequence_id_or_name]
|
||||
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
||||
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
|
||||
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
|
||||
sequence_id_or_name
|
||||
]
|
||||
matching_ctes = [
|
||||
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
|
||||
cte
|
||||
for cte in reversed(expression.ctes)
|
||||
if cte.args["sequence_id"] in sequence_ids_to_match
|
||||
]
|
||||
for matching_cte in matching_ctes:
|
||||
if matching_cte.alias_or_name in join_aliases:
|
||||
|
@ -193,9 +217,14 @@ class DataFrame:
|
|||
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
||||
hint_name = hint_name.upper()
|
||||
hint_expression = (
|
||||
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
|
||||
exp.JoinHint(
|
||||
this=hint_name,
|
||||
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
|
||||
)
|
||||
if hint_name in JOIN_HINTS
|
||||
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
|
||||
else exp.Anonymous(
|
||||
this=hint_name, expressions=[parameter.expression for parameter in args]
|
||||
)
|
||||
)
|
||||
new_df = self.copy()
|
||||
new_df.pending_hints.append(hint_expression)
|
||||
|
@ -245,7 +274,9 @@ class DataFrame:
|
|||
def _get_select_expressions(
|
||||
self,
|
||||
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
||||
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
|
||||
select_expressions: t.List[
|
||||
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
|
||||
] = []
|
||||
main_select_ctes: t.List[exp.CTE] = []
|
||||
for cte in self.expression.ctes:
|
||||
cache_storage_level = cte.args.get("cache_storage_level")
|
||||
|
@ -279,14 +310,19 @@ class DataFrame:
|
|||
cache_table_name = df._create_hash_from_expression(select_expression)
|
||||
cache_table = exp.to_table(cache_table_name)
|
||||
original_alias_name = select_expression.args["cte_alias_name"]
|
||||
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
|
||||
|
||||
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
||||
cache_table_name
|
||||
)
|
||||
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
|
||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||
options = [
|
||||
exp.Literal.string("storageLevel"),
|
||||
exp.Literal.string(cache_storage_level),
|
||||
]
|
||||
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
|
||||
expression = exp.Cache(
|
||||
this=cache_table, expression=select_expression, lazy=True, options=options
|
||||
)
|
||||
# We will drop the "view" if it exists before running the cache table
|
||||
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
||||
elif expression_type == exp.Create:
|
||||
|
@ -305,7 +341,9 @@ class DataFrame:
|
|||
raise ValueError(f"Invalid expression type: {expression_type}")
|
||||
output_expressions.append(expression)
|
||||
|
||||
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
|
||||
return [
|
||||
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
|
||||
]
|
||||
|
||||
def copy(self, **kwargs) -> DataFrame:
|
||||
return DataFrame(**object_to_dict(self, **kwargs))
|
||||
|
@ -317,7 +355,9 @@ class DataFrame:
|
|||
if self.expression.args.get("joins"):
|
||||
ambiguous_cols = [col for col in cols if not col.column_expression.table]
|
||||
if ambiguous_cols:
|
||||
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
|
||||
join_table_identifiers = [
|
||||
x.this for x in get_tables_from_expression_with_join(self.expression)
|
||||
]
|
||||
cte_names_in_join = [x.this for x in join_table_identifiers]
|
||||
for ambiguous_col in ambiguous_cols:
|
||||
ctes_with_column = [
|
||||
|
@ -367,14 +407,20 @@ class DataFrame:
|
|||
|
||||
@operation(Operation.FROM)
|
||||
def join(
|
||||
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
|
||||
self,
|
||||
other_df: DataFrame,
|
||||
on: t.Union[str, t.List[str], Column, t.List[Column]],
|
||||
how: str = "inner",
|
||||
**kwargs,
|
||||
) -> DataFrame:
|
||||
other_df = other_df._convert_leaf_to_cte()
|
||||
pre_join_self_latest_cte_name = self.latest_cte_name
|
||||
columns = self._ensure_and_normalize_cols(on)
|
||||
join_type = how.replace("_", " ")
|
||||
if isinstance(columns[0].expression, exp.Column):
|
||||
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
|
||||
join_columns = [
|
||||
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
|
||||
]
|
||||
join_clause = functools.reduce(
|
||||
lambda x, y: x & y,
|
||||
[
|
||||
|
@ -402,7 +448,9 @@ class DataFrame:
|
|||
for column in self._get_outer_select_columns(other_df)
|
||||
]
|
||||
column_value_mapping = {
|
||||
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
|
||||
column.alias_or_name
|
||||
if not isinstance(column.expression.this, exp.Star)
|
||||
else column.sql(): column
|
||||
for column in other_columns + self_columns + join_columns
|
||||
}
|
||||
all_columns = [
|
||||
|
@ -410,16 +458,22 @@ class DataFrame:
|
|||
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
|
||||
]
|
||||
new_df = self.copy(
|
||||
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
|
||||
expression=self.expression.join(
|
||||
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
|
||||
)
|
||||
)
|
||||
new_df.expression = new_df._add_ctes_to_expression(
|
||||
new_df.expression, other_df.expression.ctes
|
||||
)
|
||||
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
|
||||
new_df.pending_hints.extend(other_df.pending_hints)
|
||||
new_df = new_df.select.__wrapped__(new_df, *all_columns)
|
||||
return new_df
|
||||
|
||||
@operation(Operation.ORDER_BY)
|
||||
def orderBy(
|
||||
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
|
||||
self,
|
||||
*cols: t.Union[str, Column],
|
||||
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
||||
|
@ -429,7 +483,10 @@ class DataFrame:
|
|||
columns = self._ensure_and_normalize_cols(cols)
|
||||
pre_ordered_col_indexes = [
|
||||
x
|
||||
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
|
||||
for x in [
|
||||
i if isinstance(col.expression, exp.Ordered) else None
|
||||
for i, col in enumerate(columns)
|
||||
]
|
||||
if x is not None
|
||||
]
|
||||
if ascending is None:
|
||||
|
@ -478,7 +535,9 @@ class DataFrame:
|
|||
for r_column in r_columns_unused:
|
||||
l_expressions.append(exp.alias_(exp.Null(), r_column))
|
||||
r_expressions.append(r_column)
|
||||
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
||||
r_df = (
|
||||
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
||||
)
|
||||
l_df = self.copy()
|
||||
if allowMissingColumns:
|
||||
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
||||
|
@ -536,7 +595,9 @@ class DataFrame:
|
|||
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
||||
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
||||
)
|
||||
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
|
||||
if_null_checks = [
|
||||
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
|
||||
]
|
||||
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
||||
num_nulls = nulls_added_together.alias("num_nulls")
|
||||
new_df = new_df.select(num_nulls, append=True)
|
||||
|
@ -576,11 +637,15 @@ class DataFrame:
|
|||
value_columns = [lit(value) for value in values]
|
||||
|
||||
null_replacement_mapping = {
|
||||
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
|
||||
column.alias_or_name: (
|
||||
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
|
||||
)
|
||||
for column, value in zip(columns, value_columns)
|
||||
}
|
||||
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
||||
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
|
||||
null_replacement_columns = [
|
||||
null_replacement_mapping[column.alias_or_name] for column in all_columns
|
||||
]
|
||||
new_df = new_df.select(*null_replacement_columns)
|
||||
return new_df
|
||||
|
||||
|
@ -589,12 +654,11 @@ class DataFrame:
|
|||
self,
|
||||
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
||||
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
||||
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
||||
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
|
||||
) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.functions import lit
|
||||
|
||||
old_values = None
|
||||
subset = ensure_list(subset)
|
||||
new_df = self.copy()
|
||||
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
||||
|
@ -605,7 +669,9 @@ class DataFrame:
|
|||
new_values = list(to_replace.values())
|
||||
elif not old_values and isinstance(to_replace, list):
|
||||
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
||||
assert len(to_replace) == len(value), "the replacements and values must be the same length"
|
||||
assert len(to_replace) == len(
|
||||
value
|
||||
), "the replacements and values must be the same length"
|
||||
old_values = to_replace
|
||||
new_values = value
|
||||
else:
|
||||
|
@ -635,7 +701,9 @@ class DataFrame:
|
|||
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
||||
col = self._ensure_and_normalize_col(col)
|
||||
existing_col_names = self.expression.named_selects
|
||||
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
|
||||
existing_col_index = (
|
||||
existing_col_names.index(colName) if colName in existing_col_names else None
|
||||
)
|
||||
if existing_col_index:
|
||||
expression = self.expression.copy()
|
||||
expression.expressions[existing_col_index] = col.expression
|
||||
|
@ -645,7 +713,11 @@ class DataFrame:
|
|||
@operation(Operation.SELECT)
|
||||
def withColumnRenamed(self, existing: str, new: str):
|
||||
expression = self.expression.copy()
|
||||
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
|
||||
existing_columns = [
|
||||
expression
|
||||
for expression in expression.expressions
|
||||
if expression.alias_or_name == existing
|
||||
]
|
||||
if not existing_columns:
|
||||
raise ValueError("Tried to rename a column that doesn't exist")
|
||||
for existing_column in existing_columns:
|
||||
|
@ -674,15 +746,19 @@ class DataFrame:
|
|||
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
||||
parameter_list = ensure_list(parameters)
|
||||
parameter_columns = (
|
||||
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
|
||||
self._ensure_list_of_columns(parameter_list)
|
||||
if parameters
|
||||
else Column.ensure_cols([self.sequence_id])
|
||||
)
|
||||
return self._hint(name, parameter_columns)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
|
||||
num_partitions = Column.ensure_cols(ensure_list(numPartitions))
|
||||
def repartition(
|
||||
self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
|
||||
) -> DataFrame:
|
||||
num_partition_cols = self._ensure_list_of_columns(numPartitions)
|
||||
columns = self._ensure_and_normalize_cols(cols)
|
||||
args = num_partitions + columns
|
||||
args = num_partition_cols + columns
|
||||
return self._hint("repartition", args)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
|
|
|
@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
|
|||
|
||||
def when(condition: Column, value: t.Any) -> Column:
|
||||
true_value = value if isinstance(value, Column) else lit(value)
|
||||
return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
|
||||
return Column(
|
||||
glotexp.Case(
|
||||
ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def asc(col: ColumnOrName) -> Column:
|
||||
|
@ -407,7 +411,9 @@ def percentile_approx(
|
|||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
|
||||
)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ApproxQuantile, quantile=lit(percentage)
|
||||
)
|
||||
|
||||
|
||||
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
|
||||
|
@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
|
|||
return Column.invoke_anonymous_function(col, "FACTORIAL")
|
||||
|
||||
|
||||
def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
|
||||
def lag(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
|
||||
) -> Column:
|
||||
if default is not None:
|
||||
return Column.invoke_anonymous_function(col, "LAG", offset, default)
|
||||
if offset != 1:
|
||||
|
@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
|
|||
return Column.invoke_anonymous_function(col, "LAG")
|
||||
|
||||
|
||||
def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
|
||||
def lead(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
|
||||
) -> Column:
|
||||
if default is not None:
|
||||
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
|
||||
if offset != 1:
|
||||
|
@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
|
|||
return Column.invoke_anonymous_function(col, "LEAD")
|
||||
|
||||
|
||||
def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
|
||||
def nth_value(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
|
||||
) -> Column:
|
||||
if ignoreNulls is not None:
|
||||
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
|
||||
if offset != 1:
|
||||
|
@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
|
|||
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
|
||||
|
||||
|
||||
def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
|
||||
def months_between(
|
||||
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
|
||||
) -> Column:
|
||||
if roundOff is None:
|
||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
|
||||
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
|
||||
|
@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
|||
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
|
||||
|
||||
|
||||
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
|
||||
def unix_timestamp(
|
||||
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
|
||||
) -> Column:
|
||||
if format is not None:
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
|
||||
return Column.invoke_expression_over_column(
|
||||
timestamp, glotexp.StrToUnix, format=lit(format)
|
||||
)
|
||||
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
|
||||
|
||||
|
||||
|
@ -642,7 +660,9 @@ def window(
|
|||
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
|
||||
)
|
||||
if slideDuration is not None:
|
||||
return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
|
||||
return Column.invoke_anonymous_function(
|
||||
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
|
||||
)
|
||||
if startTime is not None:
|
||||
return Column.invoke_anonymous_function(
|
||||
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
|
||||
|
@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
|
||||
return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
|
||||
return Column.invoke_expression_over_column(
|
||||
None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
|
||||
)
|
||||
|
||||
|
||||
def decode(col: ColumnOrName, charset: str) -> Column:
|
||||
|
@ -768,7 +790,9 @@ def overlay(
|
|||
|
||||
|
||||
def sentences(
|
||||
string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
|
||||
string: ColumnOrName,
|
||||
language: t.Optional[ColumnOrName] = None,
|
||||
country: t.Optional[ColumnOrName] = None,
|
||||
) -> Column:
|
||||
if language is not None and country is not None:
|
||||
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
|
||||
|
@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
|
|||
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
|
||||
substr_col = lit(substr)
|
||||
if pos is not None:
|
||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
|
||||
return Column.invoke_expression_over_column(
|
||||
str, glotexp.StrPosition, substr=substr_col, position=pos
|
||||
)
|
||||
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
|
||||
|
||||
|
||||
|
@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
|||
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
|
||||
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
|
||||
return Column.invoke_expression_over_column(
|
||||
None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
|
||||
None,
|
||||
glotexp.VarMap,
|
||||
keys=array(*cols[::2]).expression,
|
||||
values=array(*cols[1::2]).expression,
|
||||
)
|
||||
|
||||
|
||||
|
@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
|||
|
||||
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||
value_col = value if isinstance(value, Column) else lit(value)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
|
||||
return Column.invoke_expression_over_column(
|
||||
col, glotexp.ArrayContains, expression=value_col.expression
|
||||
)
|
||||
|
||||
|
||||
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
|
||||
|
||||
|
||||
def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
|
||||
def slice(
|
||||
x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
|
||||
) -> Column:
|
||||
start_col = start if isinstance(start, Column) else lit(start)
|
||||
length_col = length if isinstance(length, Column) else lit(length)
|
||||
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
|
||||
|
||||
|
||||
def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
|
||||
def array_join(
|
||||
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
|
||||
) -> Column:
|
||||
if null_replacement is not None:
|
||||
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
|
||||
return Column.invoke_anonymous_function(
|
||||
col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
|
||||
)
|
||||
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
|
||||
|
||||
|
||||
def concat(*cols: ColumnOrName) -> Column:
|
||||
if len(cols) == 1:
|
||||
return Column.invoke_anonymous_function(cols[0], "CONCAT")
|
||||
return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
|
||||
return Column.invoke_anonymous_function(
|
||||
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
|
||||
)
|
||||
|
||||
|
||||
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
||||
|
@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
|
|||
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
|
||||
|
||||
|
||||
def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
|
||||
def sequence(
|
||||
start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
|
||||
) -> Column:
|
||||
if step is not None:
|
||||
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
|
||||
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
|
||||
|
@ -1103,12 +1144,15 @@ def aggregate(
|
|||
merge_exp = _get_lambda_from_func(merge)
|
||||
if finish is not None:
|
||||
finish_exp = _get_lambda_from_func(finish)
|
||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
|
||||
return Column.invoke_anonymous_function(
|
||||
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
|
||||
)
|
||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
||||
|
||||
|
||||
def transform(
|
||||
col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
|
||||
col: ColumnOrName,
|
||||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
|
||||
|
@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
|
|||
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
|
||||
|
||||
|
||||
def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
|
||||
def filter(
|
||||
col: ColumnOrName,
|
||||
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
|
||||
) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
|
||||
|
||||
|
||||
def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
|
||||
def zip_with(
|
||||
left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
|
||||
) -> Column:
|
||||
f_expression = _get_lambda_from_func(f)
|
||||
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
|
||||
|
||||
|
@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
|
|||
|
||||
|
||||
def _get_lambda_from_func(lambda_expression: t.Callable):
|
||||
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
|
||||
variables = [
|
||||
glotexp.to_identifier(x, quoted=_lambda_quoted(x))
|
||||
for x in lambda_expression.__code__.co_varnames
|
||||
]
|
||||
return glotexp.Lambda(
|
||||
this=lambda_expression(*[Column(x) for x in variables]).expression,
|
||||
expressions=variables,
|
||||
|
|
|
@ -17,7 +17,9 @@ class GroupedData:
|
|||
self.last_op = last_op
|
||||
self.group_by_cols = group_by_cols
|
||||
|
||||
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
|
||||
def _get_function_applied_columns(
|
||||
self, func_name: str, cols: t.Tuple[str, ...]
|
||||
) -> t.List[Column]:
|
||||
func_name = func_name.lower()
|
||||
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
|
||||
|
||||
|
@ -30,9 +32,9 @@ class GroupedData:
|
|||
)
|
||||
cols = self._df._ensure_and_normalize_cols(columns)
|
||||
|
||||
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
|
||||
*[x.expression for x in self.group_by_cols + cols], append=False
|
||||
)
|
||||
expression = self._df.expression.group_by(
|
||||
*[x.expression for x in self.group_by_cols]
|
||||
).select(*[x.expression for x in self.group_by_cols + cols], append=False)
|
||||
return self._df.copy(expression=expression)
|
||||
|
||||
def count(self) -> DataFrame:
|
||||
|
|
|
@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
|
|||
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
||||
|
||||
|
||||
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
|
||||
def replace_alias_name_with_cte_name(
|
||||
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
|
||||
):
|
||||
if id.alias_or_name in spark.name_to_sequence_id_mapping:
|
||||
for cte in reversed(expression_context.ctes):
|
||||
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
|
||||
|
@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
|
|||
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
||||
# be common in practice
|
||||
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
|
||||
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
|
||||
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
|
||||
join_table_aliases = [
|
||||
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
|
||||
]
|
||||
ctes_in_join = [
|
||||
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
|
||||
]
|
||||
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
|
||||
assert len(ctes_in_join) == 2
|
||||
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
||||
|
@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
|
|||
|
||||
|
||||
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
|
||||
values = ensure_list(values)
|
||||
results = []
|
||||
for value in values:
|
||||
if isinstance(value, str):
|
||||
|
|
|
@ -19,12 +19,19 @@ class DataFrameReader:
|
|||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
sqlglot.schema.add_table(tableName)
|
||||
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
|
||||
return DataFrame(
|
||||
self.spark,
|
||||
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
|
||||
)
|
||||
|
||||
|
||||
class DataFrameWriter:
|
||||
def __init__(
|
||||
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
|
||||
self,
|
||||
df: DataFrame,
|
||||
spark: t.Optional[SparkSession] = None,
|
||||
mode: t.Optional[str] = None,
|
||||
by_name: bool = False,
|
||||
):
|
||||
self._df = df
|
||||
self._spark = spark or df.spark
|
||||
|
@ -33,7 +40,10 @@ class DataFrameWriter:
|
|||
|
||||
def copy(self, **kwargs) -> DataFrameWriter:
|
||||
return DataFrameWriter(
|
||||
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
|
||||
**{
|
||||
k[1:] if k.startswith("_") else k: v
|
||||
for k, v in object_to_dict(self, **kwargs).items()
|
||||
}
|
||||
)
|
||||
|
||||
def sql(self, **kwargs) -> t.List[str]:
|
||||
|
|
|
@ -67,13 +67,20 @@ class SparkSession:
|
|||
|
||||
data_expressions = [
|
||||
exp.Tuple(
|
||||
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
|
||||
expressions=list(
|
||||
map(
|
||||
lambda x: F.lit(x).expression,
|
||||
row if not isinstance(row, dict) else row.values(),
|
||||
)
|
||||
)
|
||||
)
|
||||
for row in data
|
||||
]
|
||||
|
||||
sel_columns = [
|
||||
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
|
||||
F.col(name).cast(data_type).alias(name).expression
|
||||
if data_type is not None
|
||||
else F.col(name).expression
|
||||
for name, data_type in column_mapping.items()
|
||||
]
|
||||
|
||||
|
@ -106,10 +113,12 @@ class SparkSession:
|
|||
select_expression.set("with", expression.args.get("with"))
|
||||
expression.set("with", None)
|
||||
del expression.args["expression"]
|
||||
df = DataFrame(self, select_expression, output_expression_container=expression)
|
||||
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
|
||||
df = df._convert_leaf_to_cte()
|
||||
else:
|
||||
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
|
||||
raise ValueError(
|
||||
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
|
||||
)
|
||||
return df
|
||||
|
||||
@property
|
||||
|
|
|
@ -158,7 +158,11 @@ class MapType(DataType):
|
|||
|
||||
class StructField(DataType):
|
||||
def __init__(
|
||||
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
|
||||
self,
|
||||
name: str,
|
||||
dataType: DataType,
|
||||
nullable: bool = True,
|
||||
metadata: t.Optional[t.Dict[str, t.Any]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.dataType = dataType
|
||||
|
|
|
@ -74,8 +74,13 @@ class WindowSpec:
|
|||
window_spec.expression.args["order"].set("expressions", order_by)
|
||||
return window_spec
|
||||
|
||||
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
|
||||
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
|
||||
def _calc_start_end(
|
||||
self, start: int, end: int
|
||||
) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
|
||||
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
|
||||
"start_side": None,
|
||||
"end_side": None,
|
||||
}
|
||||
if start == Window.currentRow:
|
||||
kwargs["start"] = "CURRENT ROW"
|
||||
else:
|
||||
|
@ -83,7 +88,9 @@ class WindowSpec:
|
|||
**kwargs,
|
||||
**{
|
||||
"start_side": "PRECEDING",
|
||||
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
|
||||
"start": "UNBOUNDED"
|
||||
if start <= Window.unboundedPreceding
|
||||
else F.lit(start).expression,
|
||||
},
|
||||
}
|
||||
if end == Window.currentRow:
|
||||
|
@ -93,7 +100,9 @@ class WindowSpec:
|
|||
**kwargs,
|
||||
**{
|
||||
"end_side": "FOLLOWING",
|
||||
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
|
||||
"end": "UNBOUNDED"
|
||||
if end >= Window.unboundedFollowing
|
||||
else F.lit(end).expression,
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
@ -103,7 +112,10 @@ class WindowSpec:
|
|||
spec = self._calc_start_end(start, end)
|
||||
spec["kind"] = "ROWS"
|
||||
window_spec.expression.set(
|
||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
||||
"spec",
|
||||
exp.WindowSpec(
|
||||
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
|
||||
),
|
||||
)
|
||||
return window_spec
|
||||
|
||||
|
@ -112,6 +124,9 @@ class WindowSpec:
|
|||
spec = self._calc_start_end(start, end)
|
||||
spec["kind"] = "RANGE"
|
||||
window_spec.expression.set(
|
||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
||||
"spec",
|
||||
exp.WindowSpec(
|
||||
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
|
||||
),
|
||||
)
|
||||
return window_spec
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
inline_array_sql,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _date_add(expression_class):
|
||||
def func(args):
|
||||
interval = list_get(args, 1)
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=interval.args.get("unit"),
|
||||
)
|
||||
|
@ -23,6 +23,13 @@ def _date_add(expression_class):
|
|||
return func
|
||||
|
||||
|
||||
def _date_trunc(args):
|
||||
unit = seq_get(args, 1)
|
||||
if isinstance(unit, exp.Column):
|
||||
unit = exp.Var(this=unit.name)
|
||||
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||
|
||||
|
||||
def _date_add_sql(data_type, kind):
|
||||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
|
|||
structs = []
|
||||
for row in rows:
|
||||
aliases = [
|
||||
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||
exp.alias_(value, column_name)
|
||||
for value, column_name in zip(row, expression.args["alias"].args["columns"])
|
||||
]
|
||||
structs.append(exp.Struct(expressions=aliases))
|
||||
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
|
||||
|
@ -89,18 +97,19 @@ class BigQuery(Dialect):
|
|||
"%j": "%-j",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = [
|
||||
(prefix + quote, quote) if prefix else quote
|
||||
for quote in ["'", '"', '"""', "'''"]
|
||||
for prefix in ["", "r", "R"]
|
||||
]
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPE = "\\"
|
||||
ESCAPES = ["\\"]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
|
@ -111,35 +120,40 @@ class BigQuery(Dialect):
|
|||
"WINDOW": TokenType.WINDOW,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": _date_trunc,
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"TIME_ADD": _date_add(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
||||
"TIME_SUB": _date_add(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
|
||||
this=seq_get(args, 1), format=seq_get(args, 0)
|
||||
),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
**Parser.NO_PAREN_FUNCTIONS,
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||
TokenType.CURRENT_TIME: exp.CurrentTime,
|
||||
}
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
*Parser.NESTED_TYPE_TOKENS,
|
||||
*parser.Parser.NESTED_TYPE_TOKENS,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
|
@ -148,6 +162,7 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -157,11 +172,13 @@ class BigQuery(Dialect):
|
|||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
|
||||
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.INT: "INT64",
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _lower_func(sql):
|
||||
|
@ -14,11 +15,12 @@ class ClickHouse(Dialect):
|
|||
normalize_functions = None
|
||||
null_ordering = "nulls_are_last"
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"DATETIME64": TokenType.DATETIME,
|
||||
"INT8": TokenType.TINYINT,
|
||||
|
@ -30,9 +32,9 @@ class ClickHouse(Dialect):
|
|||
"TUPLE": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"MAP": parse_var_map,
|
||||
}
|
||||
|
||||
|
@ -44,11 +46,11 @@ class ClickHouse(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.DATETIME: "DateTime64",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
|
@ -63,7 +65,7 @@ class ClickHouse(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.spark import Spark
|
||||
|
@ -15,7 +17,7 @@ class Databricks(Spark):
|
|||
|
||||
class Generator(Spark.Generator):
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
**Spark.Generator.TRANSFORMS, # type: ignore
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import flatten, list_get
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
@ -32,7 +35,7 @@ class Dialects(str, Enum):
|
|||
|
||||
|
||||
class _Dialect(type):
|
||||
classes = {}
|
||||
classes: t.Dict[str, Dialect] = {}
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, key):
|
||||
|
@ -56,19 +59,30 @@ class _Dialect(type):
|
|||
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||
|
||||
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
|
||||
klass.identifier_start, klass.identifier_end = list(
|
||||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
|
||||
if (
|
||||
klass.tokenizer_class._BIT_STRINGS
|
||||
and exp.BitString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.BitString
|
||||
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
|
||||
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
|
||||
if (
|
||||
klass.tokenizer_class._HEX_STRINGS
|
||||
and exp.HexString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.HexString
|
||||
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
|
||||
if (
|
||||
klass.tokenizer_class._BYTE_STRINGS
|
||||
and exp.ByteString not in klass.generator_class.TRANSFORMS
|
||||
):
|
||||
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.ByteString
|
||||
|
@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
|
|||
index_offset = 0
|
||||
unnest_column_only = False
|
||||
alias_post_tablesample = False
|
||||
normalize_functions = "upper"
|
||||
normalize_functions: t.Optional[str] = "upper"
|
||||
null_ordering = "nulls_are_small"
|
||||
|
||||
date_format = "'%Y-%m-%d'"
|
||||
dateint_format = "'%Y%m%d'"
|
||||
time_format = "'%Y-%m-%d %H:%M:%S'"
|
||||
time_mapping = {}
|
||||
time_mapping: t.Dict[str, str] = {}
|
||||
|
||||
# autofilled
|
||||
quote_start = None
|
||||
|
@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
|
|||
"quote_end": self.quote_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"escape": self.tokenizer_class.ESCAPE,
|
||||
"escape": self.tokenizer_class.ESCAPES[0],
|
||||
"index_offset": self.index_offset,
|
||||
"time_mapping": self.inverse_time_mapping,
|
||||
"time_trie": self.inverse_time_trie,
|
||||
|
@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
|
|||
|
||||
|
||||
def if_sql(self, expression):
|
||||
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"IF({expressions})"
|
||||
|
||||
|
||||
|
@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
|
|||
|
||||
def _format_time(args):
|
||||
return exp_class(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
format=Dialect[dialect].format_time(
|
||||
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
||||
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
|
|||
"expressions",
|
||||
[e for e in schema.expressions if e not in partitions],
|
||||
)
|
||||
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
|
||||
prop.replace(
|
||||
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
|
||||
)
|
||||
expression.set("this", schema)
|
||||
|
||||
return self.create_sql(expression)
|
||||
|
@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
|
|||
def parse_date_delta(exp_class, unit_mapping=None):
|
||||
def inner_func(args):
|
||||
unit_based = len(args) == 3
|
||||
this = list_get(args, 2) if unit_based else list_get(args, 0)
|
||||
expression = list_get(args, 1) if unit_based else list_get(args, 1)
|
||||
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=expression, unit=unit)
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
approx_count_distinct_sql,
|
||||
|
@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
str_position_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _unix_to_time(self, expression):
|
||||
|
@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
|
|||
|
||||
|
||||
def _sort_array_reverse(args):
|
||||
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
|
||||
|
||||
|
||||
def _struct_pack_sql(self, expression):
|
||||
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
|
||||
args = [
|
||||
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
|
||||
for e in expression.expressions
|
||||
]
|
||||
return f"STRUCT_PACK({', '.join(args)})"
|
||||
|
||||
|
||||
|
@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
|
|||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
":=": TokenType.EQ,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
|
@ -92,7 +95,7 @@ class DuckDB(Dialect):
|
|||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=exp.Div(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Literal.number(1000),
|
||||
)
|
||||
),
|
||||
|
@ -112,11 +115,11 @@ class DuckDB(Dialect):
|
|||
"UNNEST": exp.Explode.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: rename_func("LIST_VALUE"),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
|
@ -160,7 +163,7 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp, transforms
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
approx_count_distinct_sql,
|
||||
|
@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
|
|||
struct_extract_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser, parse_var_map
|
||||
from sqlglot.tokens import Tokenizer
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
|
||||
# (FuncType, Multiplier)
|
||||
DATE_DELTA_INTERVAL = {
|
||||
|
@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
|
|||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
|
||||
int(expression.text("expression")) * multiplier
|
||||
if expression.expression.is_number
|
||||
else expression.expression
|
||||
)
|
||||
modified_increment = exp.Literal.number(modified_increment)
|
||||
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
|
||||
|
@ -165,10 +167,10 @@ class Hive(Dialect):
|
|||
dateint_format = "'yyyyMMdd'"
|
||||
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", '"']
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPE = "\\"
|
||||
ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
NUMERIC_LITERALS = {
|
||||
|
@ -180,40 +182,44 @@ class Hive(Dialect):
|
|||
"BD": "DECIMAL",
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=list_get(args, 0),
|
||||
expression=list_get(args, 1),
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=list_get(args, 0)),
|
||||
expression=exp.TsOrDsToDate(this=list_get(args, 1)),
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
),
|
||||
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Mul(
|
||||
this=list_get(args, 1),
|
||||
this=seq_get(args, 1),
|
||||
expression=exp.Literal.number(-1),
|
||||
),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
"LOCATE": lambda args: exp.StrPosition(
|
||||
this=list_get(args, 1),
|
||||
substr=list_get(args, 0),
|
||||
position=list_get(args, 2),
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
),
|
||||
"LOG": (
|
||||
lambda args: exp.Log.from_arg_list(args)
|
||||
if len(args) > 1
|
||||
else exp.Ln.from_arg_list(args)
|
||||
),
|
||||
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
|
||||
"MAP": parse_var_map,
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
"PERCENTILE": exp.Quantile.from_arg_list,
|
||||
|
@ -226,15 +232,16 @@ class Hive(Dialect):
|
|||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.AnonymousProperty: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
no_ilike_sql,
|
||||
|
@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
|
|||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _show_parser(*args, **kwargs):
|
||||
def _parse(self):
|
||||
return self._parse_show_mysql(*args, **kwargs)
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
def _date_trunc_sql(self, expression):
|
||||
unit = expression.text("unit").lower()
|
||||
unit = expression.name.lower()
|
||||
|
||||
this = self.sql(expression.this)
|
||||
expr = self.sql(expression.expression)
|
||||
|
||||
if unit == "day":
|
||||
return f"DATE({this})"
|
||||
return f"DATE({expr})"
|
||||
|
||||
if unit == "week":
|
||||
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
|
||||
date_format = "%Y %u %w"
|
||||
elif unit == "month":
|
||||
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
|
||||
date_format = "%Y %c %e"
|
||||
elif unit == "quarter":
|
||||
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
|
||||
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
|
||||
date_format = "%Y %c %e"
|
||||
elif unit == "year":
|
||||
concat = f"CONCAT(YEAR({this}), ' 1 1')"
|
||||
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
|
||||
date_format = "%Y %c %e"
|
||||
else:
|
||||
self.unsupported("Unexpected interval unit: {unit}")
|
||||
return f"DATE({this})"
|
||||
return f"DATE({expr})"
|
||||
|
||||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||
|
||||
|
||||
def _str_to_date(args):
|
||||
date_format = MySQL.format_time(list_get(args, 1))
|
||||
return exp.StrToDate(this=list_get(args, 0), format=date_format)
|
||||
date_format = MySQL.format_time(seq_get(args, 1))
|
||||
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||
|
||||
|
||||
def _str_to_date_sql(self, expression):
|
||||
|
@ -66,9 +75,9 @@ def _trim_sql(self, expression):
|
|||
|
||||
def _date_add(expression_class):
|
||||
def func(args):
|
||||
interval = list_get(args, 1)
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
expression=interval.this,
|
||||
unit=exp.Literal.string(interval.text("unit").lower()),
|
||||
)
|
||||
|
@ -101,15 +110,16 @@ class MySQL(Dialect):
|
|||
"%l": "%-I",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", '"']
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["'", "\\"]
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
|
@ -156,20 +166,23 @@ class MySQL(Dialect):
|
|||
"_UTF32": TokenType.INTRODUCER,
|
||||
"_UTF8MB3": TokenType.INTRODUCER,
|
||||
"_UTF8MB4": TokenType.INTRODUCER,
|
||||
"@@": TokenType.SESSION_PARAMETER,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"GROUP_CONCAT": lambda self: self.expression(
|
||||
exp.GroupConcat,
|
||||
this=self._parse_lambda(),
|
||||
|
@ -178,15 +191,212 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**Parser.PROPERTY_PARSERS,
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
"BINARY LOGS": _show_parser("BINARY LOGS"),
|
||||
"MASTER LOGS": _show_parser("BINARY LOGS"),
|
||||
"BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
|
||||
"CHARACTER SET": _show_parser("CHARACTER SET"),
|
||||
"CHARSET": _show_parser("CHARACTER SET"),
|
||||
"COLLATION": _show_parser("COLLATION"),
|
||||
"FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
|
||||
"COLUMNS": _show_parser("COLUMNS", target="FROM"),
|
||||
"CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
|
||||
"CREATE EVENT": _show_parser("CREATE EVENT", target=True),
|
||||
"CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
|
||||
"CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
|
||||
"CREATE TABLE": _show_parser("CREATE TABLE", target=True),
|
||||
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
|
||||
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
|
||||
"DATABASES": _show_parser("DATABASES"),
|
||||
"ENGINE": _show_parser("ENGINE", target=True),
|
||||
"STORAGE ENGINES": _show_parser("ENGINES"),
|
||||
"ENGINES": _show_parser("ENGINES"),
|
||||
"ERRORS": _show_parser("ERRORS"),
|
||||
"EVENTS": _show_parser("EVENTS"),
|
||||
"FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
|
||||
"FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
|
||||
"GRANTS": _show_parser("GRANTS", target="FOR"),
|
||||
"INDEX": _show_parser("INDEX", target="FROM"),
|
||||
"MASTER STATUS": _show_parser("MASTER STATUS"),
|
||||
"OPEN TABLES": _show_parser("OPEN TABLES"),
|
||||
"PLUGINS": _show_parser("PLUGINS"),
|
||||
"PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
|
||||
"PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
|
||||
"PRIVILEGES": _show_parser("PRIVILEGES"),
|
||||
"FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
|
||||
"PROCESSLIST": _show_parser("PROCESSLIST"),
|
||||
"PROFILE": _show_parser("PROFILE"),
|
||||
"PROFILES": _show_parser("PROFILES"),
|
||||
"RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
|
||||
"REPLICAS": _show_parser("REPLICAS"),
|
||||
"SLAVE HOSTS": _show_parser("REPLICAS"),
|
||||
"REPLICA STATUS": _show_parser("REPLICA STATUS"),
|
||||
"SLAVE STATUS": _show_parser("REPLICA STATUS"),
|
||||
"GLOBAL STATUS": _show_parser("STATUS", global_=True),
|
||||
"SESSION STATUS": _show_parser("STATUS"),
|
||||
"STATUS": _show_parser("STATUS"),
|
||||
"TABLE STATUS": _show_parser("TABLE STATUS"),
|
||||
"FULL TABLES": _show_parser("TABLES", full=True),
|
||||
"TABLES": _show_parser("TABLES"),
|
||||
"TRIGGERS": _show_parser("TRIGGERS"),
|
||||
"GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
|
||||
"SESSION VARIABLES": _show_parser("VARIABLES"),
|
||||
"VARIABLES": _show_parser("VARIABLES"),
|
||||
"WARNINGS": _show_parser("WARNINGS"),
|
||||
}
|
||||
|
||||
SET_PARSERS = {
|
||||
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
|
||||
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
|
||||
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
|
||||
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
|
||||
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
|
||||
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"NAMES": lambda self: self._parse_set_item_names(),
|
||||
}
|
||||
|
||||
PROFILE_TYPES = {
|
||||
"ALL",
|
||||
"BLOCK IO",
|
||||
"CONTEXT SWITCHES",
|
||||
"CPU",
|
||||
"IPC",
|
||||
"MEMORY",
|
||||
"PAGE FAULTS",
|
||||
"SOURCE",
|
||||
"SWAPS",
|
||||
}
|
||||
|
||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
self._match_text(target)
|
||||
target_id = self._parse_id_var()
|
||||
else:
|
||||
target_id = None
|
||||
|
||||
log = self._parse_string() if self._match_text("IN") else None
|
||||
|
||||
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
||||
position = self._parse_number() if self._match_text("FROM") else None
|
||||
db = None
|
||||
else:
|
||||
position = None
|
||||
db = self._parse_id_var() if self._match_text("FROM") else None
|
||||
|
||||
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
|
||||
|
||||
like = self._parse_string() if self._match_text("LIKE") else None
|
||||
where = self._parse_where()
|
||||
|
||||
if this == "PROFILE":
|
||||
types = self._parse_csv(self._parse_show_profile_type)
|
||||
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
|
||||
offset = self._parse_number() if self._match_text("OFFSET") else None
|
||||
limit = self._parse_number() if self._match_text("LIMIT") else None
|
||||
else:
|
||||
types, query = None, None
|
||||
offset, limit = self._parse_oldstyle_limit()
|
||||
|
||||
mutex = True if self._match_text("MUTEX") else None
|
||||
mutex = False if self._match_text("STATUS") else mutex
|
||||
|
||||
return self.expression(
|
||||
exp.Show,
|
||||
this=this,
|
||||
target=target_id,
|
||||
full=full,
|
||||
log=log,
|
||||
position=position,
|
||||
db=db,
|
||||
channel=channel,
|
||||
like=like,
|
||||
where=where,
|
||||
types=types,
|
||||
query=query,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
mutex=mutex,
|
||||
**{"global": global_},
|
||||
)
|
||||
|
||||
def _parse_show_profile_type(self):
|
||||
for type_ in self.PROFILE_TYPES:
|
||||
if self._match_text(*type_.split(" ")):
|
||||
return exp.Var(this=type_)
|
||||
return None
|
||||
|
||||
def _parse_oldstyle_limit(self):
|
||||
limit = None
|
||||
offset = None
|
||||
if self._match_text("LIMIT"):
|
||||
parts = self._parse_csv(self._parse_number)
|
||||
if len(parts) == 1:
|
||||
limit = parts[0]
|
||||
elif len(parts) == 2:
|
||||
limit = parts[1]
|
||||
offset = parts[0]
|
||||
return offset, limit
|
||||
|
||||
def _default_parse_set_item(self):
|
||||
return self._parse_set_item_assignment(kind=None)
|
||||
|
||||
def _parse_set_item_assignment(self, kind):
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
if not self._match(TokenType.EQ):
|
||||
self.raise_error("Expected =")
|
||||
right = self._parse_statement() or self._parse_id_var()
|
||||
|
||||
this = self.expression(
|
||||
exp.EQ,
|
||||
this=left,
|
||||
expression=right,
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_set_item_charset(self, kind):
|
||||
this = self._parse_string() or self._parse_id_var()
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_set_item_names(self):
|
||||
charset = self._parse_string() or self._parse_id_var()
|
||||
if self._match_text("COLLATE"):
|
||||
collate = self._parse_string() or self._parse_id_var()
|
||||
else:
|
||||
collate = None
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=charset,
|
||||
collate=collate,
|
||||
kind="NAMES",
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
@ -199,6 +409,8 @@ class MySQL(Dialect):
|
|||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Trim: _trim_sql,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
|
@ -209,4 +421,69 @@ class MySQL(Dialect):
|
|||
exp.SchemaCommentProperty,
|
||||
}
|
||||
|
||||
WITH_PROPERTIES = {}
|
||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
||||
|
||||
def show_sql(self, expression):
|
||||
this = f" {expression.name}"
|
||||
full = " FULL" if expression.args.get("full") else ""
|
||||
global_ = " GLOBAL" if expression.args.get("global") else ""
|
||||
|
||||
target = self.sql(expression, "target")
|
||||
target = f" {target}" if target else ""
|
||||
if expression.name in {"COLUMNS", "INDEX"}:
|
||||
target = f" FROM{target}"
|
||||
elif expression.name == "GRANTS":
|
||||
target = f" FOR{target}"
|
||||
|
||||
db = self._prefixed_sql("FROM", expression, "db")
|
||||
|
||||
like = self._prefixed_sql("LIKE", expression, "like")
|
||||
where = self.sql(expression, "where")
|
||||
|
||||
types = self.expressions(expression, key="types")
|
||||
types = f" {types}" if types else types
|
||||
query = self._prefixed_sql("FOR QUERY", expression, "query")
|
||||
|
||||
if expression.name == "PROFILE":
|
||||
offset = self._prefixed_sql("OFFSET", expression, "offset")
|
||||
limit = self._prefixed_sql("LIMIT", expression, "limit")
|
||||
else:
|
||||
offset = ""
|
||||
limit = self._oldstyle_limit_sql(expression)
|
||||
|
||||
log = self._prefixed_sql("IN", expression, "log")
|
||||
position = self._prefixed_sql("FROM", expression, "position")
|
||||
|
||||
channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
|
||||
|
||||
if expression.name == "ENGINE":
|
||||
mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
|
||||
else:
|
||||
mutex_or_status = ""
|
||||
|
||||
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
|
||||
|
||||
def _prefixed_sql(self, prefix, expression, arg):
|
||||
sql = self.sql(expression, arg)
|
||||
if not sql:
|
||||
return ""
|
||||
return f" {prefix} {sql}"
|
||||
|
||||
def _oldstyle_limit_sql(self, expression):
|
||||
limit = self.sql(expression, "limit")
|
||||
offset = self.sql(expression, "offset")
|
||||
if limit:
|
||||
limit_offset = f"{offset}, {limit}" if offset else limit
|
||||
return f" LIMIT {limit_offset}"
|
||||
return ""
|
||||
|
||||
def setitem_sql(self, expression):
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} " if kind else ""
|
||||
this = self.sql(expression, "this")
|
||||
collate = self.sql(expression, "collate")
|
||||
collate = f" COLLATE {collate}" if collate else ""
|
||||
return f"{kind}{this}{collate}"
|
||||
|
||||
def set_sql(self, expression):
|
||||
return f"SET {self.expressions(expression)}"
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from sqlglot import exp, transforms
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, tokens, transforms
|
||||
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _limit_sql(self, expression):
|
||||
|
@ -36,9 +37,9 @@ class Oracle(Dialect):
|
|||
"YYYY": "%Y", # 2015
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "NUMBER",
|
||||
exp.DataType.Type.SMALLINT: "NUMBER",
|
||||
exp.DataType.Type.INT: "NUMBER",
|
||||
|
@ -49,11 +50,12 @@ class Oracle(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
|
||||
exp.DataType.Type.TEXT: "CLOB",
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
|
@ -86,9 +88,9 @@ class Oracle(Dialect):
|
|||
def table_sql(self, expression):
|
||||
return super().table_sql(expression, sep=" ")
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
str_position_sql,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import delegate, preprocess
|
||||
|
||||
|
||||
|
@ -160,12 +160,12 @@ class Postgres(Dialect):
|
|||
"YYYY": "%Y", # 2015
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"COMMENT ON": TokenType.COMMENT_ON,
|
||||
|
@ -179,31 +179,32 @@ class Postgres(Dialect):
|
|||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
SINGLE_TOKENS = {
|
||||
**Tokenizer.SINGLE_TOKENS,
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
exp.DataType.Type.BINARY: "BYTEA",
|
||||
exp.DataType.Type.VARBINARY: "BYTEA",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ColumnDef: preprocess(
|
||||
[
|
||||
_auto_increment_to_serial,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp, transforms
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
|
@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
|
|||
struct_extract_sql,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _approx_distinct_sql(self, expression):
|
||||
|
@ -110,30 +110,29 @@ class Presto(Dialect):
|
|||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
time_format = "'%Y-%m-%d %H:%i:%S'"
|
||||
time_mapping = MySQL.time_mapping
|
||||
time_mapping = MySQL.time_mapping # type: ignore
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
"VARBINARY": TokenType.BINARY,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ROW": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
this=list_get(args, 2),
|
||||
expression=list_get(args, 1),
|
||||
unit=list_get(args, 0),
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATE_DIFF": lambda args: exp.DateDiff(
|
||||
this=list_get(args, 2),
|
||||
expression=list_get(args, 1),
|
||||
unit=list_get(args, 0),
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
|
@ -143,7 +142,7 @@ class Presto(Dialect):
|
|||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
|
@ -159,7 +158,7 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
|
@ -169,8 +168,8 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
|
|||
class Redshift(Postgres):
|
||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||
time_mapping = {
|
||||
**Postgres.time_mapping,
|
||||
**Postgres.time_mapping, # type: ignore
|
||||
"MON": "%b",
|
||||
"HH": "%H",
|
||||
}
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
ESCAPE = "\\"
|
||||
ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS,
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"VARBYTE": TokenType.BINARY,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
}
|
||||
|
||||
class Generator(Postgres.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
**Postgres.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BINARY: "VARBYTE",
|
||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
|
@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _check_int(s):
|
||||
|
@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
|
|||
|
||||
# case: <numeric_expr> [ , <scale> ]
|
||||
if second_arg.name not in ["0", "3", "9"]:
|
||||
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
|
||||
raise ValueError(
|
||||
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
|
||||
)
|
||||
|
||||
if second_arg.name == "0":
|
||||
timescale = exp.UnixToTime.SECONDS
|
||||
|
@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
|
|||
|
||||
return exp.UnixToTime(this=first_arg, scale=timescale)
|
||||
|
||||
first_arg = list_get(args, 0)
|
||||
first_arg = seq_get(args, 0)
|
||||
if not isinstance(first_arg, Literal):
|
||||
# case: <variant_expr>
|
||||
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
||||
|
@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unix_to_time(self, expression):
|
||||
def _unix_to_time_sql(self, expression):
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||
|
@ -132,9 +134,9 @@ class Snowflake(Dialect):
|
|||
"ff6": "%f",
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||
|
@ -143,18 +145,18 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": _parse_date_part,
|
||||
}
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*Parser.FUNC_TOKENS,
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.RLIKE,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
**Parser.COLUMN_OPERATORS,
|
||||
**parser.Parser.COLUMN_OPERATORS, # type: ignore
|
||||
TokenType.COLON: lambda self, this, path: self.expression(
|
||||
exp.Bracket,
|
||||
this=this,
|
||||
|
@ -163,21 +165,21 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**Parser.PROPERTY_PARSERS,
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
ESCAPE = "\\"
|
||||
ESCAPES = ["\\"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**Tokenizer.SINGLE_TOKENS,
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"QUALIFY": TokenType.QUALIFY,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
|
@ -187,15 +189,15 @@ class Snowflake(Dialect):
|
|||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
CREATE_TRANSIENT = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
|
@ -204,7 +206,7 @@ class Snowflake(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, parser
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self, e):
|
||||
|
@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
|
|||
class Spark(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
**Hive.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=list_get(args, 1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=list_get(args, 0),
|
||||
expression=list_get(args, 1),
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
|
||||
this=list_get(args, 0),
|
||||
expression=list_get(args, 1),
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=list_get(args, 0)),
|
||||
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=list_get(args, 1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
|
@ -88,14 +89,14 @@ class Spark(Hive):
|
|||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING,
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "BYTE",
|
||||
exp.DataType.Type.SMALLINT: "SHORT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
|
||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
|
@ -114,6 +115,8 @@ class Spark(Hive):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
}
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
|
@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class SQLite(Dialect):
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
"VARBINARY": TokenType.BINARY,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
|
@ -46,6 +45,7 @@ class SQLite(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
}
|
||||
|
||||
TOKEN_MAPPING = {
|
||||
|
@ -53,7 +53,7 @@ class SQLite(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
|
||||
class StarRocks(MySQL):
|
||||
class Generator(MySQL.Generator):
|
||||
class Generator(MySQL.Generator): # type: ignore
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
|
@ -13,7 +15,7 @@ class StarRocks(MySQL):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
**MySQL.Generator.TRANSFORMS, # type: ignore
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
|
@ -22,3 +24,4 @@ class StarRocks(MySQL):
|
|||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
}
|
||||
TRANSFORMS.pop(exp.DateTrunc)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from sqlglot import exp
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
|
||||
|
||||
def _if_sql(self, expression):
|
||||
|
@ -20,17 +20,17 @@ def _count_sql(self, expression):
|
|||
|
||||
|
||||
class Tableau(Dialect):
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.If: _if_sql,
|
||||
exp.Coalesce: _coalesce_sql,
|
||||
exp.Count: _count_sql,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.presto import Presto
|
||||
|
||||
|
@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
|
|||
class Trino(Presto):
|
||||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
**Presto.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
|
||||
FULL_FORMAT_TIME_MAPPING = {
|
||||
"weekday": "%A",
|
||||
"dw": "%A",
|
||||
"w": "%A",
|
||||
"month": "%B",
|
||||
"mm": "%B",
|
||||
"m": "%B",
|
||||
}
|
||||
DATE_DELTA_INTERVAL = {
|
||||
"year": "year",
|
||||
"yyyy": "year",
|
||||
|
@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
|
|||
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
||||
def _format_time(args):
|
||||
return exp_class(
|
||||
this=list_get(args, 1),
|
||||
this=seq_get(args, 1),
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
list_get(args, 0).name or (TSQL.time_format if default is True else default),
|
||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
|
||||
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
|
||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
|
||||
if full_format_mapping
|
||||
else TSQL.time_mapping,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
|
|||
|
||||
|
||||
def parse_format(args):
|
||||
fmt = list_get(args, 1)
|
||||
fmt = seq_get(args, 1)
|
||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
|
||||
if number_fmt:
|
||||
return exp.NumberToStr(this=list_get(args, 0), format=fmt)
|
||||
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
|
||||
return exp.TimeToStr(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
format=exp.Literal.string(
|
||||
format_time(fmt.name, TSQL.format_time_mapping)
|
||||
if len(fmt.name) == 1
|
||||
|
@ -188,11 +197,11 @@ class TSQL(Dialect):
|
|||
"Y": "%a %Y",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"BIT": TokenType.BOOLEAN,
|
||||
"REAL": TokenType.FLOAT,
|
||||
"NTEXT": TokenType.TEXT,
|
||||
|
@ -200,7 +209,6 @@ class TSQL(Dialect):
|
|||
"DATETIME2": TokenType.DATETIME,
|
||||
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"VARBINARY": TokenType.BINARY,
|
||||
"IMAGE": TokenType.IMAGE,
|
||||
"MONEY": TokenType.MONEY,
|
||||
"SMALLMONEY": TokenType.SMALLMONEY,
|
||||
|
@ -213,9 +221,9 @@ class TSQL(Dialect):
|
|||
"TOP": TokenType.TOP,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CHARINDEX": exp.StrPosition.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
|
@ -243,14 +251,16 @@ class TSQL(Dialect):
|
|||
this = self._parse_column()
|
||||
|
||||
# Retrieve length of datatype and override to default if not specified
|
||||
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
|
||||
|
||||
# Check whether a conversion with format is applicable
|
||||
if self._match(TokenType.COMMA):
|
||||
format_val = self._parse_number().name
|
||||
if format_val not in TSQL.convert_format_mapping:
|
||||
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
|
||||
raise ValueError(
|
||||
f"CONVERT function at T-SQL does not support format style {format_val}"
|
||||
)
|
||||
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
|
||||
|
||||
# Check whether the convert entails a string to date format
|
||||
|
@ -272,9 +282,9 @@ class TSQL(Dialect):
|
|||
# Entails a simple cast without any format requirement
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BOOLEAN: "BIT",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
|
@ -283,7 +293,7 @@ class TSQL(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
|
|
|
@ -4,7 +4,7 @@ from heapq import heappop, heappush
|
|||
|
||||
from sqlglot import Dialect
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import ensure_list
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -116,7 +116,9 @@ class ChangeDistiller:
|
|||
source_node = self._source_index[kept_source_node_id]
|
||||
target_node = self._target_index[kept_target_node_id]
|
||||
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
|
||||
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
|
||||
edit_script.extend(
|
||||
self._generate_move_edits(source_node, target_node, matching_set)
|
||||
)
|
||||
edit_script.append(Keep(source_node, target_node))
|
||||
else:
|
||||
edit_script.append(Update(source_node, target_node))
|
||||
|
@ -158,13 +160,16 @@ class ChangeDistiller:
|
|||
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
||||
if max_leaves_num:
|
||||
common_leaves_num = sum(
|
||||
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
|
||||
1 if s in source_leaf_ids and t in target_leaf_ids else 0
|
||||
for s, t in leaves_matching_set
|
||||
)
|
||||
leaf_similarity_score = common_leaves_num / max_leaves_num
|
||||
else:
|
||||
leaf_similarity_score = 0.0
|
||||
|
||||
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
|
||||
adjusted_t = (
|
||||
self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
|
||||
)
|
||||
|
||||
if leaf_similarity_score >= 0.8 or (
|
||||
leaf_similarity_score >= adjusted_t
|
||||
|
@ -201,7 +206,10 @@ class ChangeDistiller:
|
|||
matching_set = set()
|
||||
while candidate_matchings:
|
||||
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
|
||||
if (
|
||||
id(source_leaf) in self._unmatched_source_nodes
|
||||
and id(target_leaf) in self._unmatched_target_nodes
|
||||
):
|
||||
matching_set.add((id(source_leaf), id(target_leaf)))
|
||||
self._unmatched_source_nodes.remove(id(source_leaf))
|
||||
self._unmatched_target_nodes.remove(id(target_leaf))
|
||||
|
@ -241,8 +249,7 @@ def _get_leaves(expression):
|
|||
has_child_exprs = False
|
||||
|
||||
for a in expression.args.values():
|
||||
nodes = ensure_list(a)
|
||||
for node in nodes:
|
||||
for node in ensure_collection(a):
|
||||
if isinstance(node, exp.Expression):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
|
@ -268,7 +275,7 @@ def _expression_only_args(expression):
|
|||
args = []
|
||||
if expression:
|
||||
for a in expression.args.values():
|
||||
args.extend(ensure_list(a))
|
||||
args.extend(ensure_collection(a))
|
||||
return [a for a in args if isinstance(a, exp.Expression)]
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from enum import auto
|
||||
|
||||
from sqlglot.helper import AutoName
|
||||
|
@ -30,7 +33,11 @@ class OptimizeError(SqlglotError):
|
|||
pass
|
||||
|
||||
|
||||
def concat_errors(errors, maximum):
|
||||
class SchemaError(SqlglotError):
|
||||
pass
|
||||
|
||||
|
||||
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
|
||||
msg = [str(e) for e in errors[:maximum]]
|
||||
remaining = len(errors) - maximum
|
||||
if remaining > 0:
|
||||
|
|
|
@ -19,6 +19,7 @@ class Context:
|
|||
env (Optional[dict]): dictionary of functions within the execution context
|
||||
"""
|
||||
self.tables = tables
|
||||
self._table = None
|
||||
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
||||
self.row_readers = {name: table.reader for name, table in tables.items()}
|
||||
self.env = {**(env or {}), "scope": self.row_readers}
|
||||
|
@ -29,8 +30,27 @@ class Context:
|
|||
def eval_tuple(self, codes):
|
||||
return tuple(self.eval(code) for code in codes)
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
if self._table is None:
|
||||
self._table = list(self.tables.values())[0]
|
||||
for other in self.tables.values():
|
||||
if self._table.columns != other.columns:
|
||||
raise Exception(f"Columns are different.")
|
||||
if len(self._table.rows) != len(other.rows):
|
||||
raise Exception(f"Rows are different.")
|
||||
return self._table
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.table.columns
|
||||
|
||||
def __iter__(self):
|
||||
return self.table_iter(list(self.tables)[0])
|
||||
self.env["scope"] = self.row_readers
|
||||
for i in range(len(self.table.rows)):
|
||||
for table in self.tables.values():
|
||||
reader = table[i]
|
||||
yield reader, self
|
||||
|
||||
def table_iter(self, table):
|
||||
self.env["scope"] = self.row_readers
|
||||
|
@ -38,8 +58,8 @@ class Context:
|
|||
for reader in self.tables[table]:
|
||||
yield reader, self
|
||||
|
||||
def sort(self, table, key):
|
||||
table = self.tables[table]
|
||||
def sort(self, key):
|
||||
table = self.table
|
||||
|
||||
def sort_key(row):
|
||||
table.reader.row = row
|
||||
|
@ -47,20 +67,20 @@ class Context:
|
|||
|
||||
table.rows.sort(key=sort_key)
|
||||
|
||||
def set_row(self, table, row):
|
||||
self.row_readers[table].row = row
|
||||
def set_row(self, row):
|
||||
for table in self.tables.values():
|
||||
table.reader.row = row
|
||||
self.env["scope"] = self.row_readers
|
||||
|
||||
def set_index(self, table, index):
|
||||
self.row_readers[table].row = self.tables[table].rows[index]
|
||||
def set_index(self, index):
|
||||
for table in self.tables.values():
|
||||
table[index]
|
||||
self.env["scope"] = self.row_readers
|
||||
|
||||
def set_range(self, table, start, end):
|
||||
self.range_readers[table].range = range(start, end)
|
||||
def set_range(self, start, end):
|
||||
for name in self.tables:
|
||||
self.range_readers[name].range = range(start, end)
|
||||
self.env["scope"] = self.range_readers
|
||||
|
||||
def __getitem__(self, table):
|
||||
return self.env["scope"][table]
|
||||
|
||||
def __contains__(self, table):
|
||||
return table in self.tables
|
||||
|
|
|
@ -2,6 +2,8 @@ import datetime
|
|||
import re
|
||||
import statistics
|
||||
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
|
||||
|
||||
class reverse_key:
|
||||
def __init__(self, obj):
|
||||
|
@ -25,7 +27,7 @@ ENV = {
|
|||
"str": str,
|
||||
"desc": reverse_key,
|
||||
"SUM": sum,
|
||||
"AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
|
||||
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
|
||||
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
|
||||
"MAX": max,
|
||||
"MIN": min,
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
import ast
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from sqlglot import exp, planner
|
||||
from sqlglot import exp, generator, planner, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql
|
||||
from sqlglot.executor.context import Context
|
||||
from sqlglot.executor.env import ENV
|
||||
from sqlglot.executor.table import Table
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import csv_reader
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
||||
|
||||
class PythonExecutor:
|
||||
|
@ -26,7 +25,11 @@ class PythonExecutor:
|
|||
while queue:
|
||||
node = queue.pop()
|
||||
context = self.context(
|
||||
{name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
|
||||
{
|
||||
name: table
|
||||
for dep in node.dependencies
|
||||
for name, table in contexts[dep].tables.items()
|
||||
}
|
||||
)
|
||||
running.add(node)
|
||||
|
||||
|
@ -76,13 +79,10 @@ class PythonExecutor:
|
|||
return Table(expression.alias_or_name for expression in expressions)
|
||||
|
||||
def scan(self, step, context):
|
||||
if hasattr(step, "source"):
|
||||
source = step.source
|
||||
source = step.source
|
||||
|
||||
if isinstance(source, exp.Expression):
|
||||
source = source.name or source.alias
|
||||
else:
|
||||
source = step.name
|
||||
if isinstance(source, exp.Expression):
|
||||
source = source.name or source.alias
|
||||
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
@ -96,14 +96,12 @@ class PythonExecutor:
|
|||
|
||||
if projections:
|
||||
sink = self.table(step.projections)
|
||||
elif source in context:
|
||||
sink = Table(context[source].columns)
|
||||
else:
|
||||
sink = None
|
||||
|
||||
for reader, ctx in table_iter:
|
||||
if sink is None:
|
||||
sink = Table(ctx[source].columns)
|
||||
sink = Table(reader.columns)
|
||||
|
||||
if condition and not ctx.eval(condition):
|
||||
continue
|
||||
|
@ -135,119 +133,102 @@ class PythonExecutor:
|
|||
types.append(type(ast.literal_eval(v)))
|
||||
except (ValueError, SyntaxError):
|
||||
types.append(str)
|
||||
context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
|
||||
yield context[alias], context
|
||||
context.set_row(tuple(t(v) for t, v in zip(types, row)))
|
||||
yield context.table.reader, context
|
||||
|
||||
def join(self, step, context):
|
||||
source = step.name
|
||||
|
||||
join_context = self.context({source: context.tables[source]})
|
||||
|
||||
def merge_context(ctx, table):
|
||||
# create a new context where all existing tables are mapped to a new one
|
||||
return self.context({name: table for name in ctx.tables})
|
||||
source_table = context.tables[source]
|
||||
source_context = self.context({source: source_table})
|
||||
column_ranges = {source: range(0, len(source_table.columns))}
|
||||
|
||||
for name, join in step.joins.items():
|
||||
join_context = self.context({**join_context.tables, name: context.tables[name]})
|
||||
table = context.tables[name]
|
||||
start = max(r.stop for r in column_ranges.values())
|
||||
column_ranges[name] = range(start, len(table.columns) + start)
|
||||
join_context = self.context({name: table})
|
||||
|
||||
if join.get("source_key"):
|
||||
table = self.hash_join(join, source, name, join_context)
|
||||
table = self.hash_join(join, source_context, join_context)
|
||||
else:
|
||||
table = self.nested_loop_join(join, source, name, join_context)
|
||||
table = self.nested_loop_join(join, source_context, join_context)
|
||||
|
||||
join_context = merge_context(join_context, table)
|
||||
source_context = self.context(
|
||||
{
|
||||
name: Table(table.columns, table.rows, column_range)
|
||||
for name, column_range in column_ranges.items()
|
||||
}
|
||||
)
|
||||
|
||||
# apply projections or conditions
|
||||
context = self.scan(step, join_context)
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
# use the scan context since it returns a single table
|
||||
# otherwise there are no projections so all other tables are still in scope
|
||||
if step.projections:
|
||||
return context
|
||||
if not condition or not projections:
|
||||
return source_context
|
||||
|
||||
return merge_context(join_context, context.tables[source])
|
||||
sink = self.table(step.projections if projections else source_context.columns)
|
||||
|
||||
def nested_loop_join(self, _join, a, b, context):
|
||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
||||
for reader, ctx in join_context:
|
||||
if condition and not ctx.eval(condition):
|
||||
continue
|
||||
|
||||
for reader_a, _ in context.table_iter(a):
|
||||
for reader_b, _ in context.table_iter(b):
|
||||
if projections:
|
||||
sink.append(ctx.eval_tuple(projections))
|
||||
else:
|
||||
sink.append(reader.row)
|
||||
|
||||
if len(sink) >= step.limit:
|
||||
break
|
||||
|
||||
return self.context({step.name: sink})
|
||||
|
||||
def nested_loop_join(self, _join, source_context, join_context):
|
||||
table = Table(source_context.columns + join_context.columns)
|
||||
|
||||
for reader_a, _ in source_context:
|
||||
for reader_b, _ in join_context:
|
||||
table.append(reader_a.row + reader_b.row)
|
||||
|
||||
return table
|
||||
|
||||
def hash_join(self, join, a, b, context):
|
||||
a_key = self.generate_tuple(join["source_key"])
|
||||
b_key = self.generate_tuple(join["join_key"])
|
||||
def hash_join(self, join, source_context, join_context):
|
||||
source_key = self.generate_tuple(join["source_key"])
|
||||
join_key = self.generate_tuple(join["join_key"])
|
||||
|
||||
results = collections.defaultdict(lambda: ([], []))
|
||||
|
||||
for reader, ctx in context.table_iter(a):
|
||||
results[ctx.eval_tuple(a_key)][0].append(reader.row)
|
||||
for reader, ctx in context.table_iter(b):
|
||||
results[ctx.eval_tuple(b_key)][1].append(reader.row)
|
||||
for reader, ctx in source_context:
|
||||
results[ctx.eval_tuple(source_key)][0].append(reader.row)
|
||||
for reader, ctx in join_context:
|
||||
results[ctx.eval_tuple(join_key)][1].append(reader.row)
|
||||
|
||||
table = Table(source_context.columns + join_context.columns)
|
||||
|
||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
||||
for a_group, b_group in results.values():
|
||||
for a_row, b_row in itertools.product(a_group, b_group):
|
||||
table.append(a_row + b_row)
|
||||
|
||||
return table
|
||||
|
||||
def sort_merge_join(self, join, a, b, context):
|
||||
a_key = self.generate_tuple(join["source_key"])
|
||||
b_key = self.generate_tuple(join["join_key"])
|
||||
|
||||
context.sort(a, a_key)
|
||||
context.sort(b, b_key)
|
||||
|
||||
a_i = 0
|
||||
b_i = 0
|
||||
a_n = len(context.tables[a])
|
||||
b_n = len(context.tables[b])
|
||||
|
||||
table = Table(context.tables[a].columns + context.tables[b].columns)
|
||||
|
||||
def get_key(source, key, i):
|
||||
context.set_index(source, i)
|
||||
return context.eval_tuple(key)
|
||||
|
||||
while a_i < a_n and b_i < b_n:
|
||||
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
|
||||
|
||||
a_group = []
|
||||
|
||||
while a_i < a_n and key == get_key(a, a_key, a_i):
|
||||
a_group.append(context[a].row)
|
||||
a_i += 1
|
||||
|
||||
b_group = []
|
||||
|
||||
while b_i < b_n and key == get_key(b, b_key, b_i):
|
||||
b_group.append(context[b].row)
|
||||
b_i += 1
|
||||
|
||||
for a_row, b_row in itertools.product(a_group, b_group):
|
||||
table.append(a_row + b_row)
|
||||
|
||||
return table
|
||||
|
||||
def aggregate(self, step, context):
|
||||
source = step.source
|
||||
group_by = self.generate_tuple(step.group)
|
||||
aggregations = self.generate_tuple(step.aggregations)
|
||||
operands = self.generate_tuple(step.operands)
|
||||
|
||||
context.sort(source, group_by)
|
||||
|
||||
if step.operands:
|
||||
if operands:
|
||||
source_table = context.tables[source]
|
||||
operand_table = Table(source_table.columns + self.table(step.operands).columns)
|
||||
|
||||
for reader, ctx in context:
|
||||
operand_table.append(reader.row + ctx.eval_tuple(operands))
|
||||
|
||||
context = self.context({source: operand_table})
|
||||
context = self.context(
|
||||
{None: operand_table, **{table: operand_table for table in context.tables}}
|
||||
)
|
||||
|
||||
context.sort(group_by)
|
||||
|
||||
group = None
|
||||
start = 0
|
||||
|
@ -256,15 +237,15 @@ class PythonExecutor:
|
|||
table = self.table(step.group + step.aggregations)
|
||||
|
||||
for i in range(length):
|
||||
context.set_index(source, i)
|
||||
context.set_index(i)
|
||||
key = context.eval_tuple(group_by)
|
||||
group = key if group is None else group
|
||||
end += 1
|
||||
|
||||
if i == length - 1:
|
||||
context.set_range(source, start, end - 1)
|
||||
context.set_range(start, end - 1)
|
||||
elif key != group:
|
||||
context.set_range(source, start, end - 2)
|
||||
context.set_range(start, end - 2)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -272,13 +253,32 @@ class PythonExecutor:
|
|||
group = key
|
||||
start = end - 2
|
||||
|
||||
return self.scan(step, self.context({source: table}))
|
||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||
|
||||
if step.projections:
|
||||
return self.scan(step, context)
|
||||
return context
|
||||
|
||||
def sort(self, step, context):
|
||||
table = list(context.tables)[0]
|
||||
key = self.generate_tuple(step.key)
|
||||
context.sort(table, key)
|
||||
return self.scan(step, context)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
sink = self.table(step.projections)
|
||||
|
||||
for reader, ctx in context:
|
||||
sink.append(ctx.eval_tuple(projections))
|
||||
|
||||
context = self.context(
|
||||
{
|
||||
None: sink,
|
||||
**{table: sink for table in context.tables},
|
||||
}
|
||||
)
|
||||
context.sort(self.generate_tuple(step.key))
|
||||
|
||||
if not math.isinf(step.limit):
|
||||
context.table.rows = context.table.rows[0 : step.limit]
|
||||
|
||||
return self.context({step.name: context.table})
|
||||
|
||||
|
||||
def _cast_py(self, expression):
|
||||
|
@ -293,7 +293,7 @@ def _cast_py(self, expression):
|
|||
|
||||
|
||||
def _column_py(self, expression):
|
||||
table = self.sql(expression, "table")
|
||||
table = self.sql(expression, "table") or None
|
||||
this = self.sql(expression, "this")
|
||||
return f"scope[{table}][{this}]"
|
||||
|
||||
|
@ -319,10 +319,10 @@ def _ordered_py(self, expression):
|
|||
|
||||
|
||||
class Python(Dialect):
|
||||
class Tokenizer(Tokenizer):
|
||||
ESCAPE = "\\"
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
ESCAPES = ["\\"]
|
||||
|
||||
class Generator(Generator):
|
||||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
exp.Alias: lambda self, e: self.sql(e.this),
|
||||
exp.Array: inline_array_sql,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
class Table:
|
||||
def __init__(self, *columns, rows=None):
|
||||
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
|
||||
def __init__(self, columns, rows=None, column_range=None):
|
||||
self.columns = tuple(columns)
|
||||
self.column_range = column_range
|
||||
self.reader = RowReader(self.columns, self.column_range)
|
||||
|
||||
self.rows = rows or []
|
||||
if rows:
|
||||
assert len(rows[0]) == len(self.columns)
|
||||
self.reader = RowReader(self.columns)
|
||||
self.range_reader = RangeReader(self)
|
||||
|
||||
def append(self, row):
|
||||
|
@ -29,15 +31,22 @@ class Table:
|
|||
return self.reader
|
||||
|
||||
def __repr__(self):
|
||||
widths = {column: len(column) for column in self.columns}
|
||||
lines = [" ".join(column for column in self.columns)]
|
||||
columns = tuple(
|
||||
column
|
||||
for i, column in enumerate(self.columns)
|
||||
if not self.column_range or i in self.column_range
|
||||
)
|
||||
widths = {column: len(column) for column in columns}
|
||||
lines = [" ".join(column for column in columns)]
|
||||
|
||||
for i, row in enumerate(self):
|
||||
if i > 10:
|
||||
break
|
||||
|
||||
lines.append(
|
||||
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
|
||||
" ".join(
|
||||
str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
|
||||
)
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
@ -70,8 +79,10 @@ class RangeReader:
|
|||
|
||||
|
||||
class RowReader:
|
||||
def __init__(self, columns):
|
||||
self.columns = {column: i for i, column in enumerate(columns)}
|
||||
def __init__(self, columns, column_range=None):
|
||||
self.columns = {
|
||||
column: i for i, column in enumerate(columns) if not column_range or i in column_range
|
||||
}
|
||||
self.row = None
|
||||
|
||||
def __getitem__(self, column):
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import numbers
|
||||
import re
|
||||
import typing as t
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from enum import auto
|
||||
|
@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
|
|||
from sqlglot.helper import (
|
||||
AutoName,
|
||||
camel_to_snake_case,
|
||||
ensure_list,
|
||||
list_get,
|
||||
ensure_collection,
|
||||
seq_get,
|
||||
split_num_words,
|
||||
subclasses,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
|
||||
class _Expression(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
|
@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
|
|||
or optional (False).
|
||||
"""
|
||||
|
||||
key = None
|
||||
key = "Expression"
|
||||
arg_types = {"this": True}
|
||||
__slots__ = ("args", "parent", "arg_key", "type")
|
||||
__slots__ = ("args", "parent", "arg_key", "type", "comment")
|
||||
|
||||
def __init__(self, **args):
|
||||
self.args = args
|
||||
self.parent = None
|
||||
self.arg_key = None
|
||||
self.type = None
|
||||
self.comment = None
|
||||
|
||||
for arg_key, value in self.args.items():
|
||||
self._set_parent(arg_key, value)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.key,
|
||||
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
|
||||
tuple(
|
||||
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
|
|||
return field.this
|
||||
return ""
|
||||
|
||||
def find_comment(self, key: str) -> str:
|
||||
"""
|
||||
Finds the comment that is attached to a specified child node.
|
||||
|
||||
Args:
|
||||
key: the key of the target child node (e.g. "this", "expression", etc).
|
||||
|
||||
Returns:
|
||||
The comment attached to the child node, or the empty string, if it doesn't exist.
|
||||
"""
|
||||
field = self.args.get(key)
|
||||
return field.comment if isinstance(field, Expression) else ""
|
||||
|
||||
@property
|
||||
def is_string(self):
|
||||
return isinstance(self, Literal) and self.args["is_string"]
|
||||
|
@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
|
|||
return self.alias or self.name
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self.__class__(**deepcopy(self.args))
|
||||
copy = self.__class__(**deepcopy(self.args))
|
||||
copy.comment = self.comment
|
||||
copy.type = self.type
|
||||
return copy
|
||||
|
||||
def copy(self):
|
||||
new = deepcopy(self)
|
||||
|
@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
|
|||
return
|
||||
|
||||
for k, v in self.args.items():
|
||||
nodes = ensure_list(v)
|
||||
|
||||
for node in nodes:
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
yield from node.dfs(self, k, prune)
|
||||
|
||||
|
@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
if isinstance(item, Expression):
|
||||
for k, v in item.args.items():
|
||||
nodes = ensure_list(v)
|
||||
|
||||
for node in nodes:
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
queue.append((node, item, k))
|
||||
|
||||
|
@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
|
|||
def __repr__(self):
|
||||
return self.to_s()
|
||||
|
||||
def sql(self, dialect=None, **opts):
|
||||
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
||||
"""
|
||||
Returns SQL string representation of this tree.
|
||||
|
||||
|
@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
return Dialect.get_or_raise(dialect)().generate(self, **opts)
|
||||
|
||||
def to_s(self, hide_missing=True, level=0):
|
||||
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
|
||||
indent = "" if not level else "\n"
|
||||
indent += "".join([" "] * level)
|
||||
left = f"({self.key.upper()} "
|
||||
|
@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
|
|||
args = {
|
||||
k: ", ".join(
|
||||
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
|
||||
for v in ensure_list(vs)
|
||||
for v in ensure_collection(vs)
|
||||
if v is not None
|
||||
)
|
||||
for k, vs in self.args.items()
|
||||
}
|
||||
args["comment"] = self.comment
|
||||
args["type"] = self.type
|
||||
args = {k: v for k, v in args.items() if v or not hide_missing}
|
||||
|
||||
right = ", ".join(f"{k}: {v}" for k, v in args.items())
|
||||
|
@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
|
|||
pass
|
||||
|
||||
|
||||
class Annotation(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expression": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def alias(self):
|
||||
return self.expression.alias_or_name
|
||||
|
||||
|
||||
class Cache(Expression):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
|
@ -623,6 +635,38 @@ class Describe(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class Set(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class SetItem(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"kind": False,
|
||||
"collate": False, # MySQL SET NAMES statement
|
||||
}
|
||||
|
||||
|
||||
class Show(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"target": False,
|
||||
"offset": False,
|
||||
"limit": False,
|
||||
"like": False,
|
||||
"where": False,
|
||||
"db": False,
|
||||
"full": False,
|
||||
"mutex": False,
|
||||
"query": False,
|
||||
"channel": False,
|
||||
"global": False,
|
||||
"log": False,
|
||||
"position": False,
|
||||
"types": False,
|
||||
}
|
||||
|
||||
|
||||
class UserDefinedFunction(Expression):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
@ -864,18 +908,20 @@ class Literal(Condition):
|
|||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
|
||||
isinstance(other, Literal)
|
||||
and self.this == other.this
|
||||
and self.args["is_string"] == other.args["is_string"]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.key, self.this, self.args["is_string"]))
|
||||
|
||||
@classmethod
|
||||
def number(cls, number):
|
||||
def number(cls, number) -> Literal:
|
||||
return cls(this=str(number), is_string=False)
|
||||
|
||||
@classmethod
|
||||
def string(cls, string):
|
||||
def string(cls, string) -> Literal:
|
||||
return cls(this=str(string), is_string=True)
|
||||
|
||||
|
||||
|
@ -1087,7 +1133,7 @@ class Properties(Expression):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, properties_dict):
|
||||
def from_dict(cls, properties_dict) -> Properties:
|
||||
expressions = []
|
||||
for key, value in properties_dict.items():
|
||||
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
||||
|
@ -1323,7 +1369,7 @@ class Select(Subqueryable):
|
|||
**QUERY_MODIFIERS,
|
||||
}
|
||||
|
||||
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the FROM expression.
|
||||
|
||||
|
@ -1356,7 +1402,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the GROUP BY expression.
|
||||
|
||||
|
@ -1392,7 +1438,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the ORDER BY expression.
|
||||
|
||||
|
@ -1425,7 +1471,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the SORT BY expression.
|
||||
|
||||
|
@ -1458,7 +1504,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the CLUSTER BY expression.
|
||||
|
||||
|
@ -1491,7 +1537,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def limit(self, expression, dialect=None, copy=True, **opts):
|
||||
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the LIMIT expression.
|
||||
|
||||
|
@ -1522,7 +1568,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def offset(self, expression, dialect=None, copy=True, **opts):
|
||||
def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
||||
|
@ -1553,7 +1599,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Append to or set the SELECT expressions.
|
||||
|
||||
|
@ -1583,7 +1629,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Append to or set the LATERAL expressions.
|
||||
|
||||
|
@ -1626,7 +1672,7 @@ class Select(Subqueryable):
|
|||
dialect=None,
|
||||
copy=True,
|
||||
**opts,
|
||||
):
|
||||
) -> Select:
|
||||
"""
|
||||
Append to or set the JOIN expressions.
|
||||
|
||||
|
@ -1672,7 +1718,7 @@ class Select(Subqueryable):
|
|||
join.this.replace(join.this.subquery())
|
||||
|
||||
if join_type:
|
||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
|
||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||
if natural:
|
||||
join.set("natural", True)
|
||||
if side:
|
||||
|
@ -1681,12 +1727,12 @@ class Select(Subqueryable):
|
|||
join.set("kind", kind.text)
|
||||
|
||||
if on:
|
||||
on = and_(*ensure_list(on), dialect=dialect, **opts)
|
||||
on = and_(*ensure_collection(on), dialect=dialect, **opts)
|
||||
join.set("on", on)
|
||||
|
||||
if using:
|
||||
join = _apply_list_builder(
|
||||
*ensure_list(using),
|
||||
*ensure_collection(using),
|
||||
instance=join,
|
||||
arg="using",
|
||||
append=append,
|
||||
|
@ -1705,7 +1751,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Append to or set the WHERE expressions.
|
||||
|
||||
|
@ -1737,7 +1783,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
|
||||
def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
"""
|
||||
Append to or set the HAVING expressions.
|
||||
|
||||
|
@ -1769,7 +1815,7 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def distinct(self, distinct=True, copy=True):
|
||||
def distinct(self, distinct=True, copy=True) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
||||
|
@ -1788,7 +1834,7 @@ class Select(Subqueryable):
|
|||
instance.set("distinct", Distinct() if distinct else None)
|
||||
return instance
|
||||
|
||||
def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
|
||||
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
|
||||
"""
|
||||
Convert this expression to a CREATE TABLE AS statement.
|
||||
|
||||
|
@ -1826,11 +1872,11 @@ class Select(Subqueryable):
|
|||
)
|
||||
|
||||
@property
|
||||
def named_selects(self):
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
def selects(self) -> t.List[Expression]:
|
||||
return self.expressions
|
||||
|
||||
|
||||
|
@ -1910,12 +1956,16 @@ class Parameter(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class SessionParameter(Expression):
|
||||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Placeholder(Expression):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class Null(Condition):
|
||||
arg_types = {}
|
||||
arg_types: t.Dict[str, t.Any] = {}
|
||||
|
||||
|
||||
class Boolean(Condition):
|
||||
|
@ -1936,6 +1986,7 @@ class DataType(Expression):
|
|||
NVARCHAR = auto()
|
||||
TEXT = auto()
|
||||
BINARY = auto()
|
||||
VARBINARY = auto()
|
||||
INT = auto()
|
||||
TINYINT = auto()
|
||||
SMALLINT = auto()
|
||||
|
@ -1975,7 +2026,7 @@ class DataType(Expression):
|
|||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||
|
||||
@classmethod
|
||||
def build(cls, dtype, **kwargs):
|
||||
def build(cls, dtype, **kwargs) -> DataType:
|
||||
return DataType(
|
||||
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
|
||||
**kwargs,
|
||||
|
@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class NullSafeEQ(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class NullSafeNEQ(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class Distance(Binary):
|
||||
pass
|
||||
|
||||
|
||||
class Escape(Binary):
|
||||
pass
|
||||
|
||||
|
@ -2101,18 +2164,14 @@ class Is(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class Kwarg(Binary):
|
||||
"""Kwarg in special functions like func(kwarg => y)."""
|
||||
|
||||
|
||||
class Like(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class SimilarTo(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class Distance(Binary):
|
||||
pass
|
||||
|
||||
|
||||
class LT(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class SimilarTo(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
||||
class Sub(Binary):
|
||||
pass
|
||||
|
||||
|
@ -2189,7 +2252,13 @@ class Distinct(Expression):
|
|||
|
||||
|
||||
class In(Predicate):
|
||||
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expressions": False,
|
||||
"query": False,
|
||||
"unnest": False,
|
||||
"field": False,
|
||||
}
|
||||
|
||||
|
||||
class TimeUnit(Expression):
|
||||
|
@ -2255,7 +2324,9 @@ class Func(Condition):
|
|||
@classmethod
|
||||
def sql_names(cls):
|
||||
if cls is Func:
|
||||
raise NotImplementedError("SQL name is only supported by concrete function implementations")
|
||||
raise NotImplementedError(
|
||||
"SQL name is only supported by concrete function implementations"
|
||||
)
|
||||
if not hasattr(cls, "_sql_names"):
|
||||
cls._sql_names = [camel_to_snake_case(cls.__name__)]
|
||||
return cls._sql_names
|
||||
|
@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
|
|||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
class DateTrunc(Func, TimeUnit):
|
||||
arg_types = {"this": True, "unit": True, "zone": False}
|
||||
class DateTrunc(Func):
|
||||
arg_types = {"this": True, "expression": True, "zone": False}
|
||||
|
||||
|
||||
class DatetimeAdd(Func, TimeUnit):
|
||||
|
@ -2791,6 +2862,10 @@ class Year(Func):
|
|||
pass
|
||||
|
||||
|
||||
class Use(Expression):
|
||||
pass
|
||||
|
||||
|
||||
def _norm_args(expression):
|
||||
args = {}
|
||||
|
||||
|
@ -2822,7 +2897,7 @@ def maybe_parse(
|
|||
dialect=None,
|
||||
prefix=None,
|
||||
**opts,
|
||||
):
|
||||
) -> t.Optional[Expression]:
|
||||
"""Gracefully handle a possible string or expression.
|
||||
|
||||
Example:
|
||||
|
@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
|
|||
return Except(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def select(*expressions, dialect=None, **opts):
|
||||
def select(*expressions, dialect=None, **opts) -> Select:
|
||||
"""
|
||||
Initializes a syntax tree from one or multiple SELECT expressions.
|
||||
|
||||
|
@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
|
|||
return Select().select(*expressions, dialect=dialect, **opts)
|
||||
|
||||
|
||||
def from_(*expressions, dialect=None, **opts):
|
||||
def from_(*expressions, dialect=None, **opts) -> Select:
|
||||
"""
|
||||
Initializes a syntax tree from a FROM expression.
|
||||
|
||||
|
@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
|
|||
return Select().from_(*expressions, dialect=dialect, **opts)
|
||||
|
||||
|
||||
def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
||||
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
|
||||
"""
|
||||
Creates an update statement.
|
||||
|
||||
|
@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
|||
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
|
||||
update.set(
|
||||
"expressions",
|
||||
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
|
||||
[
|
||||
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
|
||||
for k, v in properties.items()
|
||||
],
|
||||
)
|
||||
if from_:
|
||||
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
|
||||
|
@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
|
|||
return update
|
||||
|
||||
|
||||
def delete(table, where=None, dialect=None, **opts):
|
||||
def delete(table, where=None, dialect=None, **opts) -> Delete:
|
||||
"""
|
||||
Builds a delete statement.
|
||||
|
||||
|
@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
|
|||
)
|
||||
|
||||
|
||||
def condition(expression, dialect=None, **opts):
|
||||
def condition(expression, dialect=None, **opts) -> Condition:
|
||||
"""
|
||||
Initialize a logical condition expression.
|
||||
|
||||
|
@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
|
|||
Returns:
|
||||
Condition: the expression
|
||||
"""
|
||||
return maybe_parse(
|
||||
return maybe_parse( # type: ignore
|
||||
expression,
|
||||
into=Condition,
|
||||
dialect=dialect,
|
||||
|
@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
|
|||
)
|
||||
|
||||
|
||||
def and_(*expressions, dialect=None, **opts):
|
||||
def and_(*expressions, dialect=None, **opts) -> And:
|
||||
"""
|
||||
Combine multiple conditions with an AND logical operator.
|
||||
|
||||
|
@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
|
|||
return _combine(expressions, And, dialect, **opts)
|
||||
|
||||
|
||||
def or_(*expressions, dialect=None, **opts):
|
||||
def or_(*expressions, dialect=None, **opts) -> Or:
|
||||
"""
|
||||
Combine multiple conditions with an OR logical operator.
|
||||
|
||||
|
@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
|
|||
return _combine(expressions, Or, dialect, **opts)
|
||||
|
||||
|
||||
def not_(expression, dialect=None, **opts):
|
||||
def not_(expression, dialect=None, **opts) -> Not:
|
||||
"""
|
||||
Wrap a condition with a NOT operator.
|
||||
|
||||
|
@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
|
|||
return Not(this=_wrap_operator(this))
|
||||
|
||||
|
||||
def paren(expression):
|
||||
def paren(expression) -> Paren:
|
||||
return Paren(this=expression)
|
||||
|
||||
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
|
||||
|
||||
|
||||
def to_identifier(alias, quoted=None):
|
||||
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
||||
if alias is None:
|
||||
return None
|
||||
if isinstance(alias, Identifier):
|
||||
|
@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
|
|||
return identifier
|
||||
|
||||
|
||||
def to_table(sql_path: str, **kwargs) -> Table:
|
||||
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
|
||||
"""
|
||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||
|
||||
If a table is passed in then that table is returned.
|
||||
|
||||
Args:
|
||||
sql_path(str|Table): `[catalog].[schema].[table]` string
|
||||
sql_path: a `[catalog].[schema].[table]` string.
|
||||
|
||||
Returns:
|
||||
Table: A table expression
|
||||
A table expression.
|
||||
"""
|
||||
if sql_path is None or isinstance(sql_path, Table):
|
||||
return sql_path
|
||||
|
@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
return Select().from_(expression, dialect=dialect, **opts)
|
||||
|
||||
|
||||
def column(col, table=None, quoted=None):
|
||||
def column(col, table=None, quoted=None) -> Column:
|
||||
"""
|
||||
Build a Column.
|
||||
Args:
|
||||
|
@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
|
|||
)
|
||||
|
||||
|
||||
def table_(table, db=None, catalog=None, quoted=None, alias=None):
|
||||
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||
"""Build a Table.
|
||||
|
||||
Args:
|
||||
|
@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
|
|||
)
|
||||
|
||||
|
||||
def values(values, alias=None):
|
||||
def values(values, alias=None) -> Values:
|
||||
"""Build VALUES statement.
|
||||
|
||||
Example:
|
||||
|
@ -3449,7 +3527,7 @@ def values(values, alias=None):
|
|||
)
|
||||
|
||||
|
||||
def convert(value):
|
||||
def convert(value) -> Expression:
|
||||
"""Convert a python value into an expression object.
|
||||
|
||||
Raises an error if a conversion is not possible.
|
||||
|
@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
|
|||
|
||||
for cn in child_nodes:
|
||||
if isinstance(cn, Expression):
|
||||
cns = ensure_list(fun(cn))
|
||||
for child_node in cns:
|
||||
for child_node in ensure_collection(fun(cn)):
|
||||
new_child_nodes.append(child_node)
|
||||
child_node.parent = expression
|
||||
child_node.arg_key = k
|
||||
else:
|
||||
new_child_nodes.append(cn)
|
||||
|
||||
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
|
||||
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
|
||||
|
||||
|
||||
def column_table_names(expression):
|
||||
|
@ -3529,7 +3606,7 @@ def column_table_names(expression):
|
|||
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
|
||||
|
||||
|
||||
def table_name(table):
|
||||
def table_name(table) -> str:
|
||||
"""Get the full name of a table as a string.
|
||||
|
||||
Args:
|
||||
|
@ -3546,6 +3623,9 @@ def table_name(table):
|
|||
|
||||
table = maybe_parse(table, into=Table)
|
||||
|
||||
if not table:
|
||||
raise ValueError(f"Cannot parse {table}")
|
||||
|
||||
return ".".join(
|
||||
part
|
||||
for part in (
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||
|
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
NEWLINE_RE = re.compile("\r\n?|\n")
|
||||
|
||||
|
||||
class Generator:
|
||||
"""
|
||||
|
@ -47,8 +53,7 @@ class Generator:
|
|||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
|
||||
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
|
||||
comments: Whether or not to preserve comments in the ouput SQL code.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
|
@ -65,14 +70,16 @@ class Generator:
|
|||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
||||
}
|
||||
|
||||
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
# can override in dialects
|
||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
CREATE_TRANSIENT = False
|
||||
# whether or not null ordering is supported in order by
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
# always do union distinct or union all
|
||||
|
||||
# Always do union distinct or union all
|
||||
EXPLICIT_UNION = False
|
||||
# wrap derived values in parens, usually standard but spark doesn't support it
|
||||
|
||||
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||
WRAP_DERIVED_VALUES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -80,7 +87,7 @@ class Generator:
|
|||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
}
|
||||
|
||||
TOKEN_MAPPING = {}
|
||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
||||
|
@ -96,6 +103,8 @@ class Generator:
|
|||
exp.TableFormatProperty,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select,)
|
||||
|
||||
__slots__ = (
|
||||
"time_mapping",
|
||||
"time_trie",
|
||||
|
@ -122,7 +131,7 @@ class Generator:
|
|||
"_escaped_quote_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
"_annotations",
|
||||
"_comments",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -148,7 +157,7 @@ class Generator:
|
|||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
annotations=True,
|
||||
comments=True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -177,7 +186,7 @@ class Generator:
|
|||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
self._annotations = annotations
|
||||
self._comments = comments
|
||||
|
||||
def generate(self, expression):
|
||||
"""
|
||||
|
@ -204,7 +213,6 @@ class Generator:
|
|||
return sql
|
||||
|
||||
def unsupported(self, message):
|
||||
|
||||
if self.unsupported_level == ErrorLevel.IMMEDIATE:
|
||||
raise UnsupportedError(message)
|
||||
self.unsupported_messages.append(message)
|
||||
|
@ -215,9 +223,31 @@ class Generator:
|
|||
def seg(self, sql, sep=" "):
|
||||
return f"{self.sep(sep)}{sql}"
|
||||
|
||||
def maybe_comment(self, sql, expression, single_line=False):
|
||||
comment = expression.comment if self._comments else None
|
||||
|
||||
if not comment:
|
||||
return sql
|
||||
|
||||
comment = " " + comment if comment[0].strip() else comment
|
||||
comment = comment + " " if comment[-1].strip() else comment
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return f"/*{comment}*/{self.sep()}{sql}"
|
||||
|
||||
if not self.pretty:
|
||||
return f"{sql} /*{comment}*/"
|
||||
|
||||
if not NEWLINE_RE.search(comment):
|
||||
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
||||
|
||||
return f"/*{comment}*/\n{sql}"
|
||||
|
||||
def wrap(self, expression):
|
||||
this_sql = self.indent(
|
||||
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
|
||||
self.sql(expression)
|
||||
if isinstance(expression, (exp.Select, exp.Union))
|
||||
else self.sql(expression, "this"),
|
||||
level=1,
|
||||
pad=0,
|
||||
)
|
||||
|
@ -251,7 +281,7 @@ class Generator:
|
|||
for i, line in enumerate(lines)
|
||||
)
|
||||
|
||||
def sql(self, expression, key=None):
|
||||
def sql(self, expression, key=None, comment=True):
|
||||
if not expression:
|
||||
return ""
|
||||
|
||||
|
@ -264,29 +294,24 @@ class Generator:
|
|||
transform = self.TRANSFORMS.get(expression.__class__)
|
||||
|
||||
if callable(transform):
|
||||
return transform(self, expression)
|
||||
if transform:
|
||||
return transform
|
||||
sql = transform(self, expression)
|
||||
elif transform:
|
||||
sql = transform
|
||||
elif isinstance(expression, exp.Expression):
|
||||
exp_handler_name = f"{expression.key}_sql"
|
||||
|
||||
if not isinstance(expression, exp.Expression):
|
||||
if hasattr(self, exp_handler_name):
|
||||
sql = getattr(self, exp_handler_name)(expression)
|
||||
elif isinstance(expression, exp.Func):
|
||||
sql = self.function_fallback_sql(expression)
|
||||
elif isinstance(expression, exp.Property):
|
||||
sql = self.property_sql(expression)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
else:
|
||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||
|
||||
exp_handler_name = f"{expression.key}_sql"
|
||||
if hasattr(self, exp_handler_name):
|
||||
return getattr(self, exp_handler_name)(expression)
|
||||
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
if isinstance(expression, exp.Property):
|
||||
return self.property_sql(expression)
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
if self._annotations and self.pretty:
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name}"
|
||||
return self.sql(expression, "expression")
|
||||
return self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||
|
||||
def uncache_sql(self, expression):
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -371,7 +396,9 @@ class Generator:
|
|||
expression_sql = self.sql(expression, "expression")
|
||||
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||
transient = (
|
||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||
)
|
||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
|
@ -434,7 +461,9 @@ class Generator:
|
|||
def delete_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
using_sql = (
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
|
||||
if expression.args.get("using")
|
||||
else ""
|
||||
)
|
||||
where_sql = self.sql(expression, "where")
|
||||
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
|
||||
|
@ -481,15 +510,18 @@ class Generator:
|
|||
return f"{this} ON {table} {columns}"
|
||||
|
||||
def identifier_sql(self, expression):
|
||||
value = expression.name
|
||||
value = value.lower() if self.normalize else value
|
||||
text = expression.name
|
||||
text = text.lower() if self.normalize else text
|
||||
if expression.args.get("quoted") or self.identify:
|
||||
return f"{self.identifier_start}{value}{self.identifier_end}"
|
||||
return value
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
def partition_sql(self, expression):
|
||||
keys = csv(
|
||||
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
|
||||
*[
|
||||
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
|
||||
for prop in expression.this
|
||||
]
|
||||
)
|
||||
return f"PARTITION({keys})"
|
||||
|
||||
|
@ -504,9 +536,9 @@ class Generator:
|
|||
elif p_class in self.ROOT_PROPERTIES:
|
||||
root_properties.append(p)
|
||||
|
||||
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
|
||||
exp.Properties(expressions=with_properties)
|
||||
)
|
||||
return self.root_properties(
|
||||
exp.Properties(expressions=root_properties)
|
||||
) + self.with_properties(exp.Properties(expressions=with_properties))
|
||||
|
||||
def root_properties(self, properties):
|
||||
if properties.expressions:
|
||||
|
@ -551,7 +583,9 @@ class Generator:
|
|||
|
||||
this = f"{this}{self.sql(expression, 'this')}"
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
partition_sql = (
|
||||
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
|
@ -669,7 +703,9 @@ class Generator:
|
|||
def group_sql(self, expression):
|
||||
group_by = self.op_expressions("GROUP BY", expression)
|
||||
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
||||
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||
grouping_sets = (
|
||||
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||
)
|
||||
cube = self.expressions(expression, key="cube", indent=False)
|
||||
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||
|
@ -711,10 +747,10 @@ class Generator:
|
|||
this_sql = self.sql(expression, "this")
|
||||
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
|
||||
|
||||
def lambda_sql(self, expression):
|
||||
def lambda_sql(self, expression, arrow_sep="->"):
|
||||
args = self.expressions(expression, flat=True)
|
||||
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
|
||||
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
|
||||
|
||||
def lateral_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -748,7 +784,7 @@ class Generator:
|
|||
if self._replace_backslash:
|
||||
text = text.replace("\\", "\\\\")
|
||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||
return f"{self.quote_start}{text}{self.quote_end}"
|
||||
text = f"{self.quote_start}{text}{self.quote_end}"
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression):
|
||||
|
@ -796,13 +832,21 @@ class Generator:
|
|||
|
||||
sort_order = " DESC" if desc else ""
|
||||
nulls_sort_change = ""
|
||||
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
|
||||
if nulls_first and (
|
||||
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
|
||||
):
|
||||
nulls_sort_change = " NULLS FIRST"
|
||||
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
|
||||
elif (
|
||||
nulls_last
|
||||
and ((asc and nulls_are_small) or (desc and nulls_are_large))
|
||||
and not nulls_are_last
|
||||
):
|
||||
nulls_sort_change = " NULLS LAST"
|
||||
|
||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
|
||||
self.unsupported(
|
||||
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
|
||||
)
|
||||
nulls_sort_change = ""
|
||||
|
||||
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
||||
|
@ -835,7 +879,7 @@ class Generator:
|
|||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{hint}{distinct}{expressions}",
|
||||
self.sql(expression, "from"),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
|
@ -858,6 +902,13 @@ class Generator:
|
|||
def parameter_sql(self, expression):
|
||||
return f"@{self.sql(expression, 'this')}"
|
||||
|
||||
def sessionparameter_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
kind = expression.text("kind")
|
||||
if kind:
|
||||
kind = f"{kind}."
|
||||
return f"@@{kind}{this}"
|
||||
|
||||
def placeholder_sql(self, expression):
|
||||
return f":{expression.name}" if expression.name else "?"
|
||||
|
||||
|
@ -931,7 +982,10 @@ class Generator:
|
|||
def window_spec_sql(self, expression):
|
||||
kind = self.sql(expression, "kind")
|
||||
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
||||
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
|
||||
end = (
|
||||
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
|
||||
or "CURRENT ROW"
|
||||
)
|
||||
return f"{kind} BETWEEN {start} AND {end}"
|
||||
|
||||
def withingroup_sql(self, expression):
|
||||
|
@ -1020,7 +1074,9 @@ class Generator:
|
|||
return f"UNIQUE ({columns})"
|
||||
|
||||
def if_sql(self, expression):
|
||||
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
|
||||
return self.case_sql(
|
||||
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||
)
|
||||
|
||||
def in_sql(self, expression):
|
||||
query = expression.args.get("query")
|
||||
|
@ -1196,6 +1252,12 @@ class Generator:
|
|||
def neq_sql(self, expression):
|
||||
return self.binary(expression, "<>")
|
||||
|
||||
def nullsafeeq_sql(self, expression):
|
||||
return self.binary(expression, "IS NOT DISTINCT FROM")
|
||||
|
||||
def nullsafeneq_sql(self, expression):
|
||||
return self.binary(expression, "IS DISTINCT FROM")
|
||||
|
||||
def or_sql(self, expression):
|
||||
return self.connector_sql(expression, "OR")
|
||||
|
||||
|
@ -1205,6 +1267,9 @@ class Generator:
|
|||
def trycast_sql(self, expression):
|
||||
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||
|
||||
def use_sql(self, expression):
|
||||
return f"USE {self.sql(expression, 'this')}"
|
||||
|
||||
def binary(self, expression, op):
|
||||
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||
|
||||
|
@ -1240,17 +1305,27 @@ class Generator:
|
|||
if flat:
|
||||
return sep.join(self.sql(e) for e in expressions)
|
||||
|
||||
sql = (self.sql(e) for e in expressions)
|
||||
# the only time leading_comma changes the output is if pretty print is enabled
|
||||
if self._leading_comma and self.pretty:
|
||||
pad = " " * self.pad
|
||||
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
|
||||
else:
|
||||
expressions = self.sep(sep).join(sql)
|
||||
num_sqls = len(expressions)
|
||||
|
||||
if indent:
|
||||
return self.indent(expressions, skip_first=False)
|
||||
return expressions
|
||||
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
|
||||
pad = " " * self.pad
|
||||
stripped_sep = sep.strip()
|
||||
|
||||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
sql = self.sql(e, comment=False)
|
||||
comment = self.maybe_comment("", e, single_line=True)
|
||||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
|
||||
|
||||
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
||||
|
||||
def op_expressions(self, op, expression, flat=False):
|
||||
expressions_sql = self.expressions(expression, flat=flat)
|
||||
|
@ -1264,7 +1339,9 @@ class Generator:
|
|||
def set_operation(self, expression, op):
|
||||
this = self.sql(expression, "this")
|
||||
op = self.seg(op)
|
||||
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
|
||||
return self.query_modifiers(
|
||||
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
||||
)
|
||||
|
||||
def token_sql(self, token_type):
|
||||
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
||||
|
@ -1283,3 +1360,6 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
return f"{this}({expressions})"
|
||||
|
||||
def kwarg_sql(self, expression):
|
||||
return self.binary(expression, "=>")
|
||||
|
|
|
@ -1,48 +1,125 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from collections.abc import Collection
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.expressions import Expression, Table
|
||||
|
||||
T = t.TypeVar("T")
|
||||
E = t.TypeVar("E", bound=Expression)
|
||||
|
||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||
PYTHON_VERSION = sys.version_info[:2]
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
class AutoName(Enum):
|
||||
def _generate_next_value_(name, _start, _count, _last_values):
|
||||
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
|
||||
|
||||
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
|
||||
return name
|
||||
|
||||
|
||||
def list_get(arr, index):
|
||||
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
||||
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
|
||||
try:
|
||||
return arr[index]
|
||||
return seq[index]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: T) -> t.List[T]:
|
||||
...
|
||||
|
||||
|
||||
def ensure_list(value):
|
||||
"""
|
||||
Ensures that a value is a list, otherwise casts or wraps it into one.
|
||||
|
||||
Args:
|
||||
value: the value of interest.
|
||||
|
||||
Returns:
|
||||
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
return value if isinstance(value, (list, tuple, set)) else [value]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return list(value)
|
||||
|
||||
return [value]
|
||||
|
||||
|
||||
def csv(*args, sep=", "):
|
||||
@t.overload
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: T) -> t.Collection[T]:
|
||||
...
|
||||
|
||||
|
||||
def ensure_collection(value):
|
||||
"""
|
||||
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
|
||||
|
||||
Args:
|
||||
value: the value of interest.
|
||||
|
||||
Returns:
|
||||
The value if it's a collection, or else the value wrapped in a list.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
return (
|
||||
value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
|
||||
)
|
||||
|
||||
|
||||
def csv(*args, sep: str = ", ") -> str:
|
||||
"""
|
||||
Formats any number of string arguments as CSV.
|
||||
|
||||
Args:
|
||||
args: the string arguments to format.
|
||||
sep: the argument separator.
|
||||
|
||||
Returns:
|
||||
The arguments formatted as a CSV string.
|
||||
"""
|
||||
return sep.join(arg for arg in args if arg)
|
||||
|
||||
|
||||
def subclasses(module_name, classes, exclude=()):
|
||||
def subclasses(
|
||||
module_name: str,
|
||||
classes: t.Type | t.Tuple[t.Type, ...],
|
||||
exclude: t.Type | t.Tuple[t.Type, ...] = (),
|
||||
) -> t.List[t.Type]:
|
||||
"""
|
||||
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
|
||||
Returns all subclasses for a collection of classes, possibly excluding some of them.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to search for subclasses in.
|
||||
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
|
||||
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
|
||||
module_name: the name of the module to search for subclasses in.
|
||||
classes: class(es) we want to find the subclasses of.
|
||||
exclude: class(es) we want to exclude from the returned list.
|
||||
|
||||
Returns:
|
||||
A list of all the target subclasses.
|
||||
The target subclasses.
|
||||
"""
|
||||
return [
|
||||
obj
|
||||
|
@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
|
|||
]
|
||||
|
||||
|
||||
def apply_index_offset(expressions, offset):
|
||||
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
|
||||
"""
|
||||
Applies an offset to a given integer literal expression.
|
||||
|
||||
Args:
|
||||
expressions: the expression the offset will be applied to, wrapped in a list.
|
||||
offset: the offset that will be applied.
|
||||
|
||||
Returns:
|
||||
The original expression with the offset applied to it, wrapped in a list. If the provided
|
||||
`expressions` argument contains more than one expressions, it's returned unaffected.
|
||||
"""
|
||||
if not offset or len(expressions) != 1:
|
||||
return expressions
|
||||
|
||||
|
@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
|
|||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression.args["this"] = str(int(expression.args["this"]) + offset)
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
||||
|
||||
def camel_to_snake_case(name):
|
||||
def camel_to_snake_case(name: str) -> str:
|
||||
"""Converts `name` from camelCase to snake_case and returns the result."""
|
||||
return CAMEL_CASE_PATTERN.sub("_", name).upper()
|
||||
|
||||
|
||||
def while_changing(expression, func):
|
||||
def while_changing(
|
||||
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
|
||||
) -> E:
|
||||
"""
|
||||
Applies a transformation to a given expression until a fix point is reached.
|
||||
|
||||
Args:
|
||||
expression: the expression to be transformed.
|
||||
func: the transformation to be applied.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
while True:
|
||||
start = hash(expression)
|
||||
expression = func(expression)
|
||||
|
@ -80,10 +182,19 @@ def while_changing(expression, func):
|
|||
return expression
|
||||
|
||||
|
||||
def tsort(dag):
|
||||
def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
|
||||
"""
|
||||
Sorts a given directed acyclic graph in topological order.
|
||||
|
||||
Args:
|
||||
dag: the graph to be sorted.
|
||||
|
||||
Returns:
|
||||
A list that contains all of the graph's nodes in topological order.
|
||||
"""
|
||||
result = []
|
||||
|
||||
def visit(node, visited):
|
||||
def visit(node: T, visited: t.Set[T]) -> None:
|
||||
if node in result:
|
||||
return
|
||||
if node in visited:
|
||||
|
@ -103,10 +214,8 @@ def tsort(dag):
|
|||
return result
|
||||
|
||||
|
||||
def open_file(file_name):
|
||||
"""
|
||||
Open a file that may be compressed as gzip and return in newline mode.
|
||||
"""
|
||||
def open_file(file_name: str) -> t.TextIO:
|
||||
"""Open a file that may be compressed as gzip and return it in universal newline mode."""
|
||||
with open(file_name, "rb") as f:
|
||||
gzipped = f.read(2) == b"\x1f\x8b"
|
||||
|
||||
|
@ -119,14 +228,14 @@ def open_file(file_name):
|
|||
|
||||
|
||||
@contextmanager
|
||||
def csv_reader(table):
|
||||
def csv_reader(table: Table) -> t.Any:
|
||||
"""
|
||||
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
|
||||
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
||||
|
||||
Args:
|
||||
table (exp.Table): A table expression with an anonymous function READ_CSV in it
|
||||
table: a `Table` expression with an anonymous function `READ_CSV` in it.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
A python csv reader.
|
||||
"""
|
||||
file, *args = table.this.expressions
|
||||
|
@ -147,13 +256,16 @@ def csv_reader(table):
|
|||
file.close()
|
||||
|
||||
|
||||
def find_new_name(taken, base):
|
||||
def find_new_name(taken: t.Sequence[str], base: str) -> str:
|
||||
"""
|
||||
Searches for a new name.
|
||||
|
||||
Args:
|
||||
taken (Sequence[str]): set of taken names
|
||||
base (str): base name to alter
|
||||
taken: a collection of taken names.
|
||||
base: base name to alter.
|
||||
|
||||
Returns:
|
||||
The new, available name.
|
||||
"""
|
||||
if base not in taken:
|
||||
return base
|
||||
|
@ -163,22 +275,26 @@ def find_new_name(taken, base):
|
|||
while new in taken:
|
||||
i += 1
|
||||
new = f"{base}_{i}"
|
||||
|
||||
return new
|
||||
|
||||
|
||||
def object_to_dict(obj, **kwargs):
|
||||
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
|
||||
"""Returns a dictionary created from an object's attributes."""
|
||||
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
|
||||
|
||||
|
||||
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
|
||||
def split_num_words(
|
||||
value: str, sep: str, min_num_words: int, fill_from_start: bool = True
|
||||
) -> t.List[t.Optional[str]]:
|
||||
"""
|
||||
Perform a split on a value and return N words as a result with None used for words that don't exist.
|
||||
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
|
||||
|
||||
Args:
|
||||
value: The value to be split
|
||||
sep: The value to use to split on
|
||||
min_num_words: The minimum number of words that are going to be in the result
|
||||
fill_from_start: Indicates that if None values should be inserted at the start or end of the list
|
||||
value: the value to be split.
|
||||
sep: the value to use to split on.
|
||||
min_num_words: the minimum number of words that are going to be in the result.
|
||||
fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
|
||||
|
||||
Examples:
|
||||
>>> split_num_words("db.table", ".", 3)
|
||||
|
@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
|
|||
['db', 'table', None]
|
||||
>>> split_num_words("db.table", ".", 1)
|
||||
['db', 'table']
|
||||
|
||||
Returns:
|
||||
The list of words returned by `split`, possibly augmented by a number of `None` values.
|
||||
"""
|
||||
words = value.split(sep)
|
||||
if fill_from_start:
|
||||
|
@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
|
|||
|
||||
def is_iterable(value: t.Any) -> bool:
|
||||
"""
|
||||
Checks if the value is an iterable but does not include strings and bytes
|
||||
Checks if the value is an iterable, excluding the types `str` and `bytes`.
|
||||
|
||||
Examples:
|
||||
>>> is_iterable([1,2])
|
||||
|
@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
|
|||
False
|
||||
|
||||
Args:
|
||||
value: The value to check if it is an interable
|
||||
value: the value to check if it is an iterable.
|
||||
|
||||
Returns: Bool indicating if it is an iterable
|
||||
Returns:
|
||||
A `bool` value indicating if it is an iterable.
|
||||
"""
|
||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
||||
|
||||
|
||||
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
|
||||
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
|
||||
"""
|
||||
Flattens a list that can contain both iterables and non-iterable elements
|
||||
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
|
||||
type `str` and `bytes` are not regarded as iterables.
|
||||
|
||||
Examples:
|
||||
>>> list(flatten([[1, 2], 3]))
|
||||
[1, 2, 3]
|
||||
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
|
||||
[1, 2, 3, 4, 5, 'bla']
|
||||
>>> list(flatten([1, 2, 3]))
|
||||
[1, 2, 3]
|
||||
|
||||
Args:
|
||||
values: The value to be flattened
|
||||
values: the value to be flattened.
|
||||
|
||||
Returns:
|
||||
Yields non-iterable elements (not including str or byte as iterable)
|
||||
Yields:
|
||||
Non-iterable elements in `values`.
|
||||
"""
|
||||
for value in values:
|
||||
if is_iterable(value):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.helper import ensure_collection, ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -48,35 +48,65 @@ class TypeAnnotator:
|
|||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BIGINT
|
||||
),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
|
||||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
|
@ -88,32 +118,52 @@ class TypeAnnotator:
|
|||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BOOLEAN
|
||||
),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
|
@ -124,7 +174,11 @@ class TypeAnnotator:
|
|||
exp.DataType.Type.TEXT: set(),
|
||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
exp.DataType.Type.CHAR: {
|
||||
exp.DataType.Type.NCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
|
@ -135,7 +189,11 @@ class TypeAnnotator:
|
|||
exp.DataType.Type.DOUBLE: set(),
|
||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
|
@ -160,7 +218,10 @@ class TypeAnnotator:
|
|||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
|
@ -219,7 +280,7 @@ class TypeAnnotator:
|
|||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_list(value):
|
||||
for v in ensure_collection(value):
|
||||
self._maybe_annotate(v)
|
||||
|
||||
return expression
|
||||
|
@ -243,7 +304,9 @@ class TypeAnnotator:
|
|||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
|
||||
expression.type = exp.DataType.build(
|
||||
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
|
||||
)
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
elif isinstance(expression, (exp.Condition, exp.Predicate)):
|
||||
|
@ -276,3 +339,17 @@ class TypeAnnotator:
|
|||
def _annotate_with_type(self, expression, target_type):
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_by_args(self, expression, *args):
|
||||
self._annotate_args(expression)
|
||||
expressions = []
|
||||
for arg in args:
|
||||
arg_expr = expression.args.get(arg)
|
||||
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
|
||||
|
||||
last_datatype = None
|
||||
for expr in expressions:
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||
|
||||
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
||||
return expression
|
||||
|
|
|
@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
|
|||
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
|
||||
else:
|
||||
on_clause_columns = set()
|
||||
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
|
||||
return any(
|
||||
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
|
||||
)
|
||||
|
||||
|
||||
def _is_joined_on_all_unique_outputs(scope, join):
|
||||
|
|
|
@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
|
|||
|
||||
# All table names are taken
|
||||
for scope in root.traverse():
|
||||
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
|
||||
taken.update(
|
||||
{
|
||||
source.name: source
|
||||
for _, source in scope.sources.items()
|
||||
if isinstance(source, exp.Table)
|
||||
}
|
||||
)
|
||||
|
||||
# Map of Expression->alias
|
||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||
|
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(cte_scope.expression.parent)
|
||||
|
||||
# Now append the rest
|
||||
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
|
||||
for scope in itertools.chain(
|
||||
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
|
||||
):
|
||||
for child_scope in scope.traverse():
|
||||
new_cte = _eliminate(child_scope, existing_ctes, taken)
|
||||
if new_cte:
|
||||
|
|
|
@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
unmergable_window_columns = [
|
||||
column
|
||||
for column in outer_scope.columns
|
||||
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
|
||||
if column.find_ancestor(
|
||||
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
|
||||
)
|
||||
]
|
||||
window_expressions_in_unmergable = [
|
||||
column
|
||||
|
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
|||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
|
||||
and any(
|
||||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
)
|
||||
|
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
outer_scope.add_source(
|
||||
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
|
||||
)
|
||||
|
||||
|
||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||
|
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
|
|||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
"""
|
||||
if (
|
||||
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
|
||||
any(
|
||||
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
|
||||
)
|
||||
or len(outer_scope.selected_sources) != 1
|
||||
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
|
||||
):
|
||||
|
|
|
@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
|
|||
Returns:
|
||||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
)
|
||||
|
||||
|
||||
def _predicate_lengths(expression, dnf):
|
||||
|
|
|
@ -68,4 +68,8 @@ def normalize(expression):
|
|||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
|
||||
if name != exclude
|
||||
]
|
||||
|
|
|
@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
|
|
@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
|
|||
condition = condition.replace(simplify(condition))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
|
||||
predicates = list(
|
||||
condition.flatten()
|
||||
if isinstance(condition, exp.And if cnf_like else exp.Or)
|
||||
else [condition]
|
||||
)
|
||||
|
||||
if cnf_like:
|
||||
pushdown_cnf(predicates, sources, scope_ref_count)
|
||||
|
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
for column in predicate.find_all(exp.Column):
|
||||
if column.table == table:
|
||||
condition = column.find_ancestor(exp.Condition)
|
||||
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
|
||||
predicate_condition = (
|
||||
exp.and_(predicate_condition, condition)
|
||||
if predicate_condition
|
||||
else condition
|
||||
)
|
||||
|
||||
if predicate_condition:
|
||||
conditions[table] = (
|
||||
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
|
||||
exp.or_(conditions[table], predicate_condition)
|
||||
if table in conditions
|
||||
else predicate_condition
|
||||
)
|
||||
|
||||
for name, node in nodes.items():
|
||||
|
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
# We can't push down window expressions
|
||||
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
|
||||
has_window_expression = any(
|
||||
select for select in node.selects if select.find(exp.Window)
|
||||
)
|
||||
# we can't push down predicates to select statements if they are referenced in
|
||||
# multiple places.
|
||||
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
|
||||
if (
|
||||
not node.args.get("group")
|
||||
and scope_ref_count[id(source)] < 2
|
||||
and not has_window_expression
|
||||
):
|
||||
nodes[table] = node
|
||||
return nodes
|
||||
|
||||
|
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
|
|||
|
||||
def _replace_alias(column):
|
||||
if isinstance(column, exp.Column) and column.name in aliases:
|
||||
return aliases[column.name]
|
||||
return aliases[column.name].copy()
|
||||
return column
|
||||
|
||||
return predicate.transform(_replace_alias)
|
||||
|
|
|
@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
|
||||
new_selections = [
|
||||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
|
|||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
# Determine whether each reference in the having clause is to a column or an alias.
|
||||
for having in scope.find_all(exp.Having):
|
||||
for column in having.find_all(exp.Column):
|
||||
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
|
||||
if (
|
||||
not column.table
|
||||
and column.find_ancestor(exp.AggFunc)
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
for column in columns_missing_from_scope:
|
||||
|
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
|
|||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
):
|
||||
if isinstance(selection, exp.Column):
|
||||
# convoluted setter because a simple selection.replace(alias) would require a copy
|
||||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
|
@ -343,14 +353,18 @@ class _Resolver:
|
|||
(str) table name
|
||||
"""
|
||||
if self._unambiguous_columns is None:
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
|
||||
self._unambiguous_columns = self._get_unambiguous_columns(
|
||||
self._get_all_source_columns()
|
||||
)
|
||||
return self._unambiguous_columns.get(column_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
|
||||
self._all_columns = set(
|
||||
column for columns in self._get_all_source_columns().values() for column in columns
|
||||
)
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
|
@ -377,7 +391,9 @@ class _Resolver:
|
|||
|
||||
def _get_all_source_columns(self):
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k) for k in self.scope.selected_sources
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(self, source_columns):
|
||||
|
|
|
@ -226,7 +226,9 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
columns = self._raw_columns
|
||||
|
||||
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
|
||||
external_columns = [
|
||||
column for scope in self.subquery_scopes for column in scope.external_columns
|
||||
]
|
||||
|
||||
named_outputs = {e.alias_or_name for e in self.expression.expressions}
|
||||
|
||||
|
@ -278,7 +280,11 @@ class Scope:
|
|||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
|
||||
return {
|
||||
alias: scope
|
||||
for alias, scope in self.sources.items()
|
||||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
|
@ -307,7 +313,9 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
]
|
||||
return self._external_columns
|
||||
|
||||
@property
|
||||
|
|
|
@ -229,7 +229,9 @@ def simplify_literals(expression):
|
|||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
|
|||
return TRUE if not_ else FALSE
|
||||
if a == NULL:
|
||||
return FALSE if not_ else TRUE
|
||||
elif isinstance(expression, exp.NullSafeEQ):
|
||||
if a == b:
|
||||
return TRUE
|
||||
elif isinstance(expression, exp.NullSafeNEQ):
|
||||
if a == b:
|
||||
return FALSE
|
||||
elif NULL in (a, b):
|
||||
return NULL
|
||||
|
||||
|
@ -357,7 +365,7 @@ def extract_date(cast):
|
|||
|
||||
def extract_interval(interval):
|
||||
try:
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
|
|
@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
return
|
||||
|
||||
if isinstance(predicate, exp.Binary):
|
||||
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
else:
|
||||
parent_predicate = _replace(parent_predicate, "TRUE")
|
||||
elif isinstance(parent_predicate, exp.All):
|
||||
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
|
||||
parent_predicate = _replace(
|
||||
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
|
||||
)
|
||||
elif isinstance(parent_predicate, exp.Any):
|
||||
if value.this in group_by:
|
||||
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
|
||||
|
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
|
||||
parent_predicate = _replace(
|
||||
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
|
||||
)
|
||||
elif isinstance(predicate, exp.EQ):
|
||||
parent_predicate = _replace(
|
||||
parent_predicate,
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
|
||||
from sqlglot.helper import apply_index_offset, ensure_list, list_get
|
||||
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -20,7 +24,15 @@ def parse_var_map(args):
|
|||
)
|
||||
|
||||
|
||||
class Parser:
|
||||
class _Parser(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
|
||||
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
|
||||
return klass
|
||||
|
||||
|
||||
class Parser(metaclass=_Parser):
|
||||
"""
|
||||
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
|
||||
and produces a parsed syntax tree.
|
||||
|
@ -45,16 +57,16 @@ class Parser:
|
|||
FUNCTIONS = {
|
||||
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
|
||||
"DATE_TO_DATE_STR": lambda args: exp.Cast(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
|
||||
this=exp.Cast(
|
||||
this=list_get(args, 0),
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
start=exp.Literal.number(1),
|
||||
|
@ -90,6 +102,7 @@ class Parser:
|
|||
TokenType.NVARCHAR,
|
||||
TokenType.TEXT,
|
||||
TokenType.BINARY,
|
||||
TokenType.VARBINARY,
|
||||
TokenType.JSON,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.TIMESTAMP,
|
||||
|
@ -243,6 +256,7 @@ class Parser:
|
|||
EQUALITY = {
|
||||
TokenType.EQ: exp.EQ,
|
||||
TokenType.NEQ: exp.NEQ,
|
||||
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
|
||||
}
|
||||
|
||||
COMPARISON = {
|
||||
|
@ -298,6 +312,21 @@ class Parser:
|
|||
TokenType.ANTI,
|
||||
}
|
||||
|
||||
LAMBDAS = {
|
||||
TokenType.ARROW: lambda self, expressions: self.expression(
|
||||
exp.Lambda,
|
||||
this=self._parse_conjunction().transform(
|
||||
self._replace_lambda, {node.name for node in expressions}
|
||||
),
|
||||
expressions=expressions,
|
||||
),
|
||||
TokenType.FARROW: lambda self, expressions: self.expression(
|
||||
exp.Kwarg,
|
||||
this=exp.Var(this=expressions[0].name),
|
||||
expression=self._parse_conjunction(),
|
||||
),
|
||||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
TokenType.DOT: None,
|
||||
TokenType.DCOLON: lambda self, this, to: self.expression(
|
||||
|
@ -362,20 +391,30 @@ class Parser:
|
|||
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
TokenType.USE: lambda self: self._parse_use(),
|
||||
}
|
||||
|
||||
PRIMARY_PARSERS = {
|
||||
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
|
||||
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
|
||||
TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
|
||||
TokenType.NULL: lambda *_: exp.Null(),
|
||||
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
||||
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
||||
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
|
||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
||||
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
||||
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
|
||||
TokenType.STRING: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=True
|
||||
),
|
||||
TokenType.NUMBER: lambda self, token: self.expression(
|
||||
exp.Literal, this=token.text, is_string=False
|
||||
),
|
||||
TokenType.STAR: lambda self, _: self.expression(
|
||||
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
|
||||
),
|
||||
TokenType.NULL: lambda self, _: self.expression(exp.Null),
|
||||
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
|
||||
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
|
||||
TokenType.PARAMETER: lambda self, _: self.expression(
|
||||
exp.Parameter, this=self._parse_var() or self._parse_primary()
|
||||
),
|
||||
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
|
||||
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
|
@ -411,16 +450,24 @@ class Parser:
|
|||
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
|
||||
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
|
||||
exp.TableFormatProperty
|
||||
),
|
||||
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
|
||||
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
|
||||
TokenType.DETERMINISTIC: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
|
||||
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
|
||||
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
|
||||
TokenType.IMMUTABLE: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
TokenType.STABLE: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
|
||||
),
|
||||
TokenType.VOLATILE: lambda self: self.expression(
|
||||
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
|
||||
),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
|
@ -450,7 +497,8 @@ class Parser:
|
|||
"group": lambda self: self._parse_group(),
|
||||
"having": lambda self: self._parse_having(),
|
||||
"qualify": lambda self: self._parse_qualify(),
|
||||
"window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
|
||||
"window": lambda self: self._match(TokenType.WINDOW)
|
||||
and self._parse_window(self._parse_id_var(), alias=True),
|
||||
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
|
||||
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
||||
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
||||
|
@ -459,6 +507,9 @@ class Parser:
|
|||
"offset": lambda self: self._parse_offset(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
SET_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
CREATABLES = {
|
||||
|
@ -488,7 +539,9 @@ class Parser:
|
|||
"_curr",
|
||||
"_next",
|
||||
"_prev",
|
||||
"_greedy_subqueries",
|
||||
"_prev_comment",
|
||||
"_show_trie",
|
||||
"_set_trie",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -519,7 +572,7 @@ class Parser:
|
|||
self._curr = None
|
||||
self._next = None
|
||||
self._prev = None
|
||||
self._greedy_subqueries = False
|
||||
self._prev_comment = None
|
||||
|
||||
def parse(self, raw_tokens, sql=None):
|
||||
"""
|
||||
|
@ -533,10 +586,12 @@ class Parser:
|
|||
Returns
|
||||
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
|
||||
"""
|
||||
return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
|
||||
return self._parse(
|
||||
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
|
||||
)
|
||||
|
||||
def parse_into(self, expression_types, raw_tokens, sql=None):
|
||||
for expression_type in ensure_list(expression_types):
|
||||
for expression_type in ensure_collection(expression_types):
|
||||
parser = self.EXPRESSION_PARSERS.get(expression_type)
|
||||
if not parser:
|
||||
raise TypeError(f"No parser registered for {expression_type}")
|
||||
|
@ -597,6 +652,9 @@ class Parser:
|
|||
|
||||
def expression(self, exp_class, **kwargs):
|
||||
instance = exp_class(**kwargs)
|
||||
if self._prev_comment:
|
||||
instance.comment = self._prev_comment
|
||||
self._prev_comment = None
|
||||
self.validate_expression(instance)
|
||||
return instance
|
||||
|
||||
|
@ -633,14 +691,16 @@ class Parser:
|
|||
|
||||
return index
|
||||
|
||||
def _get_token(self, index):
|
||||
return list_get(self._tokens, index)
|
||||
|
||||
def _advance(self, times=1):
|
||||
self._index += times
|
||||
self._curr = self._get_token(self._index)
|
||||
self._next = self._get_token(self._index + 1)
|
||||
self._prev = self._get_token(self._index - 1) if self._index > 0 else None
|
||||
self._curr = seq_get(self._tokens, self._index)
|
||||
self._next = seq_get(self._tokens, self._index + 1)
|
||||
if self._index > 0:
|
||||
self._prev = self._tokens[self._index - 1]
|
||||
self._prev_comment = self._prev.comment
|
||||
else:
|
||||
self._prev = None
|
||||
self._prev_comment = None
|
||||
|
||||
def _retreat(self, index):
|
||||
self._advance(index - self._index)
|
||||
|
@ -661,6 +721,7 @@ class Parser:
|
|||
|
||||
expression = self._parse_expression()
|
||||
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
||||
|
||||
self._parse_query_modifiers(expression)
|
||||
return expression
|
||||
|
||||
|
@ -682,7 +743,11 @@ class Parser:
|
|||
)
|
||||
|
||||
def _parse_exists(self, not_=False):
|
||||
return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
|
||||
return (
|
||||
self._match(TokenType.IF)
|
||||
and (not not_ or self._match(TokenType.NOT))
|
||||
and self._match(TokenType.EXISTS)
|
||||
)
|
||||
|
||||
def _parse_create(self):
|
||||
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
|
||||
|
@ -931,7 +996,9 @@ class Parser:
|
|||
return self.expression(
|
||||
exp.Delete,
|
||||
this=self._parse_table(schema=True),
|
||||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
|
||||
using=self._parse_csv(
|
||||
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
|
||||
),
|
||||
where=self._parse_where(),
|
||||
)
|
||||
|
||||
|
@ -983,11 +1050,13 @@ class Parser:
|
|||
return None
|
||||
|
||||
def parse_values():
|
||||
k = self._parse_var()
|
||||
key = self._parse_var()
|
||||
value = None
|
||||
|
||||
if self._match(TokenType.EQ):
|
||||
v = self._parse_string()
|
||||
return (k, v)
|
||||
return (k, None)
|
||||
value = self._parse_string()
|
||||
|
||||
return exp.Property(this=key, value=value)
|
||||
|
||||
self._match_l_paren()
|
||||
values = self._parse_csv(parse_values)
|
||||
|
@ -1019,6 +1088,8 @@ class Parser:
|
|||
self.raise_error(f"{this.key} does not support CTE")
|
||||
this = cte
|
||||
elif self._match(TokenType.SELECT):
|
||||
comment = self._prev_comment
|
||||
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
|
@ -1033,7 +1104,7 @@ class Parser:
|
|||
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
||||
|
||||
limit = self._parse_limit(top=True)
|
||||
expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
|
||||
expressions = self._parse_csv(self._parse_expression)
|
||||
|
||||
this = self.expression(
|
||||
exp.Select,
|
||||
|
@ -1042,6 +1113,7 @@ class Parser:
|
|||
expressions=expressions,
|
||||
limit=limit,
|
||||
)
|
||||
this.comment = comment
|
||||
from_ = self._parse_from()
|
||||
if from_:
|
||||
this.set("from", from_)
|
||||
|
@ -1072,8 +1144,10 @@ class Parser:
|
|||
while True:
|
||||
expressions.append(self._parse_cte())
|
||||
|
||||
if not self._match(TokenType.COMMA):
|
||||
if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
|
||||
break
|
||||
else:
|
||||
self._match(TokenType.WITH)
|
||||
|
||||
return self.expression(
|
||||
exp.With,
|
||||
|
@ -1111,11 +1185,7 @@ class Parser:
|
|||
if not alias and not columns:
|
||||
return None
|
||||
|
||||
return self.expression(
|
||||
exp.TableAlias,
|
||||
this=alias,
|
||||
columns=columns,
|
||||
)
|
||||
return self.expression(exp.TableAlias, this=alias, columns=columns)
|
||||
|
||||
def _parse_subquery(self, this):
|
||||
return self.expression(
|
||||
|
@ -1150,12 +1220,6 @@ class Parser:
|
|||
if expression:
|
||||
this.set(key, expression)
|
||||
|
||||
def _parse_annotation(self, expression):
|
||||
if self._match(TokenType.ANNOTATION):
|
||||
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
|
||||
|
||||
return expression
|
||||
|
||||
def _parse_hint(self):
|
||||
if self._match(TokenType.HINT):
|
||||
hints = self._parse_csv(self._parse_function)
|
||||
|
@ -1295,7 +1359,9 @@ class Parser:
|
|||
if not table:
|
||||
self.raise_error("Expected table name")
|
||||
|
||||
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
|
||||
this = self.expression(
|
||||
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||
)
|
||||
|
||||
if schema:
|
||||
return self._parse_schema(this=this)
|
||||
|
@ -1500,7 +1566,9 @@ class Parser:
|
|||
if not skip_order_token and not self._match(TokenType.ORDER_BY):
|
||||
return this
|
||||
|
||||
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
|
||||
return self.expression(
|
||||
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
|
||||
)
|
||||
|
||||
def _parse_sort(self, token_type, exp_class):
|
||||
if not self._match(token_type):
|
||||
|
@ -1521,7 +1589,8 @@ class Parser:
|
|||
if (
|
||||
not explicitly_null_ordered
|
||||
and (
|
||||
(asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
|
||||
(asc and self.null_ordering == "nulls_are_small")
|
||||
or (desc and self.null_ordering != "nulls_are_small")
|
||||
)
|
||||
and self.null_ordering != "nulls_are_last"
|
||||
):
|
||||
|
@ -1606,6 +1675,9 @@ class Parser:
|
|||
|
||||
def _parse_is(self, this):
|
||||
negate = self._match(TokenType.NOT)
|
||||
if self._match(TokenType.DISTINCT_FROM):
|
||||
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||
return self.expression(klass, this=this, expression=self._parse_expression())
|
||||
this = self.expression(
|
||||
exp.Is,
|
||||
this=this,
|
||||
|
@ -1653,9 +1725,13 @@ class Parser:
|
|||
expression=self._parse_term(),
|
||||
)
|
||||
elif self._match_pair(TokenType.LT, TokenType.LT):
|
||||
this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
|
||||
this = self.expression(
|
||||
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
|
||||
)
|
||||
elif self._match_pair(TokenType.GT, TokenType.GT):
|
||||
this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
|
||||
this = self.expression(
|
||||
exp.BitwiseRightShift, this=this, expression=self._parse_term()
|
||||
)
|
||||
else:
|
||||
break
|
||||
|
||||
|
@ -1685,7 +1761,7 @@ class Parser:
|
|||
)
|
||||
|
||||
index = self._index
|
||||
type_token = self._parse_types()
|
||||
type_token = self._parse_types(check_func=True)
|
||||
this = self._parse_column()
|
||||
|
||||
if type_token:
|
||||
|
@ -1698,7 +1774,7 @@ class Parser:
|
|||
|
||||
return this
|
||||
|
||||
def _parse_types(self):
|
||||
def _parse_types(self, check_func=False):
|
||||
index = self._index
|
||||
|
||||
if not self._match_set(self.TYPE_TOKENS):
|
||||
|
@ -1708,10 +1784,13 @@ class Parser:
|
|||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token == TokenType.STRUCT
|
||||
expressions = None
|
||||
maybe_func = False
|
||||
|
||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[exp.DataType.build(type_token.value)],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
|
@ -1731,6 +1810,7 @@ class Parser:
|
|||
return None
|
||||
|
||||
self._match_r_paren()
|
||||
maybe_func = True
|
||||
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
|
@ -1741,25 +1821,46 @@ class Parser:
|
|||
if not self._match(TokenType.GT):
|
||||
self.raise_error("Expecting >")
|
||||
|
||||
value = None
|
||||
if type_token in self.TIMESTAMPS:
|
||||
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
|
||||
if tz:
|
||||
return exp.DataType(
|
||||
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
|
||||
value = exp.DataType(
|
||||
this=exp.DataType.Type.TIMESTAMPTZ,
|
||||
expressions=expressions,
|
||||
)
|
||||
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
||||
if ltz:
|
||||
return exp.DataType(
|
||||
elif (
|
||||
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
||||
):
|
||||
value = exp.DataType(
|
||||
this=exp.DataType.Type.TIMESTAMPLTZ,
|
||||
expressions=expressions,
|
||||
)
|
||||
self._match(TokenType.WITHOUT_TIME_ZONE)
|
||||
elif self._match(TokenType.WITHOUT_TIME_ZONE):
|
||||
value = exp.DataType(
|
||||
this=exp.DataType.Type.TIMESTAMP,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type.TIMESTAMP,
|
||||
expressions=expressions,
|
||||
)
|
||||
maybe_func = maybe_func and value is None
|
||||
|
||||
if value is None:
|
||||
value = exp.DataType(
|
||||
this=exp.DataType.Type.TIMESTAMP,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
if maybe_func and check_func:
|
||||
index2 = self._index
|
||||
peek = self._parse_string()
|
||||
|
||||
if not peek:
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
self._retreat(index2)
|
||||
|
||||
if value:
|
||||
return value
|
||||
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value.upper()],
|
||||
|
@ -1826,22 +1927,29 @@ class Parser:
|
|||
return exp.Literal.number(f"0.{self._prev.text}")
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
comment = self._prev_comment
|
||||
query = self._parse_select()
|
||||
|
||||
if query:
|
||||
expressions = [query]
|
||||
else:
|
||||
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
|
||||
)
|
||||
|
||||
this = list_get(expressions, 0)
|
||||
this = seq_get(expressions, 0)
|
||||
self._parse_query_modifiers(this)
|
||||
self._match_r_paren()
|
||||
|
||||
if isinstance(this, exp.Subqueryable):
|
||||
return self._parse_set_operations(self._parse_subquery(this))
|
||||
if len(expressions) > 1:
|
||||
return self.expression(exp.Tuple, expressions=expressions)
|
||||
return self.expression(exp.Paren, this=this)
|
||||
this = self._parse_set_operations(self._parse_subquery(this))
|
||||
elif len(expressions) > 1:
|
||||
this = self.expression(exp.Tuple, expressions=expressions)
|
||||
else:
|
||||
this = self.expression(exp.Paren, this=this)
|
||||
if comment:
|
||||
this.comment = comment
|
||||
return this
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1894,7 +2002,8 @@ class Parser:
|
|||
self.validate_expression(this, args)
|
||||
else:
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
self._match_r_paren()
|
||||
|
||||
self._match_r_paren(this)
|
||||
return self._parse_window(this)
|
||||
|
||||
def _parse_user_defined_function(self):
|
||||
|
@ -1920,6 +2029,18 @@ class Parser:
|
|||
|
||||
return self.expression(exp.Identifier, this=token.text)
|
||||
|
||||
def _parse_session_parameter(self):
|
||||
kind = None
|
||||
this = self._parse_id_var() or self._parse_primary()
|
||||
if self._match(TokenType.DOT):
|
||||
kind = this.name
|
||||
this = self._parse_var() or self._parse_primary()
|
||||
return self.expression(
|
||||
exp.SessionParameter,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_udf_kwarg(self):
|
||||
this = self._parse_id_var()
|
||||
kind = self._parse_types()
|
||||
|
@ -1938,27 +2059,24 @@ class Parser:
|
|||
else:
|
||||
expressions = [self._parse_id_var()]
|
||||
|
||||
if not self._match(TokenType.ARROW):
|
||||
self._retreat(index)
|
||||
if self._match_set(self.LAMBDAS):
|
||||
return self.LAMBDAS[self._prev.token_type](self, expressions)
|
||||
|
||||
if self._match(TokenType.DISTINCT):
|
||||
this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction))
|
||||
else:
|
||||
this = self._parse_conjunction()
|
||||
self._retreat(index)
|
||||
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
else:
|
||||
self._match(TokenType.RESPECT_NULLS)
|
||||
if self._match(TokenType.DISTINCT):
|
||||
this = self.expression(
|
||||
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
|
||||
)
|
||||
else:
|
||||
this = self._parse_conjunction()
|
||||
|
||||
return self._parse_alias(self._parse_limit(self._parse_order(this)))
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
else:
|
||||
self._match(TokenType.RESPECT_NULLS)
|
||||
|
||||
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
|
||||
return self.expression(
|
||||
exp.Lambda,
|
||||
this=conjunction,
|
||||
expressions=expressions,
|
||||
)
|
||||
return self._parse_alias(self._parse_limit(self._parse_order(this)))
|
||||
|
||||
def _parse_schema(self, this=None):
|
||||
index = self._index
|
||||
|
@ -1966,7 +2084,9 @@ class Parser:
|
|||
self._retreat(index)
|
||||
return this
|
||||
|
||||
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
|
||||
args = self._parse_csv(
|
||||
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
|
||||
)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Schema, this=this, expressions=args)
|
||||
|
||||
|
@ -2104,6 +2224,7 @@ class Parser:
|
|||
if not self._match(TokenType.R_BRACKET):
|
||||
self.raise_error("Expected ]")
|
||||
|
||||
this.comment = self._prev_comment
|
||||
return self._parse_bracket(this)
|
||||
|
||||
def _parse_case(self):
|
||||
|
@ -2124,7 +2245,9 @@ class Parser:
|
|||
if not self._match(TokenType.END):
|
||||
self.raise_error("Expected END after CASE", self._prev)
|
||||
|
||||
return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
|
||||
return self._parse_window(
|
||||
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
|
||||
)
|
||||
|
||||
def _parse_if(self):
|
||||
if self._match(TokenType.L_PAREN):
|
||||
|
@ -2331,7 +2454,9 @@ class Parser:
|
|||
self._match(TokenType.BETWEEN)
|
||||
|
||||
return {
|
||||
"value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
|
||||
"value": (
|
||||
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
|
||||
)
|
||||
or self._parse_bitwise(),
|
||||
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
|
||||
}
|
||||
|
@ -2348,7 +2473,7 @@ class Parser:
|
|||
this=this,
|
||||
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
|
||||
)
|
||||
self._match_r_paren()
|
||||
self._match_r_paren(aliases)
|
||||
return aliases
|
||||
|
||||
alias = self._parse_id_var(any_token)
|
||||
|
@ -2365,28 +2490,29 @@ class Parser:
|
|||
return identifier
|
||||
|
||||
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
|
||||
|
||||
return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
|
||||
self._advance()
|
||||
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
return None
|
||||
return exp.Identifier(this=self._prev.text, quoted=False)
|
||||
|
||||
def _parse_string(self):
|
||||
if self._match(TokenType.STRING):
|
||||
return exp.Literal.string(self._prev.text)
|
||||
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_number(self):
|
||||
if self._match(TokenType.NUMBER):
|
||||
return exp.Literal.number(self._prev.text)
|
||||
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_identifier(self):
|
||||
if self._match(TokenType.IDENTIFIER):
|
||||
return exp.Identifier(this=self._prev.text, quoted=True)
|
||||
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_var(self):
|
||||
if self._match(TokenType.VAR):
|
||||
return exp.Var(this=self._prev.text)
|
||||
return self.expression(exp.Var, this=self._prev.text)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_var_or_string(self):
|
||||
|
@ -2394,27 +2520,27 @@ class Parser:
|
|||
|
||||
def _parse_null(self):
|
||||
if self._match(TokenType.NULL):
|
||||
return exp.Null()
|
||||
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
|
||||
return None
|
||||
|
||||
def _parse_boolean(self):
|
||||
if self._match(TokenType.TRUE):
|
||||
return exp.Boolean(this=True)
|
||||
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
|
||||
if self._match(TokenType.FALSE):
|
||||
return exp.Boolean(this=False)
|
||||
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
|
||||
return None
|
||||
|
||||
def _parse_star(self):
|
||||
if self._match(TokenType.STAR):
|
||||
return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
|
||||
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
|
||||
return None
|
||||
|
||||
def _parse_placeholder(self):
|
||||
if self._match(TokenType.PLACEHOLDER):
|
||||
return exp.Placeholder()
|
||||
return self.expression(exp.Placeholder)
|
||||
elif self._match(TokenType.COLON):
|
||||
self._advance()
|
||||
return exp.Placeholder(this=self._prev.text)
|
||||
return self.expression(exp.Placeholder, this=self._prev.text)
|
||||
return None
|
||||
|
||||
def _parse_except(self):
|
||||
|
@ -2432,22 +2558,27 @@ class Parser:
|
|||
self._match_r_paren()
|
||||
return columns
|
||||
|
||||
def _parse_csv(self, parse):
|
||||
parse_result = parse()
|
||||
def _parse_csv(self, parse_method):
|
||||
parse_result = parse_method()
|
||||
items = [parse_result] if parse_result is not None else []
|
||||
|
||||
while self._match(TokenType.COMMA):
|
||||
parse_result = parse()
|
||||
if parse_result and self._prev_comment is not None:
|
||||
parse_result.comment = self._prev_comment
|
||||
|
||||
parse_result = parse_method()
|
||||
if parse_result is not None:
|
||||
items.append(parse_result)
|
||||
|
||||
return items
|
||||
|
||||
def _parse_tokens(self, parse, expressions):
|
||||
this = parse()
|
||||
def _parse_tokens(self, parse_method, expressions):
|
||||
this = parse_method()
|
||||
|
||||
while self._match_set(expressions):
|
||||
this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
|
||||
this = self.expression(
|
||||
expressions[self._prev.token_type], this=this, expression=parse_method()
|
||||
)
|
||||
|
||||
return this
|
||||
|
||||
|
@ -2460,6 +2591,47 @@ class Parser:
|
|||
def _parse_select_or_expression(self):
|
||||
return self._parse_select() or self._parse_expression()
|
||||
|
||||
def _parse_use(self):
|
||||
return self.expression(exp.Use, this=self._parse_id_var())
|
||||
|
||||
def _parse_show(self):
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||
if parser:
|
||||
return parser(self)
|
||||
self._advance()
|
||||
return self.expression(exp.Show, this=self._prev.text.upper())
|
||||
|
||||
def _default_parse_set_item(self):
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=self._parse_statement(),
|
||||
)
|
||||
|
||||
def _parse_set_item(self):
|
||||
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
|
||||
return parser(self) if parser else self._default_parse_set_item()
|
||||
|
||||
def _parse_set(self):
|
||||
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||
|
||||
def _find_parser(self, parsers, trie):
|
||||
index = self._index
|
||||
this = []
|
||||
while True:
|
||||
# The current token might be multiple words
|
||||
curr = self._curr.text.upper()
|
||||
key = curr.split(" ")
|
||||
this.append(curr)
|
||||
self._advance()
|
||||
result, trie = in_trie(trie, key)
|
||||
if result == 0:
|
||||
break
|
||||
if result == 2:
|
||||
subparser = parsers[" ".join(this)]
|
||||
return subparser
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
def _match(self, token_type):
|
||||
if not self._curr:
|
||||
return None
|
||||
|
@ -2491,13 +2663,17 @@ class Parser:
|
|||
|
||||
return None
|
||||
|
||||
def _match_l_paren(self):
|
||||
def _match_l_paren(self, expression=None):
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
self.raise_error("Expecting (")
|
||||
if expression and self._prev_comment:
|
||||
expression.comment = self._prev_comment
|
||||
|
||||
def _match_r_paren(self):
|
||||
def _match_r_paren(self, expression=None):
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
self.raise_error("Expecting )")
|
||||
if expression and self._prev_comment:
|
||||
expression.comment = self._prev_comment
|
||||
|
||||
def _match_text(self, *texts):
|
||||
index = self._index
|
||||
|
|
|
@ -72,7 +72,9 @@ class Step:
|
|||
if from_:
|
||||
from_ = from_.expressions
|
||||
if len(from_) > 1:
|
||||
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
|
||||
raise UnsupportedError(
|
||||
"Multi-from statements are unsupported. Run it through the optimizer"
|
||||
)
|
||||
|
||||
step = Scan.from_expression(from_[0], ctes)
|
||||
else:
|
||||
|
@ -102,7 +104,7 @@ class Step:
|
|||
continue
|
||||
if operand not in operands:
|
||||
operands[operand] = f"_a_{next(sequence)}"
|
||||
operand.replace(exp.column(operands[operand], step.name, quoted=True))
|
||||
operand.replace(exp.column(operands[operand], quoted=True))
|
||||
else:
|
||||
projections.append(e)
|
||||
|
||||
|
@ -117,9 +119,11 @@ class Step:
|
|||
aggregate = Aggregate()
|
||||
aggregate.source = step.name
|
||||
aggregate.name = step.name
|
||||
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
|
||||
aggregate.operands = tuple(
|
||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||
)
|
||||
aggregate.aggregations = aggregations
|
||||
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
|
||||
aggregate.group = group.expressions
|
||||
aggregate.add_dependency(step)
|
||||
step = aggregate
|
||||
|
||||
|
@ -136,9 +140,6 @@ class Step:
|
|||
sort.key = order.expressions
|
||||
sort.add_dependency(step)
|
||||
step = sort
|
||||
for k in sort.key + projections:
|
||||
for column in k.find_all(exp.Column):
|
||||
column.set("table", exp.to_identifier(step.name, quoted=True))
|
||||
|
||||
step.projections = projections
|
||||
|
||||
|
@ -203,7 +204,9 @@ class Scan(Step):
|
|||
alias_ = expression.alias
|
||||
|
||||
if not alias_:
|
||||
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
|
||||
raise UnsupportedError(
|
||||
"Tables/Subqueries must be aliased. Run it through the optimizer"
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
table = expression.this
|
||||
|
|
0
sqlglot/py.typed
Normal file
0
sqlglot/py.typed
Normal file
|
@ -1,44 +1,60 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.errors import SchemaError
|
||||
from sqlglot.helper import csv_reader
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
|
||||
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
||||
|
||||
TABLE_ARGS = ("this", "db", "catalog")
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_table(self, table, column_mapping=None):
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Some implementing classes may require column information to also be provided
|
||||
Register or update a table. Some implementing classes may require column information to also be provided.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
||||
table: table expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table, only_visible=False):
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
"""
|
||||
Get the column names for a table.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
table: the `Table` expression instance.
|
||||
only_visible: whether to include invisible columns.
|
||||
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
The list of column names.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
table: the source table.
|
||||
column: the target column.
|
||||
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -60,132 +76,179 @@ class MappingSchema(Schema):
|
|||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema=None, visible=None, dialect=None):
|
||||
def __init__(
|
||||
self,
|
||||
schema: t.Optional[t.Dict] = None,
|
||||
visible: t.Optional[t.Dict] = None,
|
||||
dialect: t.Optional[str] = None,
|
||||
) -> None:
|
||||
self.schema = schema or {}
|
||||
self.visible = visible
|
||||
self.visible = visible or {}
|
||||
self.schema_trie = self._build_trie(self.schema)
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
self.supported_table_args = []
|
||||
self.forbidden_table_args = set()
|
||||
if self.schema:
|
||||
self._initialize_supported_args()
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
|
||||
self._supported_table_args: t.Tuple[str, ...] = tuple()
|
||||
|
||||
@classmethod
|
||||
def from_mapping_schema(cls, mapping_schema):
|
||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||
return MappingSchema(
|
||||
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
|
||||
schema=mapping_schema.schema,
|
||||
visible=mapping_schema.visible,
|
||||
dialect=mapping_schema.dialect,
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
|
||||
def copy(self, **kwargs) -> MappingSchema:
|
||||
return MappingSchema(
|
||||
**{ # type: ignore
|
||||
"schema": self.schema.copy(),
|
||||
"visible": self.visible.copy(),
|
||||
"dialect": self.dialect,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def add_table(self, table, column_mapping=None):
|
||||
@property
|
||||
def supported_table_args(self):
|
||||
if not self._supported_table_args and self.schema:
|
||||
depth = _dict_depth(self.schema)
|
||||
|
||||
if not depth or depth == 1: # {}
|
||||
self._supported_table_args = tuple()
|
||||
elif 2 <= depth <= 4:
|
||||
self._supported_table_args = TABLE_ARGS[: depth - 1]
|
||||
else:
|
||||
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
return self._supported_table_args
|
||||
|
||||
def add_table(
|
||||
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
|
||||
) -> None:
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
||||
table: the `Table` expression instance or string representing the table.
|
||||
column_mapping: a column mapping that describes the structure of the table.
|
||||
"""
|
||||
table = exp.to_table(table)
|
||||
self._validate_table(table)
|
||||
table_ = self._ensure_table(table)
|
||||
column_mapping = ensure_column_mapping(column_mapping)
|
||||
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
|
||||
existing_column_mapping = _nested_get(
|
||||
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
|
||||
)
|
||||
if existing_column_mapping and not column_mapping:
|
||||
schema = self.find_schema(table_, raise_on_missing=False)
|
||||
|
||||
if schema and not column_mapping:
|
||||
return
|
||||
|
||||
_nested_set(
|
||||
self.schema,
|
||||
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
|
||||
list(reversed(self.table_parts(table_))),
|
||||
column_mapping,
|
||||
)
|
||||
self._initialize_supported_args()
|
||||
self.schema_trie = self._build_trie(self.schema)
|
||||
|
||||
def _get_table_args_from_table(self, table):
|
||||
if table.args.get("catalog") is not None:
|
||||
return "catalog", "db", "this"
|
||||
if table.args.get("db") is not None:
|
||||
return "db", "this"
|
||||
return ("this",)
|
||||
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
|
||||
table_ = exp.to_table(table)
|
||||
|
||||
def _validate_table(self, table):
|
||||
if not self.supported_table_args and isinstance(table, exp.Table):
|
||||
return
|
||||
for forbidden in self.forbidden_table_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
for expected in self.supported_table_args:
|
||||
if not table.text(expected):
|
||||
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
|
||||
if not table_:
|
||||
raise SchemaError(f"Not a valid table '{table}'")
|
||||
|
||||
def column_names(self, table, only_visible=False):
|
||||
table = exp.to_table(table)
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
return table_
|
||||
|
||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
||||
def table_parts(self, table: exp.Table) -> t.List[str]:
|
||||
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
|
||||
|
||||
for forbidden in self.forbidden_table_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
|
||||
table_ = self._ensure_table(table)
|
||||
|
||||
if not isinstance(table_.this, exp.Identifier):
|
||||
return fs_get(table) # type: ignore
|
||||
|
||||
schema = self.find_schema(table_)
|
||||
|
||||
if schema is None:
|
||||
raise SchemaError(f"Could not find table schema {table}")
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
return list(schema)
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
visible = self._nested_get(self.table_parts(table_), self.visible)
|
||||
return [col for col in schema if col in visible] # type: ignore
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
def find_schema(
|
||||
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||
) -> t.Optional[t.Dict[str, str]]:
|
||||
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
|
||||
|
||||
if value == 0:
|
||||
if raise_on_missing:
|
||||
raise SchemaError(f"Cannot find schema for {table}.")
|
||||
else:
|
||||
return None
|
||||
elif value == 1:
|
||||
possibilities = flatten_schema(trie)
|
||||
if len(possibilities) == 1:
|
||||
parts.extend(possibilities[0])
|
||||
else:
|
||||
message = ", ".join(".".join(parts) for parts in possibilities)
|
||||
if raise_on_missing:
|
||||
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
|
||||
return None
|
||||
|
||||
return self._nested_get(parts, raise_on_missing=raise_on_missing)
|
||||
|
||||
def get_column_type(
|
||||
self, table: exp.Table | str, column: exp.Column | str
|
||||
) -> exp.DataType.Type:
|
||||
column_name = column if isinstance(column, str) else column.name
|
||||
table_ = exp.to_table(table)
|
||||
if table_:
|
||||
table_schema = self.find_schema(table_)
|
||||
schema_type = table_schema.get(column_name).upper() # type: ignore
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
raise SchemaError(f"Could not convert table '{table}'")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
def _convert_type(self, schema_type: str) -> exp.DataType.Type:
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
|
||||
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
schema_type: the type we want to convert.
|
||||
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
|
||||
if expression is None:
|
||||
raise ValueError(f"Could not parse {schema_type}")
|
||||
self._type_mapping_cache[schema_type] = expression.this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
raise SchemaError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
def _initialize_supported_args(self):
|
||||
if not self.supported_table_args:
|
||||
depth = _dict_depth(self.schema)
|
||||
def _build_trie(self, schema: t.Dict):
|
||||
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
|
||||
|
||||
all_args = ["this", "db", "catalog"]
|
||||
if not depth or depth == 1: # {}
|
||||
self.supported_table_args = []
|
||||
elif 2 <= depth <= 4:
|
||||
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
|
||||
else:
|
||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
def _nested_get(
|
||||
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
|
||||
) -> t.Optional[t.Any]:
|
||||
return _nested_get(
|
||||
d or self.schema,
|
||||
*zip(self.supported_table_args, reversed(parts)),
|
||||
raise_on_missing=raise_on_missing,
|
||||
)
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
def ensure_schema(schema: t.Any) -> Schema:
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
|
||||
|
||||
def ensure_column_mapping(mapping):
|
||||
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
|
||||
if isinstance(mapping, dict):
|
||||
return mapping
|
||||
elif isinstance(mapping, str):
|
||||
|
@ -196,7 +259,7 @@ def ensure_column_mapping(mapping):
|
|||
}
|
||||
# Check if mapping looks like a DataFrame StructType
|
||||
elif hasattr(mapping, "simpleString"):
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
|
||||
elif isinstance(mapping, list):
|
||||
return {x.strip(): None for x in mapping}
|
||||
elif mapping is None:
|
||||
|
@ -204,7 +267,20 @@ def ensure_column_mapping(mapping):
|
|||
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
||||
|
||||
|
||||
def fs_get(table):
|
||||
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
|
||||
tables = []
|
||||
keys = keys or []
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
for k, v in schema.items():
|
||||
if depth >= 3:
|
||||
tables.extend(flatten_schema(v, keys + [k]))
|
||||
elif depth == 2:
|
||||
tables.append(keys + [k])
|
||||
return tables
|
||||
|
||||
|
||||
def fs_get(table: exp.Table) -> t.List[str]:
|
||||
name = table.this.name
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
|
@ -214,21 +290,23 @@ def fs_get(table):
|
|||
raise ValueError(f"Cannot read schema for {table}")
|
||||
|
||||
|
||||
def _nested_get(d, *path, raise_on_missing=True):
|
||||
def _nested_get(
|
||||
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
|
||||
) -> t.Optional[t.Any]:
|
||||
"""
|
||||
Get a value for a nested dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
*path (tuple[str, str]): tuples of (name, key)
|
||||
d: the dictionary to search.
|
||||
*path: tuples of (name, key), where:
|
||||
`key` is the key in the dictionary to get.
|
||||
`name` is a string to use in the error if `key` isn't found.
|
||||
|
||||
Returns:
|
||||
The value or None if it doesn't exist
|
||||
The value or None if it doesn't exist.
|
||||
"""
|
||||
for name, key in path:
|
||||
d = d.get(key)
|
||||
d = d.get(key) # type: ignore
|
||||
if d is None:
|
||||
if raise_on_missing:
|
||||
name = "table" if name == "this" else name
|
||||
|
@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True):
|
|||
return d
|
||||
|
||||
|
||||
def _nested_set(d, keys, value):
|
||||
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
|
||||
"""
|
||||
In-place set a value for a nested dictionary
|
||||
|
||||
Ex:
|
||||
Example:
|
||||
>>> _nested_set({}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'second_key': 'value'}}
|
||||
|
||||
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
|
||||
|
||||
d (dict): dictionary
|
||||
keys (Iterable[str]): ordered iterable of keys that makeup path to value
|
||||
value (Any): The value to set in the dictionary for the given key path
|
||||
Args:
|
||||
d: dictionary to update.
|
||||
keys: the keys that makeup the path to `value`.
|
||||
value: the value to set in the dictionary for the given key path.
|
||||
|
||||
Returns:
|
||||
The (possibly) updated dictionary.
|
||||
"""
|
||||
if not keys:
|
||||
return
|
||||
return d
|
||||
|
||||
if len(keys) == 1:
|
||||
d[keys[0]] = value
|
||||
return
|
||||
return d
|
||||
|
||||
subd = d
|
||||
for key in keys[:-1]:
|
||||
if key not in subd:
|
||||
subd = subd.setdefault(key, {})
|
||||
else:
|
||||
subd = subd[key]
|
||||
|
||||
subd[keys[-1]] = value
|
||||
return d
|
||||
|
||||
|
||||
def _dict_depth(d):
|
||||
def _dict_depth(d: t.Dict) -> int:
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# the generic time format is based on python time.strftime
|
||||
import typing as t
|
||||
|
||||
# The generic time format is based on python time.strftime.
|
||||
# https://docs.python.org/3/library/time.html#time.strftime
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
||||
|
||||
def format_time(string, mapping, trie=None):
|
||||
def format_time(
|
||||
string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None
|
||||
) -> t.Optional[str]:
|
||||
"""
|
||||
Converts a time string given a mapping.
|
||||
|
||||
|
@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None):
|
|||
>>> format_time("%Y", {"%Y": "YYYY"})
|
||||
'YYYY'
|
||||
|
||||
mapping: Dictionary of time format to target time format
|
||||
trie: Optional trie, can be passed in for performance
|
||||
Args:
|
||||
mapping: dictionary of time format to target time format.
|
||||
trie: optional trie, can be passed in for performance.
|
||||
|
||||
Returns:
|
||||
The converted time string.
|
||||
"""
|
||||
if not string:
|
||||
return None
|
||||
|
||||
start = 0
|
||||
end = 1
|
||||
size = len(string)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from enum import auto
|
||||
|
||||
from sqlglot.helper import AutoName
|
||||
|
@ -27,6 +30,7 @@ class TokenType(AutoName):
|
|||
NOT = auto()
|
||||
EQ = auto()
|
||||
NEQ = auto()
|
||||
NULLSAFE_EQ = auto()
|
||||
AND = auto()
|
||||
OR = auto()
|
||||
AMP = auto()
|
||||
|
@ -36,12 +40,14 @@ class TokenType(AutoName):
|
|||
TILDA = auto()
|
||||
ARROW = auto()
|
||||
DARROW = auto()
|
||||
FARROW = auto()
|
||||
HASH = auto()
|
||||
HASH_ARROW = auto()
|
||||
DHASH_ARROW = auto()
|
||||
LR_ARROW = auto()
|
||||
ANNOTATION = auto()
|
||||
DOLLAR = auto()
|
||||
PARAMETER = auto()
|
||||
SESSION_PARAMETER = auto()
|
||||
|
||||
SPACE = auto()
|
||||
BREAK = auto()
|
||||
|
@ -73,7 +79,7 @@ class TokenType(AutoName):
|
|||
NVARCHAR = auto()
|
||||
TEXT = auto()
|
||||
BINARY = auto()
|
||||
BYTEA = auto()
|
||||
VARBINARY = auto()
|
||||
JSON = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
|
@ -142,6 +148,7 @@ class TokenType(AutoName):
|
|||
DESCRIBE = auto()
|
||||
DETERMINISTIC = auto()
|
||||
DISTINCT = auto()
|
||||
DISTINCT_FROM = auto()
|
||||
DISTRIBUTE_BY = auto()
|
||||
DIV = auto()
|
||||
DROP = auto()
|
||||
|
@ -238,6 +245,7 @@ class TokenType(AutoName):
|
|||
RETURNS = auto()
|
||||
RIGHT = auto()
|
||||
RLIKE = auto()
|
||||
ROLLBACK = auto()
|
||||
ROLLUP = auto()
|
||||
ROW = auto()
|
||||
ROWS = auto()
|
||||
|
@ -287,37 +295,49 @@ class TokenType(AutoName):
|
|||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col")
|
||||
__slots__ = ("token_type", "text", "line", "col", "comment")
|
||||
|
||||
@classmethod
|
||||
def number(cls, number):
|
||||
def number(cls, number: int) -> Token:
|
||||
"""Returns a NUMBER token with `number` as its text."""
|
||||
return cls(TokenType.NUMBER, str(number))
|
||||
|
||||
@classmethod
|
||||
def string(cls, string):
|
||||
def string(cls, string: str) -> Token:
|
||||
"""Returns a STRING token with `string` as its text."""
|
||||
return cls(TokenType.STRING, string)
|
||||
|
||||
@classmethod
|
||||
def identifier(cls, identifier):
|
||||
def identifier(cls, identifier: str) -> Token:
|
||||
"""Returns an IDENTIFIER token with `identifier` as its text."""
|
||||
return cls(TokenType.IDENTIFIER, identifier)
|
||||
|
||||
@classmethod
|
||||
def var(cls, var):
|
||||
def var(cls, var: str) -> Token:
|
||||
"""Returns an VAR token with `var` as its text."""
|
||||
return cls(TokenType.VAR, var)
|
||||
|
||||
def __init__(self, token_type, text, line=1, col=1):
|
||||
def __init__(
|
||||
self,
|
||||
token_type: TokenType,
|
||||
text: str,
|
||||
line: int = 1,
|
||||
col: int = 1,
|
||||
comment: t.Optional[str] = None,
|
||||
) -> None:
|
||||
self.token_type = token_type
|
||||
self.text = text
|
||||
self.line = line
|
||||
self.col = max(col - len(text), 1)
|
||||
self.comment = comment
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
|
||||
return f"<Token {attributes}>"
|
||||
|
||||
|
||||
class _Tokenizer(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
def __new__(cls, clsname, bases, attrs): # type: ignore
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
||||
|
@ -325,27 +345,29 @@ class _Tokenizer(type):
|
|||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||
klass._ESCAPES = set(klass.ESCAPES)
|
||||
klass._COMMENTS = dict(
|
||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
|
||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
||||
for comment in klass.COMMENTS
|
||||
)
|
||||
|
||||
klass.KEYWORD_TRIE = new_trie(
|
||||
key.upper()
|
||||
for key, value in {
|
||||
for key in {
|
||||
**klass.KEYWORDS,
|
||||
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
|
||||
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
||||
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
||||
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
||||
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
|
||||
}.items()
|
||||
}
|
||||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||
)
|
||||
|
||||
return klass
|
||||
|
||||
@staticmethod
|
||||
def _delimeter_list_to_dict(list):
|
||||
def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
|
||||
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
|
||||
|
||||
|
||||
|
@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"*": TokenType.STAR,
|
||||
"~": TokenType.TILDA,
|
||||
"?": TokenType.PLACEHOLDER,
|
||||
"#": TokenType.ANNOTATION,
|
||||
"@": TokenType.PARAMETER,
|
||||
# used for breaking a var like x'y' but nothing else
|
||||
# the token type doesn't matter
|
||||
"'": TokenType.QUOTE,
|
||||
"`": TokenType.IDENTIFIER,
|
||||
'"': TokenType.IDENTIFIER,
|
||||
"#": TokenType.HASH,
|
||||
}
|
||||
|
||||
QUOTES = ["'"]
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
|
||||
BIT_STRINGS = []
|
||||
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
HEX_STRINGS = []
|
||||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
BYTE_STRINGS = []
|
||||
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
||||
IDENTIFIERS = ['"']
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
|
||||
ESCAPE = "'"
|
||||
ESCAPES = ["'"]
|
||||
|
||||
KEYWORDS = {
|
||||
"/*+": TokenType.HINT,
|
||||
|
@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"<=": TokenType.LTE,
|
||||
"<>": TokenType.NEQ,
|
||||
"!=": TokenType.NEQ,
|
||||
"<=>": TokenType.NULLSAFE_EQ,
|
||||
"->": TokenType.ARROW,
|
||||
"->>": TokenType.DARROW,
|
||||
"=>": TokenType.FARROW,
|
||||
"#>": TokenType.HASH_ARROW,
|
||||
"#>>": TokenType.DHASH_ARROW,
|
||||
"<->": TokenType.LR_ARROW,
|
||||
|
@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DESCRIBE": TokenType.DESCRIBE,
|
||||
"DETERMINISTIC": TokenType.DETERMINISTIC,
|
||||
"DISTINCT": TokenType.DISTINCT,
|
||||
"DISTINCT FROM": TokenType.DISTINCT_FROM,
|
||||
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
|
||||
"DIV": TokenType.DIV,
|
||||
"DROP": TokenType.DROP,
|
||||
|
@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"RETURNS": TokenType.RETURNS,
|
||||
"RIGHT": TokenType.RIGHT,
|
||||
"RLIKE": TokenType.RLIKE,
|
||||
"ROLLBACK": TokenType.ROLLBACK,
|
||||
"ROLLUP": TokenType.ROLLUP,
|
||||
"ROW": TokenType.ROW,
|
||||
"ROWS": TokenType.ROWS,
|
||||
|
@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TEXT": TokenType.TEXT,
|
||||
"CLOB": TokenType.TEXT,
|
||||
"BINARY": TokenType.BINARY,
|
||||
"BLOB": TokenType.BINARY,
|
||||
"BYTEA": TokenType.BINARY,
|
||||
"BLOB": TokenType.VARBINARY,
|
||||
"BYTEA": TokenType.VARBINARY,
|
||||
"VARBINARY": TokenType.VARBINARY,
|
||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||
|
@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
TokenType.SET,
|
||||
TokenType.SHOW,
|
||||
TokenType.TRUNCATE,
|
||||
TokenType.USE,
|
||||
TokenType.VACUUM,
|
||||
TokenType.ROLLBACK,
|
||||
}
|
||||
|
||||
# handle numeric literals like in hive (3L = BIGINT)
|
||||
NUMERIC_LITERALS = {}
|
||||
ENCODE = None
|
||||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||
ENCODE: t.Optional[str] = None
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/")]
|
||||
KEYWORD_TRIE = None # autofilled
|
||||
|
@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_current",
|
||||
"_line",
|
||||
"_col",
|
||||
"_comment",
|
||||
"_char",
|
||||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"_prev_token_comment",
|
||||
"_prev_token_type",
|
||||
"_replace_backslash",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self._replace_backslash = "\\" in self._ESCAPES # type: ignore
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.sql = ""
|
||||
self.size = 0
|
||||
self.tokens = []
|
||||
self.tokens: t.List[Token] = []
|
||||
self._start = 0
|
||||
self._current = 0
|
||||
self._line = 1
|
||||
self._col = 1
|
||||
self._comment = None
|
||||
|
||||
self._char = None
|
||||
self._end = None
|
||||
self._peek = None
|
||||
self._prev_token_line = -1
|
||||
self._prev_token_comment = None
|
||||
self._prev_token_type = None
|
||||
|
||||
def tokenize(self, sql):
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||
self.reset()
|
||||
self.sql = sql
|
||||
self.size = len(sql)
|
||||
|
@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if not self._char:
|
||||
break
|
||||
|
||||
white_space = self.WHITE_SPACE.get(self._char)
|
||||
identifier_end = self._IDENTIFIERS.get(self._char)
|
||||
white_space = self.WHITE_SPACE.get(self._char) # type: ignore
|
||||
identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
|
||||
|
||||
if white_space:
|
||||
if white_space == TokenType.BREAK:
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
elif self._char.isdigit():
|
||||
elif self._char.isdigit(): # type:ignore
|
||||
self._scan_number()
|
||||
elif identifier_end:
|
||||
self._scan_identifier(identifier_end)
|
||||
|
@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._scan_keywords()
|
||||
return self.tokens
|
||||
|
||||
def _chars(self, size):
|
||||
def _chars(self, size: int) -> str:
|
||||
if size == 1:
|
||||
return self._char
|
||||
return self._char # type: ignore
|
||||
start = self._current - 1
|
||||
end = start + size
|
||||
if end <= self.size:
|
||||
return self.sql[start:end]
|
||||
return ""
|
||||
|
||||
def _advance(self, i=1):
|
||||
def _advance(self, i: int = 1) -> None:
|
||||
self._col += i
|
||||
self._current += i
|
||||
self._end = self._current >= self.size
|
||||
self._char = self.sql[self._current - 1]
|
||||
self._peek = self.sql[self._current] if self._current < self.size else ""
|
||||
self._end = self._current >= self.size # type: ignore
|
||||
self._char = self.sql[self._current - 1] # type: ignore
|
||||
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
|
||||
|
||||
@property
|
||||
def _text(self):
|
||||
def _text(self) -> str:
|
||||
return self.sql[self._start : self._current]
|
||||
|
||||
def _add(self, token_type, text=None):
|
||||
self._prev_token_type = token_type
|
||||
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
self._prev_token_comment = self._comment
|
||||
self._prev_token_type = token_type # type: ignore
|
||||
self.tokens.append(
|
||||
Token(
|
||||
token_type,
|
||||
self._text if text is None else text,
|
||||
self._line,
|
||||
self._col,
|
||||
self._comment,
|
||||
)
|
||||
)
|
||||
self._comment = None
|
||||
|
||||
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
|
||||
if token_type in self.COMMANDS and (
|
||||
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
|
||||
):
|
||||
self._start = self._current
|
||||
while not self._end and self._peek != ";":
|
||||
self._advance()
|
||||
if self._start < self._current:
|
||||
self._add(TokenType.STRING)
|
||||
|
||||
def _scan_keywords(self):
|
||||
def _scan_keywords(self) -> None:
|
||||
size = 0
|
||||
word = None
|
||||
chars = self._text
|
||||
|
@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if skip:
|
||||
result = 1
|
||||
else:
|
||||
result, trie = in_trie(trie, char.upper())
|
||||
result, trie = in_trie(trie, char.upper()) # type: ignore
|
||||
|
||||
if result == 0:
|
||||
break
|
||||
|
@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
skip = True
|
||||
else:
|
||||
chars = None
|
||||
chars = None # type: ignore
|
||||
|
||||
if not word:
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
token = self.SINGLE_TOKENS[self._char]
|
||||
if token == TokenType.ANNOTATION:
|
||||
self._scan_annotation()
|
||||
return
|
||||
self._add(token)
|
||||
self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
|
||||
return
|
||||
self._scan_var()
|
||||
return
|
||||
|
@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance(size - 1)
|
||||
self._add(self.KEYWORDS[word.upper()])
|
||||
|
||||
def _scan_comment(self, comment_start):
|
||||
if comment_start not in self._COMMENTS:
|
||||
def _scan_comment(self, comment_start: str) -> bool:
|
||||
if comment_start not in self._COMMENTS: # type: ignore
|
||||
return False
|
||||
|
||||
comment_end = self._COMMENTS[comment_start]
|
||||
comment_start_line = self._line
|
||||
comment_start_size = len(comment_start)
|
||||
comment_end = self._COMMENTS[comment_start] # type: ignore
|
||||
|
||||
if comment_end:
|
||||
comment_end_size = len(comment_end)
|
||||
|
||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||
self._advance()
|
||||
|
||||
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
|
||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
|
||||
self._advance()
|
||||
self._comment = self._text[comment_start_size:] # type: ignore
|
||||
|
||||
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
|
||||
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
|
||||
|
||||
if comment_start_line == self._prev_token_line:
|
||||
if self._prev_token_comment is None:
|
||||
self.tokens[-1].comment = self._comment
|
||||
|
||||
self._comment = None
|
||||
|
||||
return True
|
||||
|
||||
def _scan_annotation(self):
|
||||
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
|
||||
self._advance()
|
||||
self._add(TokenType.ANNOTATION, self._text[1:])
|
||||
|
||||
def _scan_number(self):
|
||||
def _scan_number(self) -> None:
|
||||
if self._char == "0":
|
||||
peek = self._peek.upper()
|
||||
peek = self._peek.upper() # type: ignore
|
||||
if peek == "B":
|
||||
return self._scan_bits()
|
||||
elif peek == "X":
|
||||
|
@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
scientific = 0
|
||||
|
||||
while True:
|
||||
if self._peek.isdigit():
|
||||
if self._peek.isdigit(): # type: ignore
|
||||
self._advance()
|
||||
elif self._peek == "." and not decimal:
|
||||
decimal = True
|
||||
|
@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
elif self._peek in ("-", "+") and scientific == 1:
|
||||
scientific += 1
|
||||
self._advance()
|
||||
elif self._peek.upper() == "E" and not scientific:
|
||||
elif self._peek.upper() == "E" and not scientific: # type: ignore
|
||||
scientific += 1
|
||||
self._advance()
|
||||
elif self._peek.isalpha():
|
||||
elif self._peek.isalpha(): # type: ignore
|
||||
self._add(TokenType.NUMBER)
|
||||
literal = []
|
||||
while self._peek.isalpha():
|
||||
literal.append(self._peek.upper())
|
||||
while self._peek.isalpha(): # type: ignore
|
||||
literal.append(self._peek.upper()) # type: ignore
|
||||
self._advance()
|
||||
literal = "".join(literal)
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
|
||||
literal = "".join(literal) # type: ignore
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
|
||||
if token_type:
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal)
|
||||
return self._add(token_type, literal) # type: ignore
|
||||
return self._advance(-len(literal))
|
||||
else:
|
||||
return self._add(TokenType.NUMBER)
|
||||
|
||||
def _scan_bits(self):
|
||||
def _scan_bits(self) -> None:
|
||||
self._advance()
|
||||
value = self._extract_value()
|
||||
try:
|
||||
|
@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
except ValueError:
|
||||
self._add(TokenType.IDENTIFIER)
|
||||
|
||||
def _scan_hex(self):
|
||||
def _scan_hex(self) -> None:
|
||||
self._advance()
|
||||
value = self._extract_value()
|
||||
try:
|
||||
|
@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
except ValueError:
|
||||
self._add(TokenType.IDENTIFIER)
|
||||
|
||||
def _extract_value(self):
|
||||
def _extract_value(self) -> str:
|
||||
while True:
|
||||
char = self._peek.strip()
|
||||
char = self._peek.strip() # type: ignore
|
||||
if char and char not in self.SINGLE_TOKENS:
|
||||
self._advance()
|
||||
else:
|
||||
|
@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
return self._text
|
||||
|
||||
def _scan_string(self, quote):
|
||||
quote_end = self._QUOTES.get(quote)
|
||||
def _scan_string(self, quote: str) -> bool:
|
||||
quote_end = self._QUOTES.get(quote) # type: ignore
|
||||
if quote_end is None:
|
||||
return False
|
||||
|
||||
self._advance(len(quote))
|
||||
text = self._extract_string(quote_end)
|
||||
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
|
||||
text = text.replace("\\\\", "\\") if self._replace_backslash else text
|
||||
self._add(TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
# X'1234, b'0110', E'\\\\\' etc.
|
||||
def _scan_formatted_string(self, string_start):
|
||||
if string_start in self._HEX_STRINGS:
|
||||
delimiters = self._HEX_STRINGS
|
||||
def _scan_formatted_string(self, string_start: str) -> bool:
|
||||
if string_start in self._HEX_STRINGS: # type: ignore
|
||||
delimiters = self._HEX_STRINGS # type: ignore
|
||||
token_type = TokenType.HEX_STRING
|
||||
base = 16
|
||||
elif string_start in self._BIT_STRINGS:
|
||||
delimiters = self._BIT_STRINGS
|
||||
elif string_start in self._BIT_STRINGS: # type: ignore
|
||||
delimiters = self._BIT_STRINGS # type: ignore
|
||||
token_type = TokenType.BIT_STRING
|
||||
base = 2
|
||||
elif string_start in self._BYTE_STRINGS:
|
||||
delimiters = self._BYTE_STRINGS
|
||||
elif string_start in self._BYTE_STRINGS: # type: ignore
|
||||
delimiters = self._BYTE_STRINGS # type: ignore
|
||||
token_type = TokenType.BYTE_STRING
|
||||
base = None
|
||||
else:
|
||||
|
@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
try:
|
||||
self._add(token_type, f"{int(text, base)}")
|
||||
except:
|
||||
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
|
||||
raise RuntimeError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _scan_identifier(self, identifier_end):
|
||||
def _scan_identifier(self, identifier_end: str) -> None:
|
||||
while self._peek != identifier_end:
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
||||
|
@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
||||
|
||||
def _scan_var(self):
|
||||
def _scan_var(self) -> None:
|
||||
while True:
|
||||
char = self._peek.strip()
|
||||
char = self._peek.strip() # type: ignore
|
||||
if char and char not in self.SINGLE_TOKENS:
|
||||
self._advance()
|
||||
else:
|
||||
|
@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
|
||||
)
|
||||
|
||||
def _extract_string(self, delimiter):
|
||||
def _extract_string(self, delimiter: str) -> str:
|
||||
text = ""
|
||||
delim_size = len(delimiter)
|
||||
|
||||
while True:
|
||||
if self._char == self.ESCAPE and self._peek == delimiter:
|
||||
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
|
||||
text += delimiter
|
||||
self._advance(2)
|
||||
else:
|
||||
|
@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if self._end:
|
||||
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
text += self._char
|
||||
text += self._char # type: ignore
|
||||
self._advance()
|
||||
|
||||
return text
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.generator import Generator
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
||||
|
||||
def unalias_group(expression):
|
||||
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Replace references to select aliases in GROUP BY clauses.
|
||||
|
||||
|
@ -9,6 +16,12 @@ def unalias_group(expression):
|
|||
>>> import sqlglot
|
||||
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
|
||||
'SELECT a AS b FROM x GROUP BY 1'
|
||||
|
||||
Args:
|
||||
expression: the expression that will be transformed.
|
||||
|
||||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
|
||||
aliased_selects = {
|
||||
|
@ -30,19 +43,20 @@ def unalias_group(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def preprocess(transforms, to_sql):
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
to_sql: t.Callable[[Generator, exp.Expression], str],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
"""
|
||||
Create a new transform function that can be used a value in `Generator.TRANSFORMS`
|
||||
to convert expressions to SQL.
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
|
||||
|
||||
Args:
|
||||
transforms (list[(exp.Expression) -> exp.Expression]):
|
||||
Sequence of transform functions. These will be called in order.
|
||||
to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
|
||||
Final transform that converts the resulting expression to a SQL string.
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
to_sql: final transform that converts the resulting expression to a SQL string.
|
||||
|
||||
Returns:
|
||||
(sqlglot.generator.Generator, exp.Expression) -> str:
|
||||
Function that can be used as a generator transform.
|
||||
Function that can be used as a generator transform.
|
||||
"""
|
||||
|
||||
def _to_sql(self, expression):
|
||||
|
@ -54,12 +68,10 @@ def preprocess(transforms, to_sql):
|
|||
return _to_sql
|
||||
|
||||
|
||||
def delegate(attr):
|
||||
def delegate(attr: str) -> t.Callable:
|
||||
"""
|
||||
Create a new method that delegates to `attr`.
|
||||
|
||||
This is useful for creating `Generator.TRANSFORMS` functions that delegate
|
||||
to existing generator methods.
|
||||
Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
|
||||
functions that delegate to existing generator methods.
|
||||
"""
|
||||
|
||||
def _transform(self, *args, **kwargs):
|
||||
|
|
|
@ -1,5 +1,26 @@
|
|||
def new_trie(keywords):
|
||||
trie = {}
|
||||
import typing as t
|
||||
|
||||
key = t.Sequence[t.Hashable]
|
||||
|
||||
|
||||
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
|
||||
"""
|
||||
Creates a new trie out of a collection of keywords.
|
||||
|
||||
The trie is represented as a sequence of nested dictionaries keyed by either single character
|
||||
strings, or by 0, which is used to designate that a keyword is in the trie.
|
||||
|
||||
Example:
|
||||
>>> new_trie(["bla", "foo", "blab"])
|
||||
{'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}}
|
||||
|
||||
Args:
|
||||
keywords: the keywords to create the trie from.
|
||||
|
||||
Returns:
|
||||
The trie corresponding to `keywords`.
|
||||
"""
|
||||
trie: t.Dict = {}
|
||||
|
||||
for key in keywords:
|
||||
current = trie
|
||||
|
@ -11,7 +32,28 @@ def new_trie(keywords):
|
|||
return trie
|
||||
|
||||
|
||||
def in_trie(trie, key):
|
||||
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
|
||||
"""
|
||||
Checks whether a key is in a trie.
|
||||
|
||||
Examples:
|
||||
>>> in_trie(new_trie(["cat"]), "bob")
|
||||
(0, {'c': {'a': {'t': {0: True}}}})
|
||||
|
||||
>>> in_trie(new_trie(["cat"]), "ca")
|
||||
(1, {'t': {0: True}})
|
||||
|
||||
>>> in_trie(new_trie(["cat"]), "cat")
|
||||
(2, {0: True})
|
||||
|
||||
Args:
|
||||
trie: the trie to be searched.
|
||||
key: the target key.
|
||||
|
||||
Returns:
|
||||
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
|
||||
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
|
||||
"""
|
||||
if not key:
|
||||
return (0, trie)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue