145 lines
3.5 KiB
Go
145 lines
3.5 KiB
Go
// Copyright (c) 2023 Couchbase, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
//go:build vectors
|
|
// +build vectors
|
|
|
|
package searcher
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"reflect"
|
|
|
|
"github.com/blevesearch/bleve/v2/mapping"
|
|
"github.com/blevesearch/bleve/v2/search"
|
|
"github.com/blevesearch/bleve/v2/search/scorer"
|
|
"github.com/blevesearch/bleve/v2/size"
|
|
index "github.com/blevesearch/bleve_index_api"
|
|
)
|
|
|
|
var reflectStaticSizeKNNSearcher int
|
|
|
|
func init() {
|
|
var ks KNNSearcher
|
|
reflectStaticSizeKNNSearcher = int(reflect.TypeOf(ks).Size())
|
|
}
|
|
|
|
type KNNSearcher struct {
|
|
field string
|
|
vector []float32
|
|
k int64
|
|
indexReader index.IndexReader
|
|
vectorReader index.VectorReader
|
|
scorer *scorer.KNNQueryScorer
|
|
count uint64
|
|
vd index.VectorDoc
|
|
}
|
|
|
|
func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping,
|
|
options search.SearcherOptions, field string, vector []float32, k int64,
|
|
boost float64, similarityMetric string, searchParams json.RawMessage,
|
|
eligibleSelector index.EligibleDocumentSelector) (
|
|
search.Searcher, error) {
|
|
|
|
if vr, ok := i.(index.VectorIndexReader); ok {
|
|
vectorReader, err := vr.VectorReader(ctx, vector, field, k, searchParams, eligibleSelector)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
knnScorer := scorer.NewKNNQueryScorer(vector, field, boost,
|
|
options, similarityMetric)
|
|
return &KNNSearcher{
|
|
indexReader: i,
|
|
vectorReader: vectorReader,
|
|
field: field,
|
|
vector: vector,
|
|
k: k,
|
|
scorer: knnScorer,
|
|
}, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *KNNSearcher) VectorOptimize(ctx context.Context, octx index.VectorOptimizableContext) (
|
|
index.VectorOptimizableContext, error) {
|
|
o, ok := s.vectorReader.(index.VectorOptimizable)
|
|
if ok {
|
|
return o.VectorOptimize(ctx, octx)
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *KNNSearcher) Advance(ctx *search.SearchContext, ID index.IndexInternalID) (
|
|
*search.DocumentMatch, error) {
|
|
knnMatch, err := s.vectorReader.Next(s.vd.Reset())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if knnMatch == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
docMatch := s.scorer.Score(ctx, knnMatch)
|
|
|
|
return docMatch, nil
|
|
}
|
|
|
|
func (s *KNNSearcher) Close() error {
|
|
return s.vectorReader.Close()
|
|
}
|
|
|
|
func (s *KNNSearcher) Count() uint64 {
|
|
return s.vectorReader.Count()
|
|
}
|
|
|
|
func (s *KNNSearcher) DocumentMatchPoolSize() int {
|
|
return 1
|
|
}
|
|
|
|
func (s *KNNSearcher) Min() int {
|
|
return 0
|
|
}
|
|
|
|
func (s *KNNSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch, error) {
|
|
knnMatch, err := s.vectorReader.Next(s.vd.Reset())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if knnMatch == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
docMatch := s.scorer.Score(ctx, knnMatch)
|
|
|
|
return docMatch, nil
|
|
}
|
|
|
|
func (s *KNNSearcher) SetQueryNorm(qnorm float64) {
|
|
s.scorer.SetQueryNorm(qnorm)
|
|
}
|
|
|
|
func (s *KNNSearcher) Size() int {
|
|
return reflectStaticSizeKNNSearcher + size.SizeOfPtr +
|
|
s.vectorReader.Size() +
|
|
s.vd.Size() +
|
|
s.scorer.Size()
|
|
}
|
|
|
|
func (s *KNNSearcher) Weight() float64 {
|
|
return s.scorer.Weight()
|
|
}
|