1
0
Fork 0
golang-github-blevesearch-b.../search/collector/knn.go
Daniel Baumann 982828099e
Adding upstream version 2.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-05-19 00:20:02 +02:00

262 lines
6.8 KiB
Go

// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build vectors
// +build vectors
package collector
import (
"context"
"time"
"github.com/blevesearch/bleve/v2/search"
index "github.com/blevesearch/bleve_index_api"
)
type collectStoreKNN struct {
internalHeaps []collectorStore
kValues []int64
allHits map[*search.DocumentMatch]struct{}
ejectedDocs map[*search.DocumentMatch]struct{}
}
func newStoreKNN(internalHeaps []collectorStore, kValues []int64) *collectStoreKNN {
return &collectStoreKNN{
internalHeaps: internalHeaps,
kValues: kValues,
ejectedDocs: make(map[*search.DocumentMatch]struct{}),
allHits: make(map[*search.DocumentMatch]struct{}),
}
}
// Adds a document to the collector store and returns the documents that were ejected
// from the store. The documents that were ejected from the store are the ones that
// were not in the top K documents for any of the heaps.
// These document are put back into the pool document match pool in the KNN Collector.
func (c *collectStoreKNN) AddDocument(doc *search.DocumentMatch) []*search.DocumentMatch {
for heapIdx := 0; heapIdx < len(c.internalHeaps); heapIdx++ {
if _, ok := doc.ScoreBreakdown[heapIdx]; !ok {
continue
}
ejectedDoc := c.internalHeaps[heapIdx].AddNotExceedingSize(doc, int(c.kValues[heapIdx]))
if ejectedDoc != nil {
delete(ejectedDoc.ScoreBreakdown, heapIdx)
c.ejectedDocs[ejectedDoc] = struct{}{}
}
}
var rv []*search.DocumentMatch
for doc := range c.ejectedDocs {
if len(doc.ScoreBreakdown) == 0 {
rv = append(rv, doc)
}
// clear out the ejectedDocs map to reuse it in the next AddDocument call
delete(c.ejectedDocs, doc)
}
return rv
}
func (c *collectStoreKNN) Final(fixup collectorFixup) (search.DocumentMatchCollection, error) {
for _, heap := range c.internalHeaps {
for _, doc := range heap.Internal() {
// duplicates may be present across the internal heaps
// meaning the same document match may be in the top K
// for multiple KNN queries.
c.allHits[doc] = struct{}{}
}
}
size := len(c.allHits)
if size <= 0 {
return make(search.DocumentMatchCollection, 0), nil
}
rv := make(search.DocumentMatchCollection, size)
i := 0
for doc := range c.allHits {
if fixup != nil {
err := fixup(doc)
if err != nil {
return nil, err
}
}
rv[i] = doc
i++
}
return rv, nil
}
func MakeKNNDocMatchHandler(ctx *search.SearchContext) (search.DocumentMatchHandler, error) {
var hc *KNNCollector
var ok bool
if hc, ok = ctx.Collector.(*KNNCollector); ok {
return func(d *search.DocumentMatch) error {
if d == nil {
return nil
}
toRelease := hc.knnStore.AddDocument(d)
for _, doc := range toRelease {
ctx.DocumentMatchPool.Put(doc)
}
return nil
}, nil
}
return nil, nil
}
func GetNewKNNCollectorStore(kArray []int64) *collectStoreKNN {
internalHeaps := make([]collectorStore, len(kArray))
for knnIdx, k := range kArray {
// TODO - Check if the datatype of k can be made into an int instead of int64
idx := knnIdx
internalHeaps[idx] = getOptimalCollectorStore(int(k), 0, func(i, j *search.DocumentMatch) int {
if i.ScoreBreakdown[idx] < j.ScoreBreakdown[idx] {
return 1
}
return -1
})
}
return newStoreKNN(internalHeaps, kArray)
}
// implements Collector interface
type KNNCollector struct {
knnStore *collectStoreKNN
size int
total uint64
took time.Duration
results search.DocumentMatchCollection
maxScore float64
}
func NewKNNCollector(kArray []int64, size int64) *KNNCollector {
return &KNNCollector{
knnStore: GetNewKNNCollectorStore(kArray),
size: int(size),
}
}
func (hc *KNNCollector) Collect(ctx context.Context, searcher search.Searcher, reader index.IndexReader) error {
startTime := time.Now()
var err error
var next *search.DocumentMatch
// pre-allocate enough space in the DocumentMatchPool
// unless the sum of K is too large, then cap it
// everything should still work, just allocates DocumentMatches on demand
backingSize := hc.size
if backingSize > PreAllocSizeSkipCap {
backingSize = PreAllocSizeSkipCap + 1
}
searchContext := &search.SearchContext{
DocumentMatchPool: search.NewDocumentMatchPool(backingSize+searcher.DocumentMatchPoolSize(), 0),
Collector: hc,
IndexReader: reader,
}
dmHandlerMakerKNN := MakeKNNDocMatchHandler
if cv := ctx.Value(search.MakeKNNDocumentMatchHandlerKey); cv != nil {
dmHandlerMakerKNN = cv.(search.MakeKNNDocumentMatchHandler)
}
// use the application given builder for making the custom document match
// handler and perform callbacks/invocations on the newly made handler.
dmHandler, err := dmHandlerMakerKNN(searchContext)
if err != nil {
return err
}
select {
case <-ctx.Done():
search.RecordSearchCost(ctx, search.AbortM, 0)
return ctx.Err()
default:
next, err = searcher.Next(searchContext)
}
for err == nil && next != nil {
if hc.total%CheckDoneEvery == 0 {
select {
case <-ctx.Done():
search.RecordSearchCost(ctx, search.AbortM, 0)
return ctx.Err()
default:
}
}
hc.total++
err = dmHandler(next)
if err != nil {
break
}
next, err = searcher.Next(searchContext)
}
if err != nil {
return err
}
// help finalize/flush the results in case
// of custom document match handlers.
err = dmHandler(nil)
if err != nil {
return err
}
// compute search duration
hc.took = time.Since(startTime)
// finalize actual results
err = hc.finalizeResults(reader)
if err != nil {
return err
}
return nil
}
func (hc *KNNCollector) finalizeResults(r index.IndexReader) error {
var err error
hc.results, err = hc.knnStore.Final(func(doc *search.DocumentMatch) error {
if doc.ID == "" {
// look up the id since we need it for lookup
var err error
doc.ID, err = r.ExternalID(doc.IndexInternalID)
if err != nil {
return err
}
}
return nil
})
return err
}
func (hc *KNNCollector) Results() search.DocumentMatchCollection {
return hc.results
}
func (hc *KNNCollector) Total() uint64 {
return hc.total
}
func (hc *KNNCollector) MaxScore() float64 {
return hc.maxScore
}
func (hc *KNNCollector) Took() time.Duration {
return hc.took
}
func (hc *KNNCollector) SetFacetsBuilder(facetsBuilder *search.FacetsBuilder) {
// facet unsupported for vector search
}
func (hc *KNNCollector) FacetResults() search.FacetResults {
// facet unsupported for vector search
return nil
}