262 lines
6.8 KiB
Go
262 lines
6.8 KiB
Go
// Copyright (c) 2023 Couchbase, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
//go:build vectors
|
|
// +build vectors
|
|
|
|
package collector
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/blevesearch/bleve/v2/search"
|
|
index "github.com/blevesearch/bleve_index_api"
|
|
)
|
|
|
|
type collectStoreKNN struct {
|
|
internalHeaps []collectorStore
|
|
kValues []int64
|
|
allHits map[*search.DocumentMatch]struct{}
|
|
ejectedDocs map[*search.DocumentMatch]struct{}
|
|
}
|
|
|
|
func newStoreKNN(internalHeaps []collectorStore, kValues []int64) *collectStoreKNN {
|
|
return &collectStoreKNN{
|
|
internalHeaps: internalHeaps,
|
|
kValues: kValues,
|
|
ejectedDocs: make(map[*search.DocumentMatch]struct{}),
|
|
allHits: make(map[*search.DocumentMatch]struct{}),
|
|
}
|
|
}
|
|
|
|
// Adds a document to the collector store and returns the documents that were ejected
|
|
// from the store. The documents that were ejected from the store are the ones that
|
|
// were not in the top K documents for any of the heaps.
|
|
// These document are put back into the pool document match pool in the KNN Collector.
|
|
func (c *collectStoreKNN) AddDocument(doc *search.DocumentMatch) []*search.DocumentMatch {
|
|
for heapIdx := 0; heapIdx < len(c.internalHeaps); heapIdx++ {
|
|
if _, ok := doc.ScoreBreakdown[heapIdx]; !ok {
|
|
continue
|
|
}
|
|
ejectedDoc := c.internalHeaps[heapIdx].AddNotExceedingSize(doc, int(c.kValues[heapIdx]))
|
|
if ejectedDoc != nil {
|
|
delete(ejectedDoc.ScoreBreakdown, heapIdx)
|
|
c.ejectedDocs[ejectedDoc] = struct{}{}
|
|
}
|
|
}
|
|
var rv []*search.DocumentMatch
|
|
for doc := range c.ejectedDocs {
|
|
if len(doc.ScoreBreakdown) == 0 {
|
|
rv = append(rv, doc)
|
|
}
|
|
// clear out the ejectedDocs map to reuse it in the next AddDocument call
|
|
delete(c.ejectedDocs, doc)
|
|
}
|
|
return rv
|
|
}
|
|
|
|
func (c *collectStoreKNN) Final(fixup collectorFixup) (search.DocumentMatchCollection, error) {
|
|
for _, heap := range c.internalHeaps {
|
|
for _, doc := range heap.Internal() {
|
|
// duplicates may be present across the internal heaps
|
|
// meaning the same document match may be in the top K
|
|
// for multiple KNN queries.
|
|
c.allHits[doc] = struct{}{}
|
|
}
|
|
}
|
|
size := len(c.allHits)
|
|
if size <= 0 {
|
|
return make(search.DocumentMatchCollection, 0), nil
|
|
}
|
|
rv := make(search.DocumentMatchCollection, size)
|
|
i := 0
|
|
for doc := range c.allHits {
|
|
if fixup != nil {
|
|
err := fixup(doc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
rv[i] = doc
|
|
i++
|
|
}
|
|
return rv, nil
|
|
}
|
|
|
|
func MakeKNNDocMatchHandler(ctx *search.SearchContext) (search.DocumentMatchHandler, error) {
|
|
var hc *KNNCollector
|
|
var ok bool
|
|
if hc, ok = ctx.Collector.(*KNNCollector); ok {
|
|
return func(d *search.DocumentMatch) error {
|
|
if d == nil {
|
|
return nil
|
|
}
|
|
toRelease := hc.knnStore.AddDocument(d)
|
|
for _, doc := range toRelease {
|
|
ctx.DocumentMatchPool.Put(doc)
|
|
}
|
|
return nil
|
|
}, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func GetNewKNNCollectorStore(kArray []int64) *collectStoreKNN {
|
|
internalHeaps := make([]collectorStore, len(kArray))
|
|
for knnIdx, k := range kArray {
|
|
// TODO - Check if the datatype of k can be made into an int instead of int64
|
|
idx := knnIdx
|
|
internalHeaps[idx] = getOptimalCollectorStore(int(k), 0, func(i, j *search.DocumentMatch) int {
|
|
if i.ScoreBreakdown[idx] < j.ScoreBreakdown[idx] {
|
|
return 1
|
|
}
|
|
return -1
|
|
})
|
|
}
|
|
return newStoreKNN(internalHeaps, kArray)
|
|
}
|
|
|
|
// implements Collector interface
|
|
type KNNCollector struct {
|
|
knnStore *collectStoreKNN
|
|
size int
|
|
total uint64
|
|
took time.Duration
|
|
results search.DocumentMatchCollection
|
|
maxScore float64
|
|
}
|
|
|
|
func NewKNNCollector(kArray []int64, size int64) *KNNCollector {
|
|
return &KNNCollector{
|
|
knnStore: GetNewKNNCollectorStore(kArray),
|
|
size: int(size),
|
|
}
|
|
}
|
|
|
|
func (hc *KNNCollector) Collect(ctx context.Context, searcher search.Searcher, reader index.IndexReader) error {
|
|
startTime := time.Now()
|
|
var err error
|
|
var next *search.DocumentMatch
|
|
|
|
// pre-allocate enough space in the DocumentMatchPool
|
|
// unless the sum of K is too large, then cap it
|
|
// everything should still work, just allocates DocumentMatches on demand
|
|
backingSize := hc.size
|
|
if backingSize > PreAllocSizeSkipCap {
|
|
backingSize = PreAllocSizeSkipCap + 1
|
|
}
|
|
searchContext := &search.SearchContext{
|
|
DocumentMatchPool: search.NewDocumentMatchPool(backingSize+searcher.DocumentMatchPoolSize(), 0),
|
|
Collector: hc,
|
|
IndexReader: reader,
|
|
}
|
|
|
|
dmHandlerMakerKNN := MakeKNNDocMatchHandler
|
|
if cv := ctx.Value(search.MakeKNNDocumentMatchHandlerKey); cv != nil {
|
|
dmHandlerMakerKNN = cv.(search.MakeKNNDocumentMatchHandler)
|
|
}
|
|
// use the application given builder for making the custom document match
|
|
// handler and perform callbacks/invocations on the newly made handler.
|
|
dmHandler, err := dmHandlerMakerKNN(searchContext)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
search.RecordSearchCost(ctx, search.AbortM, 0)
|
|
return ctx.Err()
|
|
default:
|
|
next, err = searcher.Next(searchContext)
|
|
}
|
|
for err == nil && next != nil {
|
|
if hc.total%CheckDoneEvery == 0 {
|
|
select {
|
|
case <-ctx.Done():
|
|
search.RecordSearchCost(ctx, search.AbortM, 0)
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
}
|
|
hc.total++
|
|
|
|
err = dmHandler(next)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
next, err = searcher.Next(searchContext)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// help finalize/flush the results in case
|
|
// of custom document match handlers.
|
|
err = dmHandler(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// compute search duration
|
|
hc.took = time.Since(startTime)
|
|
|
|
// finalize actual results
|
|
err = hc.finalizeResults(reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hc *KNNCollector) finalizeResults(r index.IndexReader) error {
|
|
var err error
|
|
hc.results, err = hc.knnStore.Final(func(doc *search.DocumentMatch) error {
|
|
if doc.ID == "" {
|
|
// look up the id since we need it for lookup
|
|
var err error
|
|
doc.ID, err = r.ExternalID(doc.IndexInternalID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (hc *KNNCollector) Results() search.DocumentMatchCollection {
|
|
return hc.results
|
|
}
|
|
|
|
func (hc *KNNCollector) Total() uint64 {
|
|
return hc.total
|
|
}
|
|
|
|
func (hc *KNNCollector) MaxScore() float64 {
|
|
return hc.maxScore
|
|
}
|
|
|
|
func (hc *KNNCollector) Took() time.Duration {
|
|
return hc.took
|
|
}
|
|
|
|
func (hc *KNNCollector) SetFacetsBuilder(facetsBuilder *search.FacetsBuilder) {
|
|
// facet unsupported for vector search
|
|
}
|
|
|
|
func (hc *KNNCollector) FacetResults() search.FacetResults {
|
|
// facet unsupported for vector search
|
|
return nil
|
|
}
|