Skip to content

Commit f34bcf5

Browse files
authored
.Net MEVD: SqlServerVectorStore should accept a connection string to be thread safe (#11042)
By accepting a connection string rather than a connection itself we can allow the store to be thread safe and we don't need to worry about the lifetime of provided connection. It's a breaking change, but we really want to steer the users to do the right thing, so the sooner we release it, the better.
1 parent ea00f4f commit f34bcf5

File tree

5 files changed

+82
-62
lines changed

5 files changed

+82
-62
lines changed

dotnet/src/Connectors/Connectors.Memory.SqlServer/ExceptionWrapper.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ internal static async Task<T> WrapAsync<T>(
3333
}
3434
catch (Exception ex)
3535
{
36+
#if NET
37+
await connection.DisposeAsync().ConfigureAwait(false);
38+
#else
39+
connection.Dispose();
40+
#endif
41+
3642
throw new VectorStoreOperationException(ex.Message, ex)
3743
{
3844
OperationName = operationName,

dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStore.cs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3-
using System;
43
using System.Collections.Generic;
54
using System.Runtime.CompilerServices;
65
using System.Threading;
@@ -12,36 +11,35 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer;
1211
/// <summary>
1312
/// An implementation of <see cref="IVectorStore"/> backed by a SQL Server or Azure SQL database.
1413
/// </summary>
15-
public sealed class SqlServerVectorStore : IVectorStore, IDisposable
14+
public sealed class SqlServerVectorStore : IVectorStore
1615
{
17-
private readonly SqlConnection _connection;
16+
private readonly string _connectionString;
1817
private readonly SqlServerVectorStoreOptions _options;
1918

2019
/// <summary>
2120
/// Initializes a new instance of the <see cref="SqlServerVectorStore"/> class.
2221
/// </summary>
23-
/// <param name="connection">Database connection.</param>
22+
/// <param name="connectionString">The connection string.</param>
2423
/// <param name="options">Optional configuration options.</param>
25-
public SqlServerVectorStore(SqlConnection connection, SqlServerVectorStoreOptions? options = null)
24+
public SqlServerVectorStore(string connectionString, SqlServerVectorStoreOptions? options = null)
2625
{
27-
this._connection = connection;
26+
Verify.NotNullOrWhiteSpace(connectionString);
27+
28+
this._connectionString = connectionString;
2829
// We need to create a copy, so any changes made to the option bag after
2930
// the ctor call do not affect this instance.
3031
this._options = options is not null
3132
? new() { Schema = options.Schema }
3233
: SqlServerVectorStoreOptions.Defaults;
3334
}
3435

35-
/// <inheritdoc/>
36-
public void Dispose() => this._connection.Dispose();
37-
3836
/// <inheritdoc/>
3937
public IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, TRecord>(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull
4038
{
4139
Verify.NotNull(name);
4240

4341
return new SqlServerVectorStoreRecordCollection<TKey, TRecord>(
44-
this._connection,
42+
this._connectionString,
4543
name,
4644
new()
4745
{
@@ -53,9 +51,10 @@ public IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, TRecord>(
5351
/// <inheritdoc/>
5452
public async IAsyncEnumerable<string> ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
5553
{
56-
using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(this._connection, this._options.Schema);
54+
using SqlConnection connection = new(this._connectionString);
55+
using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, this._options.Schema);
5756

58-
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._connection, command,
57+
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command,
5958
static (cmd, ct) => cmd.ExecuteReaderAsync(ct),
6059
cancellationToken, "ListCollection").ConfigureAwait(false);
6160

dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerVectorStoreRecordCollection.cs

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@ public sealed class SqlServerVectorStoreRecordCollection<TKey, TRecord>
2222
private static readonly VectorSearchOptions<TRecord> s_defaultVectorSearchOptions = new();
2323
private static readonly SqlServerVectorStoreRecordCollectionOptions<TRecord> s_defaultOptions = new();
2424

25-
private readonly SqlConnection _sqlConnection;
25+
private readonly string _connectionString;
2626
private readonly SqlServerVectorStoreRecordCollectionOptions<TRecord> _options;
2727
private readonly VectorStoreRecordPropertyReader _propertyReader;
2828
private readonly IVectorStoreRecordMapper<TRecord, IDictionary<string, object?>> _mapper;
2929

3030
/// <summary>
3131
/// Initializes a new instance of the <see cref="SqlServerVectorStoreRecordCollection{TKey, TRecord}"/> class.
3232
/// </summary>
33-
/// <param name="connection">Database connection.</param>
33+
/// <param name="connectionString">Database connection string.</param>
3434
/// <param name="name">The name of the collection.</param>
3535
/// <param name="options">Optional configuration options.</param>
3636
public SqlServerVectorStoreRecordCollection(
37-
SqlConnection connection,
37+
string connectionString,
3838
string name,
3939
SqlServerVectorStoreRecordCollectionOptions<TRecord>? options = null)
4040
{
41-
Verify.NotNull(connection);
41+
Verify.NotNullOrWhiteSpace(connectionString);
4242
Verify.NotNull(name);
4343

4444
VectorStoreRecordPropertyReader propertyReader = new(typeof(TRecord),
@@ -61,7 +61,7 @@ public SqlServerVectorStoreRecordCollection(
6161
propertyReader.VerifyDataProperties(SqlServerConstants.SupportedDataTypes, supportEnumerable: false);
6262
propertyReader.VerifyVectorProperties(SqlServerConstants.SupportedVectorTypes);
6363

64-
this._sqlConnection = connection;
64+
this._connectionString = connectionString;
6565
this.CollectionName = name;
6666
// We need to create a copy, so any changes made to the option bag after
6767
// the ctor call do not affect this instance.
@@ -96,10 +96,11 @@ public SqlServerVectorStoreRecordCollection(
9696
/// <inheritdoc/>
9797
public async Task<bool> CollectionExistsAsync(CancellationToken cancellationToken = default)
9898
{
99+
using SqlConnection connection = new(this._connectionString);
99100
using SqlCommand command = SqlServerCommandBuilder.SelectTableName(
100-
this._sqlConnection, this._options.Schema, this.CollectionName);
101+
connection, this._options.Schema, this.CollectionName);
101102

102-
return await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
103+
return await ExceptionWrapper.WrapAsync(connection, command,
103104
static async (cmd, ct) =>
104105
{
105106
using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false);
@@ -125,27 +126,29 @@ private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken can
125126
}
126127
}
127128

129+
using SqlConnection connection = new(this._connectionString);
128130
using SqlCommand command = SqlServerCommandBuilder.CreateTable(
129-
this._sqlConnection,
131+
connection,
130132
this._options.Schema,
131133
this.CollectionName,
132134
ifNotExists,
133135
this._propertyReader.KeyProperty,
134136
this._propertyReader.DataProperties,
135137
this._propertyReader.VectorProperties);
136138

137-
await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
139+
await ExceptionWrapper.WrapAsync(connection, command,
138140
static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct),
139141
cancellationToken, "CreateCollection", this.CollectionName).ConfigureAwait(false);
140142
}
141143

142144
/// <inheritdoc/>
143145
public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default)
144146
{
147+
using SqlConnection connection = new(this._connectionString);
145148
using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists(
146-
this._sqlConnection, this._options.Schema, this.CollectionName);
149+
connection, this._options.Schema, this.CollectionName);
147150

148-
await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
151+
await ExceptionWrapper.WrapAsync(connection, command,
149152
static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct),
150153
cancellationToken, "DeleteCollection", this.CollectionName).ConfigureAwait(false);
151154
}
@@ -155,14 +158,15 @@ public async Task DeleteAsync(TKey key, CancellationToken cancellationToken = de
155158
{
156159
Verify.NotNull(key);
157160

161+
using SqlConnection connection = new(this._connectionString);
158162
using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(
159-
this._sqlConnection,
163+
connection,
160164
this._options.Schema,
161165
this.CollectionName,
162166
this._propertyReader.KeyProperty,
163167
key);
164168

165-
await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
169+
await ExceptionWrapper.WrapAsync(connection, command,
166170
static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct),
167171
cancellationToken, "Delete", this.CollectionName).ConfigureAwait(false);
168172
}
@@ -172,8 +176,9 @@ public async Task DeleteBatchAsync(IEnumerable<TKey> keys, CancellationToken can
172176
{
173177
Verify.NotNull(keys);
174178

179+
using SqlConnection connection = new(this._connectionString);
175180
using SqlCommand? command = SqlServerCommandBuilder.DeleteMany(
176-
this._sqlConnection,
181+
connection,
177182
this._options.Schema,
178183
this.CollectionName,
179184
this._propertyReader.KeyProperty,
@@ -184,7 +189,7 @@ public async Task DeleteBatchAsync(IEnumerable<TKey> keys, CancellationToken can
184189
return; // keys is empty, there is nothing to delete
185190
}
186191

187-
await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
192+
await ExceptionWrapper.WrapAsync(connection, command,
188193
static (cmd, ct) => cmd.ExecuteNonQueryAsync(ct),
189194
cancellationToken, "DeleteBatch", this.CollectionName).ConfigureAwait(false);
190195
}
@@ -196,16 +201,17 @@ await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
196201

197202
bool includeVectors = options?.IncludeVectors is true;
198203

204+
using SqlConnection connection = new(this._connectionString);
199205
using SqlCommand command = SqlServerCommandBuilder.SelectSingle(
200-
this._sqlConnection,
206+
connection,
201207
this._options.Schema,
202208
this.CollectionName,
203209
this._propertyReader.KeyProperty,
204210
this._propertyReader.Properties,
205211
key,
206212
includeVectors);
207213

208-
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
214+
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command,
209215
static async (cmd, ct) =>
210216
{
211217
SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false);
@@ -228,8 +234,9 @@ public async IAsyncEnumerable<TRecord> GetBatchAsync(IEnumerable<TKey> keys, Get
228234

229235
bool includeVectors = options?.IncludeVectors is true;
230236

237+
using SqlConnection connection = new(this._connectionString);
231238
using SqlCommand? command = SqlServerCommandBuilder.SelectMany(
232-
this._sqlConnection,
239+
connection,
233240
this._options.Schema,
234241
this.CollectionName,
235242
this._propertyReader.KeyProperty,
@@ -242,7 +249,7 @@ public async IAsyncEnumerable<TRecord> GetBatchAsync(IEnumerable<TKey> keys, Get
242249
yield break; // keys is empty
243250
}
244251

245-
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
252+
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command,
246253
static (cmd, ct) => cmd.ExecuteReaderAsync(ct),
247254
cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false);
248255

@@ -259,15 +266,16 @@ public async Task<TKey> UpsertAsync(TRecord record, CancellationToken cancellati
259266
{
260267
Verify.NotNull(record);
261268

269+
using SqlConnection connection = new(this._connectionString);
262270
using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(
263-
this._sqlConnection,
271+
connection,
264272
this._options.Schema,
265273
this.CollectionName,
266274
this._propertyReader.KeyProperty,
267275
this._propertyReader.Properties,
268276
this._mapper.MapFromDataToStorageModel(record));
269277

270-
return await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
278+
return await ExceptionWrapper.WrapAsync(connection, command,
271279
async static (cmd, ct) =>
272280
{
273281
using SqlDataReader reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false);
@@ -282,8 +290,9 @@ public async IAsyncEnumerable<TKey> UpsertBatchAsync(IEnumerable<TRecord> record
282290
{
283291
Verify.NotNull(records);
284292

293+
using SqlConnection connection = new(this._connectionString);
285294
using SqlCommand? command = SqlServerCommandBuilder.MergeIntoMany(
286-
this._sqlConnection,
295+
connection,
287296
this._options.Schema,
288297
this.CollectionName,
289298
this._propertyReader.KeyProperty,
@@ -295,7 +304,7 @@ public async IAsyncEnumerable<TKey> UpsertBatchAsync(IEnumerable<TRecord> record
295304
yield break; // records is empty
296305
}
297306

298-
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
307+
using SqlDataReader reader = await ExceptionWrapper.WrapAsync(connection, command,
299308
static (cmd, ct) => cmd.ExecuteReaderAsync(ct),
300309
cancellationToken, "GetBatch", this.CollectionName).ConfigureAwait(false);
301310

@@ -326,8 +335,13 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
326335
var searchOptions = options ?? s_defaultVectorSearchOptions;
327336
var vectorProperty = this._propertyReader.GetVectorPropertyOrSingle(searchOptions);
328337

338+
#pragma warning disable CA2000 // Dispose objects before losing scope
339+
// This connection will be disposed by the ReadVectorSearchResultsAsync
340+
// when the user is done with the results.
341+
SqlConnection connection = new(this._connectionString);
342+
#pragma warning restore CA2000 // Dispose objects before losing scope
329343
using SqlCommand command = SqlServerCommandBuilder.SelectVector(
330-
this._sqlConnection,
344+
connection,
331345
this._options.Schema,
332346
this.CollectionName,
333347
vectorProperty,
@@ -336,34 +350,42 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
336350
searchOptions,
337351
allowed);
338352

339-
return await ExceptionWrapper.WrapAsync(this._sqlConnection, command,
353+
return await ExceptionWrapper.WrapAsync(connection, command,
340354
(cmd, ct) =>
341355
{
342-
var results = this.ReadVectorSearchResultsAsync(cmd, searchOptions.IncludeVectors, ct);
356+
var results = this.ReadVectorSearchResultsAsync(connection, cmd, searchOptions.IncludeVectors, ct);
343357
return Task.FromResult(new VectorSearchResults<TRecord>(results));
344358
}, cancellationToken, "VectorizedSearch", this.CollectionName).ConfigureAwait(false);
345359
}
346360

347361
private async IAsyncEnumerable<VectorSearchResult<TRecord>> ReadVectorSearchResultsAsync(
362+
SqlConnection connection,
348363
SqlCommand command,
349364
bool includeVectors,
350365
[EnumeratorCancellation] CancellationToken cancellationToken)
351366
{
352-
StorageToDataModelMapperOptions options = new() { IncludeVectors = includeVectors };
353-
var vectorPropertyStoragePropertyNames = includeVectors ? this._propertyReader.VectorPropertyStoragePropertyNames : [];
354-
using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
355-
356-
int scoreIndex = -1;
357-
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
367+
try
358368
{
359-
if (scoreIndex < 0)
369+
StorageToDataModelMapperOptions options = new() { IncludeVectors = includeVectors };
370+
var vectorPropertyStoragePropertyNames = includeVectors ? this._propertyReader.VectorPropertyStoragePropertyNames : [];
371+
using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
372+
373+
int scoreIndex = -1;
374+
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
360375
{
361-
scoreIndex = reader.GetOrdinal("score");
376+
if (scoreIndex < 0)
377+
{
378+
scoreIndex = reader.GetOrdinal("score");
379+
}
380+
381+
yield return new VectorSearchResult<TRecord>(
382+
this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorPropertyStoragePropertyNames), options),
383+
reader.GetDouble(scoreIndex));
362384
}
363-
364-
yield return new VectorSearchResult<TRecord>(
365-
this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorPropertyStoragePropertyNames), options),
366-
reader.GetDouble(scoreIndex));
385+
}
386+
finally
387+
{
388+
connection.Dispose();
367389
}
368390
}
369391
}

dotnet/src/VectorDataIntegrationTests/SqlServerIntegrationTests/SqlServerVectorStoreTests.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ public async Task CustomMapper()
159159
{
160160
Mapper = mapper
161161
};
162-
using SqlConnection connection = new(SqlServerTestEnvironment.ConnectionString);
163-
SqlServerVectorStoreRecordCollection<string, TestModel> collection = new(connection, collectionName, options);
162+
SqlServerVectorStoreRecordCollection<string, TestModel> collection = new(SqlServerTestEnvironment.ConnectionString!, collectionName, options);
164163

165164
try
166165
{

0 commit comments

Comments
 (0)