Skip to content

Commit

Permalink
.Net: Add qdrant vector search implementation. (#8508)
Browse files Browse the repository at this point in the history
### Motivation and Context

As part of the work on vector storage we have to add vector search
capabilities for each implementation.

### Description

1. Adding vector search for Qdrant.
2. Note that for now, the search interface is implemented directly by
the collection, but in future it should be part of the collection
interface. I'm doing it this way so that each implementation can be
added one by one, and once all have been implemented, we can make the
switch.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
westey-m authored Sep 5, 2024
1 parent d745d57 commit 7836721
Show file tree
Hide file tree
Showing 7 changed files with 551 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
<PackageVersion Include="Milvus.Client" Version="2.3.0-preview.1" />
<PackageVersion Include="Testcontainers.Milvus" Version="3.8.0" />
<PackageVersion Include="Microsoft.Data.SqlClient" Version="5.2.1" />
<PackageVersion Include="Qdrant.Client" Version="1.9.0" />
<PackageVersion Include="Qdrant.Client" Version="1.11.0" />
<!-- Symbols -->
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
<!-- Toolset -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,61 @@ public virtual Task<IReadOnlyList<RetrievedPoint>> RetrieveAsync(
ShardKeySelector? shardKeySelector = null,
CancellationToken cancellationToken = default)
=> this._qdrantClient.RetrieveAsync(collectionName, ids, withPayload, withVectors, readConsistency, shardKeySelector, cancellationToken);

/// <summary>
/// Universally query points.
/// Covers all capabilities of search, recommend, discover, filters.
/// Also enables hybrid and multi-stage queries.
/// </summary>
/// <param name="collectionName">The name of the collection.</param>
/// <param name="query">Query to perform. If missing, returns points ordered by their IDs.</param>
/// <param name="prefetch">Sub-requests to perform first. If present, the query will be performed on the results of the prefetches.</param>
/// <param name="usingVector">Name of the vector to use for querying. If missing, the default vector is used..</param>
/// <param name="filter">Filter conditions - return only those points that satisfy the specified conditions.</param>
/// <param name="scoreThreshold">Return points with scores better than this threshold.</param>
/// <param name="searchParams">Search config.</param>
/// <param name="limit">Max number of results.</param>
/// <param name="offset">Offset of the result.</param>
/// <param name="payloadSelector">Options for specifying which payload to include or not.</param>
/// <param name="vectorsSelector">Options for specifying which vectors to include into the response.</param>
/// <param name="readConsistency">Options for specifying read consistency guarantees.</param>
/// <param name="shardKeySelector">Specify in which shards to look for the points, if not specified - look in all shards.</param>
/// <param name="lookupFrom">The location to use for IDs lookup, if not specified - use the current collection and the 'usingVector' vector</param>
/// <param name="timeout">If set, overrides global timeout setting for this request.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None" />.
/// </param>
public virtual Task<IReadOnlyList<ScoredPoint>> QueryAsync(
string collectionName,
Query? query = null,
IReadOnlyList<PrefetchQuery>? prefetch = null,
string? usingVector = null,
Filter? filter = null,
float? scoreThreshold = null,
SearchParams? searchParams = null,
ulong limit = 10,
ulong offset = 0,
WithPayloadSelector? payloadSelector = null,
WithVectorsSelector? vectorsSelector = null,
ReadConsistency? readConsistency = null,
ShardKeySelector? shardKeySelector = null,
LookupLocation? lookupFrom = null,
TimeSpan? timeout = null,
CancellationToken cancellationToken = default)
=> this._qdrantClient.QueryAsync(
collectionName,
query,
prefetch,
usingVector,
filter,
scoreThreshold,
searchParams,
limit,
offset,
payloadSelector,
vectorsSelector,
readConsistency,
shardKeySelector,
lookupFrom,
timeout,
cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using Microsoft.SemanticKernel.Data;
using Qdrant.Client.Grpc;

namespace Microsoft.SemanticKernel.Connectors.Qdrant;

/// <summary>
/// Contains mapping helpers to use when searching for documents using Qdrant.
/// </summary>
internal static class QdrantVectorStoreCollectionSearchMapping
{
/// <summary>
/// Build a Qdrant <see cref="Filter"/> from the provided <see cref="VectorSearchFilter"/>.
/// </summary>
/// <param name="basicVectorSearchFilter">The <see cref="VectorSearchFilter"/> to build a Qdrant <see cref="Filter"/> from.</param>
/// <param name="storagePropertyNames">A mapping of data model property names to the names under which they are stored.</param>
/// <returns>The Qdrant <see cref="Filter"/>.</returns>
/// <exception cref="InvalidOperationException">Thrown when the provided filter contains unsupported types, values or unknown properties.</exception>
public static Filter BuildFilter(VectorSearchFilter? basicVectorSearchFilter, Dictionary<string, string> storagePropertyNames)
{
var filter = new Filter();

// Return an empty filter if no filter clauses are provided.
if (basicVectorSearchFilter?.FilterClauses is null)
{
return filter;
}

foreach (var filterClause in basicVectorSearchFilter.FilterClauses)
{
string fieldName;
object filterValue;

// In Qdrant, tag list contains is handled using a keyword match, which is the same as a string equality check.
// We can therefore just extract the field name and value from each clause and handle them the same.
if (filterClause is EqualityFilterClause equalityFilterClause)
{
fieldName = equalityFilterClause.FieldName;
filterValue = equalityFilterClause.Value;
}
else if (filterClause is TagListContainsFilterClause tagListContainsClause)
{
fieldName = tagListContainsClause.FieldName;
filterValue = tagListContainsClause.Value;
}
else
{
throw new InvalidOperationException($"Unsupported filter clause type '{filterClause.GetType().Name}'.");
}

// Map each type of filter value to the appropriate Qdrant match type.
var match = filterValue switch
{
string stringValue => new Match { Keyword = stringValue },
int intValue => new Match { Integer = intValue },
long longValue => new Match { Integer = longValue },
bool boolValue => new Match { Boolean = boolValue },
_ => throw new InvalidOperationException($"Unsupported filter value type '{filterValue.GetType().Name}'.")
};

// Get the storage name for the field.
if (!storagePropertyNames.TryGetValue(fieldName, out var storagePropertyName))
{
throw new InvalidOperationException($"Property name '{fieldName}' provided as part of the filter clause is not a valid property name.");
}

filter.Must.Add(new Condition() { Field = new FieldCondition() { Key = storagePropertyName, Match = match } });
}

return filter;
}

/// <summary>
/// Map the given <see cref="ScoredPoint"/> to a <see cref="VectorSearchResult{TRecord}"/>.
/// </summary>
/// <typeparam name="TRecord">The type of the record to map to.</typeparam>
/// <param name="point">The point to map to a <see cref="VectorSearchResult{TRecord}"/>.</param>
/// <param name="mapper">The mapper to perform the main mapping operation with.</param>
/// <param name="includeVectors">A value indicating whether to include vectors in the mapped result.</param>
/// <param name="databaseSystemName">The name of the database system the operation is being run on.</param>
/// <param name="collectionName">The name of the collection the operation is being run on.</param>
/// <param name="operationName">The type of database operation being run.</param>
/// <returns>The mapped <see cref="VectorSearchResult{TRecord}"/>.</returns>
public static VectorSearchResult<TRecord> MapScoredPointToVectorSearchResult<TRecord>(ScoredPoint point, IVectorStoreRecordMapper<TRecord, PointStruct> mapper, bool includeVectors, string databaseSystemName, string collectionName, string operationName)
where TRecord : class
{
// Since the mapper doesn't know about scored points, we need to convert the scored point to a point struct first.
var pointStruct = new PointStruct
{
Id = point.Id,
Vectors = point.Vectors,
Payload = { }
};

foreach (KeyValuePair<string, Value> payloadEntry in point.Payload)
{
pointStruct.Payload.Add(payloadEntry.Key, payloadEntry.Value);
}

// Do the mapping with error handling.
return new VectorSearchResult<TRecord>(
VectorStoreErrorHandler.RunModelConversion(
databaseSystemName,
collectionName,
operationName,
() => mapper.MapFromStorageToDataModel(pointStruct, new() { IncludeVectors = includeVectors })),
point.Score);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant;
/// </summary>
/// <typeparam name="TRecord">The data model to use for adding, updating and retrieving data from storage.</typeparam>
#pragma warning disable CA1711 // Identifiers should not have incorrect suffix
public sealed class QdrantVectorStoreRecordCollection<TRecord> : IVectorStoreRecordCollection<ulong, TRecord>, IVectorStoreRecordCollection<Guid, TRecord>
public sealed class QdrantVectorStoreRecordCollection<TRecord> : IVectorStoreRecordCollection<ulong, TRecord>, IVectorStoreRecordCollection<Guid, TRecord>, IVectorSearch<TRecord>
#pragma warning restore CA1711 // Identifiers should not have incorrect suffix
where TRecord : class
{
Expand Down Expand Up @@ -56,6 +56,9 @@ public sealed class QdrantVectorStoreRecordCollection<TRecord> : IVectorStoreRec
/// <summary>A dictionary that maps from a property name to the configured name that should be used when storing it.</summary>
private readonly Dictionary<string, string> _storagePropertyNames = new();

/// <summary>The name of the first vector field for the collections that this class is used with.</summary>
private readonly string? _firstVectorPropertyName = null;

/// <summary>
/// Initializes a new instance of the <see cref="QdrantVectorStoreRecordCollection{TRecord}"/> class.
/// </summary>
Expand Down Expand Up @@ -95,6 +98,10 @@ internal QdrantVectorStoreRecordCollection(MockableQdrantClient qdrantClient, st

// Build a map of property names to storage names.
this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToStorageNameMap(properties);
if (properties.VectorProperties.Count > 0)
{
this._firstVectorPropertyName = this._storagePropertyNames[properties.VectorProperties.First().DataModelPropertyName];
}

// Assign Mapper.
if (this._options.PointStructCustomMapper is not null)
Expand Down Expand Up @@ -432,6 +439,93 @@ private async IAsyncEnumerable<TRecord> GetBatchByPointIdAsync<TKey>(
}
}

/// <inheritdoc />
public async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync(VectorSearchQuery vectorQuery, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (this._firstVectorPropertyName is null)
{
throw new InvalidOperationException("The collection does not have any vector fields, so vector search is not possible.");
}

if (vectorQuery is VectorizedSearchQuery<ReadOnlyMemory<float>> floatVectorQuery)
{
var internalOptions = floatVectorQuery.SearchOptions ?? Data.VectorSearchOptions.Default;

// Build filter object.
var filter = QdrantVectorStoreCollectionSearchMapping.BuildFilter(internalOptions.VectorSearchFilter, this._storagePropertyNames);

// Specify the vector name if named vectors are used.
string? vectorName = null;
if (this._options.HasNamedVectors)
{
vectorName = this.ResolveVectorFieldName(internalOptions.VectorFieldName);
}

// Specify whether to include vectors in the search results.
var vectorsSelector = new WithVectorsSelector();
vectorsSelector.Enable = internalOptions.IncludeVectors;

var query = new Query
{
Nearest = new VectorInput(floatVectorQuery.Vector.ToArray()),
};

// Execute Search.
var points = await this.RunOperationAsync(
"Query",
() => this._qdrantClient.QueryAsync(
this.CollectionName,
query: query,
usingVector: vectorName,
filter: filter,
limit: (ulong)internalOptions.Limit,
offset: (ulong)internalOptions.Offset,
vectorsSelector: vectorsSelector,
cancellationToken: cancellationToken)).ConfigureAwait(false);

// Map to data model and return results.
foreach (var point in points)
{
yield return QdrantVectorStoreCollectionSearchMapping.MapScoredPointToVectorSearchResult(
point,
this._mapper,
internalOptions.IncludeVectors,
DatabaseName,
this._collectionName,
"Query");
}

yield break;
}

throw new NotSupportedException($"A {nameof(VectorSearchQuery)} of type {vectorQuery.QueryType} is not supported by the Qdrant connector.");
}

/// <summary>
/// Resolve the vector field name to use for a search by using the storage name for the field name from options
/// if available, and falling back to the first vector field name if not.
/// </summary>
/// <param name="optionsVectorFieldName">The vector field name provided via options.</param>
/// <returns>The resolved vector field name.</returns>
/// <exception cref="InvalidOperationException">Thrown if the provided field name is not a valid field name.</exception>
private string ResolveVectorFieldName(string? optionsVectorFieldName)
{
string? vectorFieldName;
if (optionsVectorFieldName is not null)
{
if (!this._storagePropertyNames.TryGetValue(optionsVectorFieldName, out vectorFieldName))
{
throw new InvalidOperationException($"The collection does not have a vector field named '{optionsVectorFieldName}'.");
}
}
else
{
vectorFieldName = this._firstVectorPropertyName;
}

return vectorFieldName!;
}

/// <summary>
/// Run the given operation and wrap any <see cref="RpcException"/> with <see cref="VectorStoreOperationException"/>."/>
/// </summary>
Expand Down
Loading

0 comments on commit 7836721

Please sign in to comment.