1
0
Fork 0
golang-github-blevesearch-b.../search_knn_test.go
Daniel Baumann 982828099e
Adding upstream version 2.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-05-19 00:20:02 +02:00

1703 lines
46 KiB
Go

// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build vectors
// +build vectors
package bleve
import (
"archive/zip"
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"math/rand"
"sort"
"strconv"
"sync"
"testing"
"github.com/blevesearch/bleve/v2/analysis/lang/en"
"github.com/blevesearch/bleve/v2/index/scorch"
"github.com/blevesearch/bleve/v2/mapping"
"github.com/blevesearch/bleve/v2/search"
"github.com/blevesearch/bleve/v2/search/query"
index "github.com/blevesearch/bleve_index_api"
)
const testInputCompressedFile = "test/knn/knn_dataset_queries.zip"
const testDatasetFileName = "knn_dataset.json"
const testQueryFileName = "knn_queries.json"
const testDatasetDims = 384
var knnOperators []knnOperator = []knnOperator{knnOperatorAnd, knnOperatorOr}
func TestSimilaritySearchPartitionedIndex(t *testing.T) {
dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile)
if err != nil {
t.Fatal(err)
}
documents := makeDatasetIntoDocuments(dataset)
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Analyzer = en.AnalyzerName
vecFieldMappingL2 := mapping.NewVectorFieldMapping()
vecFieldMappingL2.Dims = testDatasetDims
vecFieldMappingL2.Similarity = index.EuclideanDistance
indexMappingL2Norm := NewIndexMapping()
indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2)
vecFieldMappingDot := mapping.NewVectorFieldMapping()
vecFieldMappingDot.Dims = testDatasetDims
vecFieldMappingDot.Similarity = index.InnerProduct
indexMappingDotProduct := NewIndexMapping()
indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingDot)
vecFieldMappingCosine := mapping.NewVectorFieldMapping()
vecFieldMappingCosine.Dims = testDatasetDims
vecFieldMappingCosine.Similarity = index.CosineSimilarity
indexMappingCosine := NewIndexMapping()
indexMappingCosine.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingCosine.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingCosine)
type testCase struct {
testType string
queryIndex int
numIndexPartitions int
mapping mapping.IndexMapping
}
testCases := []testCase{
// l2 norm similarity
{
testType: "multi_partition:match_none:oneKNNreq:k=3",
queryIndex: 0,
numIndexPartitions: 4,
mapping: indexMappingL2Norm,
},
{
testType: "multi_partition:match_none:oneKNNreq:k=2",
queryIndex: 0,
numIndexPartitions: 10,
mapping: indexMappingL2Norm,
},
{
testType: "multi_partition:match:oneKNNreq:k=2",
queryIndex: 1,
numIndexPartitions: 5,
mapping: indexMappingL2Norm,
},
{
testType: "multi_partition:disjunction:twoKNNreq:k=2,2",
queryIndex: 2,
numIndexPartitions: 4,
mapping: indexMappingL2Norm,
},
// dot product similarity
{
testType: "multi_partition:match_none:oneKNNreq:k=3",
queryIndex: 0,
numIndexPartitions: 4,
mapping: indexMappingDotProduct,
},
{
testType: "multi_partition:match_none:oneKNNreq:k=2",
queryIndex: 0,
numIndexPartitions: 10,
mapping: indexMappingDotProduct,
},
{
testType: "multi_partition:match:oneKNNreq:k=2",
queryIndex: 1,
numIndexPartitions: 5,
mapping: indexMappingDotProduct,
},
{
testType: "multi_partition:disjunction:twoKNNreq:k=2,2",
queryIndex: 2,
numIndexPartitions: 4,
mapping: indexMappingDotProduct,
},
// cosine similarity
{
testType: "multi_partition:match_none:oneKNNreq:k=3",
queryIndex: 0,
numIndexPartitions: 7,
mapping: indexMappingCosine,
},
{
testType: "multi_partition:match_none:oneKNNreq:k=2",
queryIndex: 0,
numIndexPartitions: 5,
mapping: indexMappingCosine,
},
{
testType: "multi_partition:match:oneKNNreq:k=2",
queryIndex: 1,
numIndexPartitions: 3,
mapping: indexMappingCosine,
},
{
testType: "multi_partition:disjunction:twoKNNreq:k=2,2",
queryIndex: 2,
numIndexPartitions: 9,
mapping: indexMappingCosine,
},
}
index := NewIndexAlias()
var reqSort = search.SortOrder{&search.SortScore{Desc: true}, &search.SortDocID{Desc: true}, &search.SortField{Desc: false, Field: "content"}}
for testCaseNum, testCase := range testCases {
originalRequest := searchRequests[testCase.queryIndex]
for _, operator := range knnOperators {
index.indexes = make([]Index, 0)
query := copySearchRequest(originalRequest, nil)
query.AddKNNOperator(operator)
query.Sort = reqSort.Copy()
query.Explain = true
nameToIndex := createPartitionedIndex(documents, index, 1, testCase.mapping, t, false)
controlResult, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(controlResult.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected control result hits to have valid `Index`", testCaseNum)
}
cleanUp(t, nameToIndex)
index.indexes = make([]Index, 0)
query = copySearchRequest(originalRequest, nil)
query.AddKNNOperator(operator)
query.Sort = reqSort.Copy()
query.Explain = true
nameToIndex = createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, false)
experimentalResult, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(experimentalResult.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
verifyResult(t, controlResult, experimentalResult, testCaseNum, true)
cleanUp(t, nameToIndex)
index.indexes = make([]Index, 0)
query = copySearchRequest(originalRequest, nil)
query.AddKNNOperator(operator)
query.Sort = reqSort.Copy()
query.Explain = true
nameToIndex = createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, true)
multiLevelIndexResult, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(multiLevelIndexResult.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
verifyResult(t, multiLevelIndexResult, experimentalResult, testCaseNum, false)
cleanUp(t, nameToIndex)
}
}
var facets = map[string]*FacetRequest{
"content": {
Field: "content",
Size: 10,
},
}
index = NewIndexAlias()
for testCaseNum, testCase := range testCases {
index.indexes = make([]Index, 0)
nameToIndex := createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, false)
originalRequest := searchRequests[testCase.queryIndex]
for _, operator := range knnOperators {
from, size := originalRequest.From, originalRequest.Size
query := copySearchRequest(originalRequest, nil)
query.AddKNNOperator(operator)
query.Explain = true
query.From = from
query.Size = size
// Three types of queries to run wrt sort and facet fields that require fields.
// 1. Sort And Facet are there
// 2. Sort is there, Facet is not there
// 3. Sort is not there, Facet is there
// The case where both sort and facet are not there is already covered in the previous tests.
// 1. Sort And Facet are there
query.Facets = facets
query.Sort = reqSort.Copy()
res1, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(res1.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
facetRes1 := res1.Facets
facetRes1Str, err := json.Marshal(facetRes1)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
// 2. Sort is there, Facet is not there
query.Facets = nil
query.Sort = reqSort.Copy()
res2, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(res2.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
// 3. Sort is not there, Facet is there
query.Facets = facets
query.Sort = nil
res3, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(res3.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
facetRes3 := res3.Facets
facetRes3Str, err := json.Marshal(facetRes3)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
// Verify the facet results
if string(facetRes1Str) != string(facetRes3Str) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected facet results to be equal", testCaseNum)
}
// Verify the results
verifyResult(t, res1, res2, testCaseNum, false)
verifyResult(t, res2, res3, testCaseNum, true)
// Test early exit fail case -> matchNone + facetRequest
query.Query = NewMatchNoneQuery()
query.Sort = reqSort.Copy()
// control case
query.Facets = nil
res4Ctrl, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(res4Ctrl.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected control Result hits to have valid `Index`", testCaseNum)
}
// experimental case
query.Facets = facets
res4Exp, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(res4Exp.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum)
}
if !(operator == knnOperatorAnd && res4Ctrl.Total == 0 && res4Exp.Total == 0) {
// catch case where no hits are returned
// due to matchNone query with a KNN request with operator AND
// where no hits are part of the intersection in multi knn request
verifyResult(t, res4Ctrl, res4Exp, testCaseNum, false)
}
}
cleanUp(t, nameToIndex)
}
// Test Pagination with multi partitioned index
index = NewIndexAlias()
index.indexes = make([]Index, 0)
nameToIndex := createPartitionedIndex(documents, index, 8, indexMappingL2Norm, t, true)
// Test From + Size pagination for Hybrid Search (2-Phase)
query := copySearchRequest(searchRequests[4], nil)
query.Sort = reqSort.Copy()
query.Facets = facets
query.Explain = true
testFromSizePagination(t, query, index, nameToIndex)
// Test From + Size pagination for Early Exit Hybrid Search (1-Phase)
query = copySearchRequest(searchRequests[4], nil)
query.Query = NewMatchNoneQuery()
query.Sort = reqSort.Copy()
query.Facets = nil
query.Explain = true
testFromSizePagination(t, query, index, nameToIndex)
cleanUp(t, nameToIndex)
}
func testFromSizePagination(t *testing.T, query *SearchRequest, index Index, nameToIndex map[string]Index) {
query.From = 0
query.Size = 30
resCtrl, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
ctrlHitIds := make([]string, len(resCtrl.Hits))
for i, doc := range resCtrl.Hits {
ctrlHitIds[i] = doc.ID
}
// experimental case
fromValues := []int{0, 5, 10, 15, 20, 25}
size := 5
for fromIdx := 0; fromIdx < len(fromValues); fromIdx++ {
from := fromValues[fromIdx]
query.From = from
query.Size = size
resExp, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if from >= len(ctrlHitIds) {
if len(resExp.Hits) != 0 {
cleanUp(t, nameToIndex)
t.Fatalf("expected 0 hits, got %d", len(resExp.Hits))
}
continue
}
numHitsExp := len(resExp.Hits)
numHitsCtrl := min(len(ctrlHitIds)-from, size)
if numHitsExp != numHitsCtrl {
cleanUp(t, nameToIndex)
t.Fatalf("expected %d hits, got %d", numHitsCtrl, numHitsExp)
}
for i := 0; i < numHitsExp; i++ {
doc := resExp.Hits[i]
startOffset := from + i
if doc.ID != ctrlHitIds[startOffset] {
cleanUp(t, nameToIndex)
t.Fatalf("expected %s at index %d, got %s", ctrlHitIds[startOffset], i, doc.ID)
}
}
}
}
func TestVectorBase64Index(t *testing.T) {
dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile)
if err != nil {
t.Fatal(err)
}
documents := makeDatasetIntoDocuments(dataset)
_, searchRequestsCopy, err := readDatasetAndQueries(testInputCompressedFile)
if err != nil {
t.Fatal(err)
}
for _, doc := range documents {
vec, ok := doc["vector"].([]float32)
if !ok {
t.Fatal("Typecasting vector to float array failed")
}
buf := new(bytes.Buffer)
for _, v := range vec {
err := binary.Write(buf, binary.LittleEndian, v)
if err != nil {
t.Fatal(err)
}
}
doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(buf.Bytes())
}
for _, sr := range searchRequestsCopy {
for _, kr := range sr.KNN {
kr.Field = "vectorEncoded"
}
}
contentFM := NewTextFieldMapping()
contentFM.Analyzer = en.AnalyzerName
vecFML2 := mapping.NewVectorFieldMapping()
vecFML2.Dims = testDatasetDims
vecFML2.Similarity = index.EuclideanDistance
vecBFML2 := mapping.NewVectorBase64FieldMapping()
vecBFML2.Dims = testDatasetDims
vecBFML2.Similarity = index.EuclideanDistance
vecFMDot := mapping.NewVectorFieldMapping()
vecFMDot.Dims = testDatasetDims
vecFMDot.Similarity = index.InnerProduct
vecBFMDot := mapping.NewVectorBase64FieldMapping()
vecBFMDot.Dims = testDatasetDims
vecBFMDot.Similarity = index.InnerProduct
indexMappingL2 := NewIndexMapping()
indexMappingL2.DefaultMapping.AddFieldMappingsAt("content", contentFM)
indexMappingL2.DefaultMapping.AddFieldMappingsAt("vector", vecFML2)
indexMappingL2.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFML2)
indexMappingDot := NewIndexMapping()
indexMappingDot.DefaultMapping.AddFieldMappingsAt("content", contentFM)
indexMappingDot.DefaultMapping.AddFieldMappingsAt("vector", vecFMDot)
indexMappingDot.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFMDot)
tmpIndexPathL2 := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPathL2)
tmpIndexPathDot := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPathDot)
indexL2, err := New(tmpIndexPathL2, indexMappingL2)
if err != nil {
t.Fatal(err)
}
defer func() {
err := indexL2.Close()
if err != nil {
t.Fatal(err)
}
}()
indexDot, err := New(tmpIndexPathDot, indexMappingDot)
if err != nil {
t.Fatal(err)
}
defer func() {
err := indexDot.Close()
if err != nil {
t.Fatal(err)
}
}()
batchL2 := indexL2.NewBatch()
batchDot := indexDot.NewBatch()
for _, doc := range documents {
err = batchL2.Index(doc["id"].(string), doc)
if err != nil {
t.Fatal(err)
}
err = batchDot.Index(doc["id"].(string), doc)
if err != nil {
t.Fatal(err)
}
}
err = indexL2.Batch(batchL2)
if err != nil {
t.Fatal(err)
}
err = indexDot.Batch(batchDot)
if err != nil {
t.Fatal(err)
}
for i := range searchRequests {
for _, operator := range knnOperators {
controlQuery := searchRequests[i]
testQuery := searchRequestsCopy[i]
controlQuery.AddKNNOperator(operator)
testQuery.AddKNNOperator(operator)
controlResultL2, err := indexL2.Search(controlQuery)
if err != nil {
t.Fatal(err)
}
testResultL2, err := indexL2.Search(testQuery)
if err != nil {
t.Fatal(err)
}
if controlResultL2 != nil && testResultL2 != nil {
if len(controlResultL2.Hits) == len(testResultL2.Hits) {
for j := range controlResultL2.Hits {
if controlResultL2.Hits[j].ID != testResultL2.Hits[j].ID {
t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultL2.Hits[j].ID, testResultL2.Hits[j].ID)
}
}
}
} else if (controlResultL2 == nil && testResultL2 != nil) ||
(controlResultL2 != nil && testResultL2 == nil) {
t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultL2, testResultL2)
}
controlResultDot, err := indexDot.Search(controlQuery)
if err != nil {
t.Fatal(err)
}
testResultDot, err := indexDot.Search(testQuery)
if err != nil {
t.Fatal(err)
}
if controlResultDot != nil && testResultDot != nil {
if len(controlResultDot.Hits) == len(testResultDot.Hits) {
for j := range controlResultDot.Hits {
if controlResultDot.Hits[j].ID != testResultDot.Hits[j].ID {
t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultDot.Hits[j].ID, testResultDot.Hits[j].ID)
}
}
}
} else if (controlResultDot == nil && testResultDot != nil) ||
(controlResultDot != nil && testResultDot == nil) {
t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultDot, testResultDot)
}
}
}
}
type testDocument struct {
ID string `json:"id"`
Content string `json:"content"`
Vector []float32 `json:"vector"`
}
func readDatasetAndQueries(fileName string) ([]testDocument, []*SearchRequest, error) {
// Open the zip archive for reading
r, err := zip.OpenReader(fileName)
if err != nil {
return nil, nil, err
}
var dataset []testDocument
var queries []*SearchRequest
defer r.Close()
for _, f := range r.File {
jsonFile, err := f.Open()
if err != nil {
return nil, nil, err
}
defer jsonFile.Close()
if f.Name == testDatasetFileName {
err = json.NewDecoder(jsonFile).Decode(&dataset)
if err != nil {
return nil, nil, err
}
} else if f.Name == testQueryFileName {
err = json.NewDecoder(jsonFile).Decode(&queries)
if err != nil {
return nil, nil, err
}
}
}
return dataset, queries, nil
}
func makeDatasetIntoDocuments(dataset []testDocument) []map[string]interface{} {
documents := make([]map[string]interface{}, len(dataset))
for i := 0; i < len(dataset); i++ {
document := make(map[string]interface{})
document["id"] = dataset[i].ID
document["content"] = dataset[i].Content
document["vector"] = dataset[i].Vector
documents[i] = document
}
return documents
}
func cleanUp(t *testing.T, nameToIndex map[string]Index) {
for path, childIndex := range nameToIndex {
err := childIndex.Close()
if err != nil {
t.Fatal(err)
}
cleanupTmpIndexPath(t, path)
}
}
func createChildIndex(docs []map[string]interface{}, mapping mapping.IndexMapping, t *testing.T, nameToIndex map[string]Index) Index {
tmpIndexPath := createTmpIndexPath(t)
index, err := New(tmpIndexPath, mapping)
if err != nil {
t.Fatal(err)
}
nameToIndex[index.Name()] = index
batch := index.NewBatch()
for _, doc := range docs {
err := batch.Index(doc["id"].(string), doc)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
}
err = index.Batch(batch)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
return index
}
func createPartitionedIndex(documents []map[string]interface{}, index *indexAliasImpl, numPartitions int,
mapping mapping.IndexMapping, t *testing.T, multiLevel bool) map[string]Index {
partitionSize := len(documents) / numPartitions
extraDocs := len(documents) % numPartitions
numDocsPerPartition := make([]int, numPartitions)
for i := 0; i < numPartitions; i++ {
numDocsPerPartition[i] = partitionSize
if extraDocs > 0 {
numDocsPerPartition[i]++
extraDocs--
}
}
docsPerPartition := make([][]map[string]interface{}, numPartitions)
prevCutoff := 0
for i := 0; i < numPartitions; i++ {
docsPerPartition[i] = make([]map[string]interface{}, numDocsPerPartition[i])
for j := 0; j < numDocsPerPartition[i]; j++ {
docsPerPartition[i][j] = documents[prevCutoff+j]
}
prevCutoff += numDocsPerPartition[i]
}
rv := make(map[string]Index)
if !multiLevel {
// all indexes are at the same level
for i := 0; i < numPartitions; i++ {
index.Add(createChildIndex(docsPerPartition[i], mapping, t, rv))
}
} else {
// alias tree
indexes := make([]Index, numPartitions)
for i := 0; i < numPartitions; i++ {
indexes[i] = createChildIndex(docsPerPartition[i], mapping, t, rv)
}
numAlias := int(math.Ceil(float64(numPartitions) / 2.0))
aliases := make([]IndexAlias, numAlias)
for i := 0; i < numAlias; i++ {
aliases[i] = NewIndexAlias()
aliases[i].SetName(fmt.Sprintf("alias%d", i))
for j := 0; j < 2; j++ {
if i*2+j < numPartitions {
aliases[i].Add(indexes[i*2+j])
}
}
}
for i := 0; i < numAlias; i++ {
index.Add(aliases[i])
}
}
return rv
}
func createMultipleSegmentsIndex(documents []map[string]interface{}, index Index, numSegments int) error {
// create multiple batches to simulate more than one segment
numBatches := numSegments
batches := make([]*Batch, numBatches)
numDocsPerBatch := len(documents) / numBatches
extraDocs := len(documents) % numBatches
docsPerBatch := make([]int, numBatches)
for i := 0; i < numBatches; i++ {
docsPerBatch[i] = numDocsPerBatch
if extraDocs > 0 {
docsPerBatch[i]++
extraDocs--
}
}
prevCutoff := 0
for i := 0; i < numBatches; i++ {
batches[i] = index.NewBatch()
for j := prevCutoff; j < prevCutoff+docsPerBatch[i]; j++ {
doc := documents[j]
err := batches[i].Index(doc["id"].(string), doc)
if err != nil {
return err
}
}
prevCutoff += docsPerBatch[i]
}
errMutex := sync.Mutex{}
var errors []error
wg := sync.WaitGroup{}
wg.Add(len(batches))
for i, batch := range batches {
go func(ix int, batchx *Batch) {
defer wg.Done()
err := index.Batch(batchx)
if err != nil {
errMutex.Lock()
errors = append(errors, err)
errMutex.Unlock()
}
}(i, batch)
}
wg.Wait()
if len(errors) > 0 {
return errors[0]
}
return nil
}
func truncateScore(score float64) float64 {
epsilon := 1e-4
truncated := float64(int(score*1e6)) / 1e6
if math.Abs(truncated-1.0) <= epsilon {
return 1.0
}
return truncated
}
// Function to compare two Explanation structs recursively
func compareExplanation(a, b *search.Explanation) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if truncateScore(a.Value) != truncateScore(b.Value) || len(a.Children) != len(b.Children) {
return false
}
// Sort the children slices before comparison
sortChildren(a.Children)
sortChildren(b.Children)
for i := range a.Children {
if !compareExplanation(a.Children[i], b.Children[i]) {
return false
}
}
return true
}
// Function to sort the children slices
func sortChildren(children []*search.Explanation) {
sort.Slice(children, func(i, j int) bool {
return children[i].Value < children[j].Value
})
}
// All hits from a hybrid search/knn search should not have
// index names or score breakdown.
func finalHitsOmitKNNMetadata(hits []*search.DocumentMatch) bool {
for _, hit := range hits {
if hit.IndexNames != nil || hit.ScoreBreakdown != nil {
return false
}
}
return true
}
func finalHitsHaveValidIndex(hits []*search.DocumentMatch, indexes map[string]Index) bool {
for _, hit := range hits {
if hit.Index == "" {
return false
}
var idx Index
var ok bool
if idx, ok = indexes[hit.Index]; !ok {
return false
}
if idx == nil {
return false
}
var doc index.Document
doc, err = idx.Document(hit.ID)
if err != nil {
return false
}
if doc == nil {
return false
}
}
return true
}
func verifyResult(t *testing.T, controlResult *SearchResult, experimentalResult *SearchResult, testCaseNum int, verifyOnlyDocIDs bool) {
if controlResult.Hits.Len() == 0 || experimentalResult.Hits.Len() == 0 {
t.Fatalf("test case #%d failed: 0 hits returned", testCaseNum)
}
if len(controlResult.Hits) != len(experimentalResult.Hits) {
t.Fatalf("test case #%d failed: expected %d results, got %d", testCaseNum, len(controlResult.Hits), len(experimentalResult.Hits))
}
if controlResult.Total != experimentalResult.Total {
t.Fatalf("test case #%d failed: expected total hits to be %d, got %d", testCaseNum, controlResult.Total, experimentalResult.Total)
}
// KNN Metadata -> Score Breakdown and IndexNames MUST be omitted from the final hits
if !finalHitsOmitKNNMetadata(controlResult.Hits) || !finalHitsOmitKNNMetadata(experimentalResult.Hits) {
t.Fatalf("test case #%d failed: expected no KNN metadata in hits", testCaseNum)
}
if controlResult.Took == 0 || experimentalResult.Took == 0 {
t.Fatalf("test case #%d failed: expected non-zero took time", testCaseNum)
}
if controlResult.Request == nil || experimentalResult.Request == nil {
t.Fatalf("test case #%d failed: expected non-nil request", testCaseNum)
}
if verifyOnlyDocIDs {
// in multi partitioned index, we cannot be sure of the score or the ordering of the hits as the tf-idf scores are localized to each partition
// so we only check the ids
controlMap := make(map[string]struct{})
experimentalMap := make(map[string]struct{})
for _, hit := range controlResult.Hits {
controlMap[hit.ID] = struct{}{}
}
for _, hit := range experimentalResult.Hits {
experimentalMap[hit.ID] = struct{}{}
}
if len(controlMap) != len(experimentalMap) {
t.Fatalf("test case #%d failed: expected %d results, got %d", testCaseNum, len(controlMap), len(experimentalMap))
}
for id := range controlMap {
if _, ok := experimentalMap[id]; !ok {
t.Fatalf("test case #%d failed: expected id %s to be in experimental result", testCaseNum, id)
}
}
return
}
for i := 0; i < len(controlResult.Hits); i++ {
if controlResult.Hits[i].ID != experimentalResult.Hits[i].ID {
t.Fatalf("test case #%d failed: expected hit %d to have id %s, got %s", testCaseNum, i, controlResult.Hits[i].ID, experimentalResult.Hits[i].ID)
}
// Truncate to 6 decimal places
actualScore := truncateScore(experimentalResult.Hits[i].Score)
expectScore := truncateScore(controlResult.Hits[i].Score)
if expectScore != actualScore {
t.Fatalf("test case #%d failed: expected hit %d to have score %f, got %f", testCaseNum, i, expectScore, actualScore)
}
if !compareExplanation(controlResult.Hits[i].Expl, experimentalResult.Hits[i].Expl) {
t.Fatalf("test case #%d failed: expected hit %d to have explanation %v, got %v", testCaseNum, i, controlResult.Hits[i].Expl, experimentalResult.Hits[i].Expl)
}
}
if truncateScore(controlResult.MaxScore) != truncateScore(experimentalResult.MaxScore) {
t.Fatalf("test case #%d: expected maxScore to be %f, got %f", testCaseNum, controlResult.MaxScore, experimentalResult.MaxScore)
}
}
func TestSimilaritySearchMultipleSegments(t *testing.T) {
// using scorch options to prevent merges during the course of this test
// so that the knnCollector can be accurately tested
scorch.DefaultMemoryPressurePauseThreshold = 0
scorch.DefaultMinSegmentsForInMemoryMerge = math.MaxInt
dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile)
if err != nil {
t.Fatal(err)
}
documents := makeDatasetIntoDocuments(dataset)
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Analyzer = en.AnalyzerName
vecFieldMappingL2 := mapping.NewVectorFieldMapping()
vecFieldMappingL2.Dims = testDatasetDims
vecFieldMappingL2.Similarity = index.EuclideanDistance
vecFieldMappingDot := mapping.NewVectorFieldMapping()
vecFieldMappingDot.Dims = testDatasetDims
vecFieldMappingDot.Similarity = index.InnerProduct
vecFieldMappingCosine := mapping.NewVectorFieldMapping()
vecFieldMappingCosine.Dims = testDatasetDims
vecFieldMappingCosine.Similarity = index.CosineSimilarity
indexMappingL2Norm := NewIndexMapping()
indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2)
indexMappingDotProduct := NewIndexMapping()
indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingDot)
indexMappingCosine := NewIndexMapping()
indexMappingCosine.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
indexMappingCosine.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingCosine)
var reqSort = search.SortOrder{&search.SortScore{Desc: true}, &search.SortDocID{Desc: true}, &search.SortField{Desc: false, Field: "content"}}
testCases := []struct {
numSegments int
queryIndex int
mapping mapping.IndexMapping
scoreValue string
}{
// L2 norm similarity
{
numSegments: 6,
queryIndex: 0,
mapping: indexMappingL2Norm,
},
{
numSegments: 7,
queryIndex: 1,
mapping: indexMappingL2Norm,
},
{
numSegments: 8,
queryIndex: 2,
mapping: indexMappingL2Norm,
},
{
numSegments: 9,
queryIndex: 3,
mapping: indexMappingL2Norm,
},
{
numSegments: 10,
queryIndex: 4,
mapping: indexMappingL2Norm,
},
{
numSegments: 11,
queryIndex: 5,
mapping: indexMappingL2Norm,
},
// dot_product similarity
{
numSegments: 6,
queryIndex: 0,
mapping: indexMappingDotProduct,
},
{
numSegments: 7,
queryIndex: 1,
mapping: indexMappingDotProduct,
},
{
numSegments: 8,
queryIndex: 2,
mapping: indexMappingDotProduct,
},
{
numSegments: 9,
queryIndex: 3,
mapping: indexMappingDotProduct,
},
{
numSegments: 10,
queryIndex: 4,
mapping: indexMappingDotProduct,
},
{
numSegments: 11,
queryIndex: 5,
mapping: indexMappingDotProduct,
},
// cosine similarity
{
numSegments: 9,
queryIndex: 0,
mapping: indexMappingCosine,
},
{
numSegments: 5,
queryIndex: 1,
mapping: indexMappingCosine,
},
{
numSegments: 4,
queryIndex: 2,
mapping: indexMappingCosine,
},
{
numSegments: 12,
queryIndex: 3,
mapping: indexMappingCosine,
},
{
numSegments: 7,
queryIndex: 4,
mapping: indexMappingCosine,
},
{
numSegments: 11,
queryIndex: 5,
mapping: indexMappingCosine,
},
// score none test
{
numSegments: 3,
queryIndex: 0,
mapping: indexMappingL2Norm,
scoreValue: "none",
},
{
numSegments: 7,
queryIndex: 1,
mapping: indexMappingL2Norm,
scoreValue: "none",
},
{
numSegments: 8,
queryIndex: 2,
mapping: indexMappingL2Norm,
scoreValue: "none",
},
{
numSegments: 3,
queryIndex: 0,
mapping: indexMappingDotProduct,
scoreValue: "none",
},
{
numSegments: 7,
queryIndex: 1,
mapping: indexMappingDotProduct,
scoreValue: "none",
},
{
numSegments: 8,
queryIndex: 2,
mapping: indexMappingDotProduct,
scoreValue: "none",
},
{
numSegments: 3,
queryIndex: 0,
mapping: indexMappingCosine,
scoreValue: "none",
},
{
numSegments: 7,
queryIndex: 1,
mapping: indexMappingCosine,
scoreValue: "none",
},
{
numSegments: 8,
queryIndex: 2,
mapping: indexMappingCosine,
scoreValue: "none",
},
}
for testCaseNum, testCase := range testCases {
originalRequest := searchRequests[testCase.queryIndex]
for _, operator := range knnOperators {
// run single segment test first
tmpIndexPath := createTmpIndexPath(t)
index, err := New(tmpIndexPath, testCase.mapping)
if err != nil {
t.Fatal(err)
}
query := copySearchRequest(originalRequest, nil)
query.Sort = reqSort.Copy()
query.AddKNNOperator(operator)
query.Explain = true
nameToIndex := make(map[string]Index)
nameToIndex[index.Name()] = index
err = createMultipleSegmentsIndex(documents, index, 1)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
controlResult, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(controlResult.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected control result hits to have valid `Index`", testCaseNum)
}
if testCase.scoreValue == "none" {
query := copySearchRequest(originalRequest, nil)
query.Sort = reqSort.Copy()
query.AddKNNOperator(operator)
query.Explain = true
query.Score = testCase.scoreValue
expectedResultScoreNone, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(expectedResultScoreNone.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected score none hits to have valid `Index`", testCaseNum)
}
verifyResult(t, controlResult, expectedResultScoreNone, testCaseNum, true)
}
cleanUp(t, nameToIndex)
// run multiple segments test
tmpIndexPath = createTmpIndexPath(t)
index, err = New(tmpIndexPath, testCase.mapping)
if err != nil {
t.Fatal(err)
}
nameToIndex = make(map[string]Index)
nameToIndex[index.Name()] = index
err = createMultipleSegmentsIndex(documents, index, testCase.numSegments)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
query = copySearchRequest(originalRequest, nil)
query.Sort = reqSort.Copy()
query.AddKNNOperator(operator)
query.Explain = true
experimentalResult, err := index.Search(query)
if err != nil {
cleanUp(t, nameToIndex)
t.Fatal(err)
}
if !finalHitsHaveValidIndex(experimentalResult.Hits, nameToIndex) {
cleanUp(t, nameToIndex)
t.Fatalf("test case #%d failed: expected experimental result hits to have valid `Index`", testCaseNum)
}
verifyResult(t, controlResult, experimentalResult, testCaseNum, false)
cleanUp(t, nameToIndex)
}
}
}
// Test to determine the impact of boost on kNN queries.
func TestKNNScoreBoosting(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)
const dims = 5
getRandomVector := func() []float32 {
vec := make([]float32, dims)
for i := 0; i < dims; i++ {
vec[i] = rand.Float32()
}
return vec
}
dataset := make([]map[string]interface{}, 10)
// Indexing just a few docs to populate index.
for i := 0; i < 100; i++ {
dataset = append(dataset, map[string]interface{}{
"type": "vectorStuff",
"content": strconv.Itoa(i),
"vector": getRandomVector(),
})
}
indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("vectorStuff", documentMapping)
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Index = true
contentFieldMapping.Store = true
documentMapping.AddFieldMappingsAt("content", contentFieldMapping)
vecFieldMapping := mapping.NewVectorFieldMapping()
vecFieldMapping.Index = true
vecFieldMapping.Dims = 5
vecFieldMapping.Similarity = "dot_product"
documentMapping.AddFieldMappingsAt("vector", vecFieldMapping)
index, err := New(tmpIndexPath, indexMapping)
if err != nil {
t.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
t.Fatal(err)
}
}()
batch := index.NewBatch()
for i := 0; i < len(dataset); i++ {
err = batch.Index(strconv.Itoa(i), dataset[i])
if err != nil {
t.Fatal(err)
}
}
err = index.Batch(batch)
if err != nil {
t.Fatal(err)
}
queryVec := getRandomVector()
searchRequest := NewSearchRequest(NewMatchNoneQuery())
searchRequest.AddKNN("vector", queryVec, 3, 1.0)
searchRequest.Fields = []string{"content", "vector"}
hits, _ := index.Search(searchRequest)
hitsMap := make(map[string]float64, 0)
for _, hit := range hits.Hits {
hitsMap[hit.ID] = (hit.Score)
}
searchRequest2 := NewSearchRequest(NewMatchNoneQuery())
searchRequest.AddKNN("vector", queryVec, 3, 10.0)
searchRequest.Fields = []string{"content", "vector"}
hits2, _ := index.Search(searchRequest2)
hitsMap2 := make(map[string]float64, 0)
for _, hit := range hits2.Hits {
hitsMap2[hit.ID] = (hit.Score)
}
for _, hit := range hits2.Hits {
if hitsMap[hit.ID] != hitsMap2[hit.ID]/10 {
t.Errorf("boosting not working: %v %v \n", hitsMap[hit.ID], hitsMap2[hit.ID])
}
}
}
// Test to see if KNN Operators get added right to the query.
func TestKNNOperator(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)
const dims = 5
getRandomVector := func() []float32 {
vec := make([]float32, dims)
for i := 0; i < dims; i++ {
vec[i] = rand.Float32()
}
return vec
}
dataset := make([]map[string]interface{}, 10)
// Indexing just a few docs to populate index.
for i := 0; i < 10; i++ {
dataset = append(dataset, map[string]interface{}{
"type": "vectorStuff",
"content": strconv.Itoa(i),
"vector": getRandomVector(),
})
}
indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("vectorStuff", documentMapping)
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Index = true
contentFieldMapping.Store = true
documentMapping.AddFieldMappingsAt("content", contentFieldMapping)
vecFieldMapping := mapping.NewVectorFieldMapping()
vecFieldMapping.Index = true
vecFieldMapping.Dims = 5
vecFieldMapping.Similarity = "dot_product"
documentMapping.AddFieldMappingsAt("vector", vecFieldMapping)
index, err := New(tmpIndexPath, indexMapping)
if err != nil {
t.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
t.Fatal(err)
}
}()
batch := index.NewBatch()
for i := 0; i < len(dataset); i++ {
err = batch.Index(strconv.Itoa(i), dataset[i])
if err != nil {
t.Fatal(err)
}
}
err = index.Batch(batch)
if err != nil {
t.Fatal(err)
}
termQuery := query.NewTermQuery("2")
searchRequest := NewSearchRequest(termQuery)
searchRequest.AddKNN("vector", getRandomVector(), 3, 2.0)
searchRequest.AddKNN("vector", getRandomVector(), 2, 1.5)
searchRequest.Fields = []string{"content", "vector"}
// Conjunction
searchRequest.AddKNNOperator(knnOperatorAnd)
conjunction, _, _, err := createKNNQuery(searchRequest, nil)
if err != nil {
t.Fatalf("unexpected error for AND knn operator")
}
conj, ok := conjunction.(*query.DisjunctionQuery)
if !ok {
t.Fatalf("expected disjunction query")
}
if len(conj.Disjuncts) != 2 {
t.Fatalf("expected 2 disjuncts")
}
// Disjunction
searchRequest.AddKNNOperator(knnOperatorOr)
disjunction, _, _, err := createKNNQuery(searchRequest, nil)
if err != nil {
t.Fatalf("unexpected error for OR knn operator")
}
disj, ok := disjunction.(*query.DisjunctionQuery)
if !ok {
t.Fatalf("expected disjunction query")
}
if len(disj.Disjuncts) != 2 {
t.Fatalf("expected 2 disjuncts")
}
// Incorrect operator.
searchRequest.AddKNNOperator("bs_op")
searchRequest.Query, _, _, err = createKNNQuery(searchRequest, nil)
if err == nil {
t.Fatalf("expected error for incorrect knn operator")
}
}
func TestKNNFiltering(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)
const dims = 5
getRandomVector := func() []float32 {
vec := make([]float32, dims)
for i := 0; i < dims; i++ {
vec[i] = rand.Float32()
}
return vec
}
dataset := make([]map[string]interface{}, 0)
// Indexing just a few docs to populate index.
for i := 0; i < 10; i++ {
dataset = append(dataset, map[string]interface{}{
"type": "vectorStuff",
"content": strconv.Itoa(i + 1000),
"vector": getRandomVector(),
})
}
indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("vectorStuff", documentMapping)
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Index = true
contentFieldMapping.Store = true
documentMapping.AddFieldMappingsAt("content", contentFieldMapping)
vecFieldMapping := mapping.NewVectorFieldMapping()
vecFieldMapping.Index = true
vecFieldMapping.Dims = 5
vecFieldMapping.Similarity = "dot_product"
documentMapping.AddFieldMappingsAt("vector", vecFieldMapping)
index, err := New(tmpIndexPath, indexMapping)
if err != nil {
t.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
t.Fatal(err)
}
}()
batch := index.NewBatch()
for i := 0; i < len(dataset); i++ {
// the id of term "i" is (i-1000)
err = batch.Index(strconv.Itoa(i), dataset[i])
if err != nil {
t.Fatal(err)
}
}
err = index.Batch(batch)
if err != nil {
t.Fatal(err)
}
termQuery := query.NewTermQuery("1004")
filterRequest := NewSearchRequest(termQuery)
filteredHits, err := index.Search(filterRequest)
if err != nil {
t.Fatal(err)
}
filteredDocIDs := make(map[string]struct{})
for _, match := range filteredHits.Hits {
filteredDocIDs[match.ID] = struct{}{}
}
searchRequest := NewSearchRequest(NewMatchNoneQuery())
searchRequest.AddKNNWithFilter("vector", getRandomVector(), 3, 2.0, termQuery)
searchRequest.Fields = []string{"content", "vector"}
res, err := index.Search(searchRequest)
if err != nil {
t.Fatal(err)
}
// check if any of the returned results are not part of the filtered hits.
for _, match := range res.Hits {
if _, exists := filteredDocIDs[match.ID]; !exists {
t.Errorf("returned result not present in filtered hits")
}
}
// No results should be returned with a match_none filter.
searchRequest = NewSearchRequest(NewMatchNoneQuery())
searchRequest.AddKNNWithFilter("vector", getRandomVector(), 3, 2.0,
NewMatchNoneQuery())
res, err = index.Search(searchRequest)
if err != nil {
t.Fatal(err)
}
if len(res.Hits) != 0 {
t.Errorf("match none filter should return no hits")
}
// Testing with a disjunction query.
termQuery = query.NewTermQuery("1003")
termQuery2 := query.NewTermQuery("1005")
disjQuery := query.NewDisjunctionQuery([]query.Query{termQuery, termQuery2})
filterRequest = NewSearchRequest(disjQuery)
filteredHits, err = index.Search(filterRequest)
if err != nil {
t.Fatal(err)
}
filteredDocIDs = make(map[string]struct{})
for _, match := range filteredHits.Hits {
filteredDocIDs[match.ID] = struct{}{}
}
searchRequest = NewSearchRequest(NewMatchNoneQuery())
searchRequest.AddKNNWithFilter("vector", getRandomVector(), 3, 2.0, disjQuery)
searchRequest.Fields = []string{"content", "vector"}
res, err = index.Search(searchRequest)
if err != nil {
t.Fatal(err)
}
for _, match := range res.Hits {
if _, exists := filteredDocIDs[match.ID]; !exists {
t.Errorf("returned result not present in filtered hits")
}
}
}
// -----------------------------------------------------------------------------
// Test nested vectors
func TestNestedVectors(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)
const dims = 3
const k = 1 // one nearest neighbor
const vecFieldName = "vecData"
dataset := map[string]map[string]interface{}{ // docID -> Doc
"doc1": {
vecFieldName: []float32{100, 100, 100},
},
"doc2": {
vecFieldName: [][]float32{{0, 0, 0}, {1000, 1000, 1000}},
},
}
// Index mapping
indexMapping := NewIndexMapping()
vm := mapping.NewVectorFieldMapping()
vm.Dims = dims
vm.Similarity = "l2_norm"
indexMapping.DefaultMapping.AddFieldMappingsAt(vecFieldName, vm)
// Create index and upload documents
index, err := New(tmpIndexPath, indexMapping)
if err != nil {
t.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
t.Fatal(err)
}
}()
batch := index.NewBatch()
for docID, doc := range dataset {
err = batch.Index(docID, doc)
if err != nil {
t.Fatal(err)
}
}
err = index.Batch(batch)
if err != nil {
t.Fatal(err)
}
// Run searches
tests := []struct {
queryVec []float32
expectedDocID string
}{
{
queryVec: []float32{100, 100, 100},
expectedDocID: "doc1",
},
{
queryVec: []float32{0, 0, 0},
expectedDocID: "doc2",
},
{
queryVec: []float32{1000, 1000, 1000},
expectedDocID: "doc2",
},
}
for _, test := range tests {
searchReq := NewSearchRequest(query.NewMatchNoneQuery())
searchReq.AddKNNWithFilter(vecFieldName, test.queryVec, k, 1000,
NewMatchAllQuery())
res, err := index.Search(searchReq)
if err != nil {
t.Fatal(err)
}
if len(res.Hits) != 1 {
t.Fatalf("expected 1 hit, got %d", len(res.Hits))
}
if res.Hits[0].ID != test.expectedDocID {
t.Fatalf("expected docID %s, got %s", test.expectedDocID,
res.Hits[0].ID)
}
}
}
func TestNumVecsStat(t *testing.T) {
dataset, _, err := readDatasetAndQueries(testInputCompressedFile)
if err != nil {
t.Fatal(err)
}
documents := makeDatasetIntoDocuments(dataset)
indexMapping := NewIndexMapping()
contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Analyzer = en.AnalyzerName
indexMapping.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping)
vecFieldMapping1 := mapping.NewVectorFieldMapping()
vecFieldMapping1.Dims = testDatasetDims
vecFieldMapping1.Similarity = index.EuclideanDistance
indexMapping.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMapping1)
tmpIndexPath := createTmpIndexPath(t)
index, err := New(tmpIndexPath, indexMapping)
if err != nil {
t.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
t.Fatal(err)
}
}()
for i := 0; i < 10; i++ {
batch := index.NewBatch()
for j := 0; j < 3; j++ {
for k := 0; k < 10; k++ {
err := batch.Index(fmt.Sprintf("%d", i*30+j*10+k), documents[j*10+k])
if err != nil {
t.Fatal(err)
}
}
}
err = index.Batch(batch)
if err != nil {
t.Fatal(err)
}
}
statsMap := index.StatsMap()
if indexStats, exists := statsMap["index"]; exists {
if indexStatsMap, ok := indexStats.(map[string]interface{}); ok {
v1, ok := indexStatsMap["field:vector:num_vectors"].(uint64)
if !ok || v1 != uint64(300) {
t.Fatalf("mismatch in the number of vectors, expected 300, got %d", indexStatsMap["field:vector:num_vectors"])
}
}
}
}