Adding upstream version 6.0.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d01130b3f1
commit
527597d2af
122 changed files with 23162 additions and 0 deletions
314
sqlglot/diff.py
Normal file
314
sqlglot/diff.py
Normal file
|
@ -0,0 +1,314 @@
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from heapq import heappop, heappush
|
||||
|
||||
from sqlglot import Dialect
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Insert:
|
||||
"""Indicates that a new node has been inserted"""
|
||||
|
||||
expression: exp.Expression
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Remove:
|
||||
"""Indicates that an existing node has been removed"""
|
||||
|
||||
expression: exp.Expression
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Move:
|
||||
"""Indicates that an existing node's position within the tree has changed"""
|
||||
|
||||
expression: exp.Expression
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Update:
|
||||
"""Indicates that an existing node has been updated"""
|
||||
|
||||
source: exp.Expression
|
||||
target: exp.Expression
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Keep:
|
||||
"""Indicates that an existing node hasn't been changed"""
|
||||
|
||||
source: exp.Expression
|
||||
target: exp.Expression
|
||||
|
||||
|
||||
def diff(source, target):
|
||||
"""
|
||||
Returns the list of changes between the source and the target expressions.
|
||||
|
||||
Examples:
|
||||
>>> diff(parse_one("a + b"), parse_one("a + c"))
|
||||
[
|
||||
Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))),
|
||||
Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))),
|
||||
Keep(
|
||||
source=(ADD this: ...),
|
||||
target=(ADD this: ...)
|
||||
),
|
||||
Keep(
|
||||
source=(COLUMN this: (IDENTIFIER this: a, quoted: False)),
|
||||
target=(COLUMN this: (IDENTIFIER this: a, quoted: False))
|
||||
),
|
||||
]
|
||||
|
||||
Args:
|
||||
source (sqlglot.Expression): the source expression.
|
||||
target (sqlglot.Expression): the target expression against which the diff should be calculated.
|
||||
|
||||
Returns:
|
||||
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
|
||||
This list represents a sequence of steps needed to transform the source expression tree into the target one.
|
||||
"""
|
||||
return ChangeDistiller().diff(source.copy(), target.copy())
|
||||
|
||||
|
||||
LEAF_EXPRESSION_TYPES = (
|
||||
exp.Boolean,
|
||||
exp.DataType,
|
||||
exp.Identifier,
|
||||
exp.Literal,
|
||||
)
|
||||
|
||||
|
||||
class ChangeDistiller:
|
||||
"""
|
||||
The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in
|
||||
their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by
|
||||
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, f=0.6, t=0.6):
|
||||
self.f = f
|
||||
self.t = t
|
||||
self._sql_generator = Dialect().generator()
|
||||
|
||||
def diff(self, source, target):
|
||||
self._source = source
|
||||
self._target = target
|
||||
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
|
||||
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
|
||||
self._unmatched_source_nodes = set(self._source_index)
|
||||
self._unmatched_target_nodes = set(self._target_index)
|
||||
self._bigram_histo_cache = {}
|
||||
|
||||
matching_set = self._compute_matching_set()
|
||||
return self._generate_edit_script(matching_set)
|
||||
|
||||
def _generate_edit_script(self, matching_set):
|
||||
edit_script = []
|
||||
for removed_node_id in self._unmatched_source_nodes:
|
||||
edit_script.append(Remove(self._source_index[removed_node_id]))
|
||||
for inserted_node_id in self._unmatched_target_nodes:
|
||||
edit_script.append(Insert(self._target_index[inserted_node_id]))
|
||||
for kept_source_node_id, kept_target_node_id in matching_set:
|
||||
source_node = self._source_index[kept_source_node_id]
|
||||
target_node = self._target_index[kept_target_node_id]
|
||||
if (
|
||||
not isinstance(source_node, LEAF_EXPRESSION_TYPES)
|
||||
or source_node == target_node
|
||||
):
|
||||
edit_script.extend(
|
||||
self._generate_move_edits(source_node, target_node, matching_set)
|
||||
)
|
||||
edit_script.append(Keep(source_node, target_node))
|
||||
else:
|
||||
edit_script.append(Update(source_node, target_node))
|
||||
|
||||
return edit_script
|
||||
|
||||
def _generate_move_edits(self, source, target, matching_set):
|
||||
source_args = [id(e) for e in _expression_only_args(source)]
|
||||
target_args = [id(e) for e in _expression_only_args(target)]
|
||||
|
||||
args_lcs = set(
|
||||
_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
|
||||
)
|
||||
|
||||
move_edits = []
|
||||
for a in source_args:
|
||||
if a not in args_lcs and a not in self._unmatched_source_nodes:
|
||||
move_edits.append(Move(self._source_index[a]))
|
||||
|
||||
return move_edits
|
||||
|
||||
def _compute_matching_set(self):
|
||||
leaves_matching_set = self._compute_leaf_matching_set()
|
||||
matching_set = leaves_matching_set.copy()
|
||||
|
||||
ordered_unmatched_source_nodes = {
|
||||
id(n[0]): None
|
||||
for n in self._source.bfs()
|
||||
if id(n[0]) in self._unmatched_source_nodes
|
||||
}
|
||||
ordered_unmatched_target_nodes = {
|
||||
id(n[0]): None
|
||||
for n in self._target.bfs()
|
||||
if id(n[0]) in self._unmatched_target_nodes
|
||||
}
|
||||
|
||||
for source_node_id in ordered_unmatched_source_nodes:
|
||||
for target_node_id in ordered_unmatched_target_nodes:
|
||||
source_node = self._source_index[source_node_id]
|
||||
target_node = self._target_index[target_node_id]
|
||||
if _is_same_type(source_node, target_node):
|
||||
source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
|
||||
target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
|
||||
|
||||
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
||||
if max_leaves_num:
|
||||
common_leaves_num = sum(
|
||||
1 if s in source_leaf_ids and t in target_leaf_ids else 0
|
||||
for s, t in leaves_matching_set
|
||||
)
|
||||
leaf_similarity_score = common_leaves_num / max_leaves_num
|
||||
else:
|
||||
leaf_similarity_score = 0.0
|
||||
|
||||
adjusted_t = (
|
||||
self.t
|
||||
if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
|
||||
else 0.4
|
||||
)
|
||||
|
||||
if leaf_similarity_score >= 0.8 or (
|
||||
leaf_similarity_score >= adjusted_t
|
||||
and self._dice_coefficient(source_node, target_node) >= self.f
|
||||
):
|
||||
matching_set.add((source_node_id, target_node_id))
|
||||
self._unmatched_source_nodes.remove(source_node_id)
|
||||
self._unmatched_target_nodes.remove(target_node_id)
|
||||
ordered_unmatched_target_nodes.pop(target_node_id, None)
|
||||
break
|
||||
|
||||
return matching_set
|
||||
|
||||
def _compute_leaf_matching_set(self):
|
||||
candidate_matchings = []
|
||||
source_leaves = list(_get_leaves(self._source))
|
||||
target_leaves = list(_get_leaves(self._target))
|
||||
for source_leaf in source_leaves:
|
||||
for target_leaf in target_leaves:
|
||||
if _is_same_type(source_leaf, target_leaf):
|
||||
similarity_score = self._dice_coefficient(source_leaf, target_leaf)
|
||||
if similarity_score >= self.f:
|
||||
heappush(
|
||||
candidate_matchings,
|
||||
(
|
||||
-similarity_score,
|
||||
len(candidate_matchings),
|
||||
source_leaf,
|
||||
target_leaf,
|
||||
),
|
||||
)
|
||||
|
||||
# Pick best matchings based on the highest score
|
||||
matching_set = set()
|
||||
while candidate_matchings:
|
||||
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
|
||||
if (
|
||||
id(source_leaf) in self._unmatched_source_nodes
|
||||
and id(target_leaf) in self._unmatched_target_nodes
|
||||
):
|
||||
matching_set.add((id(source_leaf), id(target_leaf)))
|
||||
self._unmatched_source_nodes.remove(id(source_leaf))
|
||||
self._unmatched_target_nodes.remove(id(target_leaf))
|
||||
|
||||
return matching_set
|
||||
|
||||
def _dice_coefficient(self, source, target):
|
||||
source_histo = self._bigram_histo(source)
|
||||
target_histo = self._bigram_histo(target)
|
||||
|
||||
total_grams = sum(source_histo.values()) + sum(target_histo.values())
|
||||
if not total_grams:
|
||||
return 1.0 if source == target else 0.0
|
||||
|
||||
overlap_len = 0
|
||||
overlapping_grams = set(source_histo) & set(target_histo)
|
||||
for g in overlapping_grams:
|
||||
overlap_len += min(source_histo[g], target_histo[g])
|
||||
|
||||
return 2 * overlap_len / total_grams
|
||||
|
||||
def _bigram_histo(self, expression):
|
||||
if id(expression) in self._bigram_histo_cache:
|
||||
return self._bigram_histo_cache[id(expression)]
|
||||
|
||||
expression_str = self._sql_generator.generate(expression)
|
||||
count = max(0, len(expression_str) - 1)
|
||||
bigram_histo = defaultdict(int)
|
||||
for i in range(count):
|
||||
bigram_histo[expression_str[i : i + 2]] += 1
|
||||
|
||||
self._bigram_histo_cache[id(expression)] = bigram_histo
|
||||
return bigram_histo
|
||||
|
||||
|
||||
def _get_leaves(expression):
|
||||
has_child_exprs = False
|
||||
|
||||
for a in expression.args.values():
|
||||
nodes = ensure_list(a)
|
||||
for node in nodes:
|
||||
if isinstance(node, exp.Expression):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
|
||||
if not has_child_exprs:
|
||||
yield expression
|
||||
|
||||
|
||||
def _is_same_type(source, target):
|
||||
if type(source) is type(target):
|
||||
if isinstance(source, exp.Join):
|
||||
return source.args.get("side") == target.args.get("side")
|
||||
|
||||
if isinstance(source, exp.Anonymous):
|
||||
return source.this == target.this
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _expression_only_args(expression):
|
||||
args = []
|
||||
if expression:
|
||||
for a in expression.args.values():
|
||||
args.extend(ensure_list(a))
|
||||
return [a for a in args if isinstance(a, exp.Expression)]
|
||||
|
||||
|
||||
def _lcs(seq_a, seq_b, equal):
|
||||
"""Calculates the longest common subsequence"""
|
||||
|
||||
len_a = len(seq_a)
|
||||
len_b = len(seq_b)
|
||||
lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)]
|
||||
|
||||
for i in range(len_a + 1):
|
||||
for j in range(len_b + 1):
|
||||
if i == 0 or j == 0:
|
||||
lcs_result[i][j] = []
|
||||
elif equal(seq_a[i - 1], seq_b[j - 1]):
|
||||
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
|
||||
else:
|
||||
lcs_result[i][j] = (
|
||||
lcs_result[i - 1][j]
|
||||
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
|
||||
else lcs_result[i][j - 1]
|
||||
)
|
||||
|
||||
return lcs_result[len_a][len_b]
|
Loading…
Add table
Add a link
Reference in a new issue