1
0
Fork 0
sqlglot/sqlglot/diff.py

296 lines
10 KiB
Python
Raw Normal View History

from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_list
@dataclass(frozen=True)
class Insert:
"""Indicates that a new node has been inserted"""
expression: exp.Expression
@dataclass(frozen=True)
class Remove:
"""Indicates that an existing node has been removed"""
expression: exp.Expression
@dataclass(frozen=True)
class Move:
"""Indicates that an existing node's position within the tree has changed"""
expression: exp.Expression
@dataclass(frozen=True)
class Update:
"""Indicates that an existing node has been updated"""
source: exp.Expression
target: exp.Expression
@dataclass(frozen=True)
class Keep:
"""Indicates that an existing node hasn't been changed"""
source: exp.Expression
target: exp.Expression
def diff(source, target):
"""
Returns the list of changes between the source and the target expressions.
Examples:
>>> diff(parse_one("a + b"), parse_one("a + c"))
[
Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))),
Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))),
Keep(
source=(ADD this: ...),
target=(ADD this: ...)
),
Keep(
source=(COLUMN this: (IDENTIFIER this: a, quoted: False)),
target=(COLUMN this: (IDENTIFIER this: a, quoted: False))
),
]
Args:
source (sqlglot.Expression): the source expression.
target (sqlglot.Expression): the target expression against which the diff should be calculated.
Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
This list represents a sequence of steps needed to transform the source expression tree into the target one.
"""
return ChangeDistiller().diff(source.copy(), target.copy())
LEAF_EXPRESSION_TYPES = (
exp.Boolean,
exp.DataType,
exp.Identifier,
exp.Literal,
)
class ChangeDistiller:
"""
The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in
their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
def __init__(self, f=0.6, t=0.6):
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
def diff(self, source, target):
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
self._bigram_histo_cache = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set):
edit_script = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
edit_script.append(Insert(self._target_index[inserted_node_id]))
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
return edit_script
def _generate_move_edits(self, source, target, matching_set):
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set))
move_edits = []
for a in source_args:
if a not in args_lcs and a not in self._unmatched_source_nodes:
move_edits.append(Move(self._source_index[a]))
return move_edits
def _compute_matching_set(self):
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
for target_node_id in ordered_unmatched_target_nodes:
source_node = self._source_index[source_node_id]
target_node = self._target_index[target_node_id]
if _is_same_type(source_node, target_node):
source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
and self._dice_coefficient(source_node, target_node) >= self.f
):
matching_set.add((source_node_id, target_node_id))
self._unmatched_source_nodes.remove(source_node_id)
self._unmatched_target_nodes.remove(target_node_id)
ordered_unmatched_target_nodes.pop(target_node_id, None)
break
return matching_set
def _compute_leaf_matching_set(self):
candidate_matchings = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
for target_leaf in target_leaves:
if _is_same_type(source_leaf, target_leaf):
similarity_score = self._dice_coefficient(source_leaf, target_leaf)
if similarity_score >= self.f:
heappush(
candidate_matchings,
(
-similarity_score,
len(candidate_matchings),
source_leaf,
target_leaf,
),
)
# Pick best matchings based on the highest score
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
return matching_set
def _dice_coefficient(self, source, target):
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
total_grams = sum(source_histo.values()) + sum(target_histo.values())
if not total_grams:
return 1.0 if source == target else 0.0
overlap_len = 0
overlapping_grams = set(source_histo) & set(target_histo)
for g in overlapping_grams:
overlap_len += min(source_histo[g], target_histo[g])
return 2 * overlap_len / total_grams
def _bigram_histo(self, expression):
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
bigram_histo = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
self._bigram_histo_cache[id(expression)] = bigram_histo
return bigram_histo
def _get_leaves(expression):
has_child_exprs = False
for a in expression.args.values():
nodes = ensure_list(a)
for node in nodes:
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source, target):
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
if isinstance(source, exp.Anonymous):
return source.this == target.this
return True
return False
def _expression_only_args(expression):
args = []
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]
def _lcs(seq_a, seq_b, equal):
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
len_b = len(seq_b)
lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)]
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
lcs_result[i][j] = []
elif equal(seq_a[i - 1], seq_b[j - 1]):
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
else lcs_result[i][j - 1]
)
return lcs_result[len_a][len_b]