Skip to content

Commit

Permalink
.Net: Adding generic data model support for Redis (#8816)
Browse files Browse the repository at this point in the history
### Motivation and Context

In some cases users may not want to define their own data model, e.g.
where the database schema is driven from configuration.
To support this we allow a generic data model which uses object
dictionaries for most of the fields, but this requires custom mapping,
since each field from storage has to be mapped specifically into the
right dictionary.

### Description

- Adding support to Redis for using a generic data model where schema is
determined from the record definition.
- Adding unit tests / integration tests
- Refactoring some of the existing unit tests to reuse test code.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
westey-m authored Sep 16, 2024
1 parent db0faca commit e9f1fca
Show file tree
Hide file tree
Showing 12 changed files with 957 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.SemanticKernel.Data;
using StackExchange.Redis;

namespace Microsoft.SemanticKernel.Connectors.Redis;

/// <summary>
/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Redis when using hash sets.
/// </summary>
internal class RedisHashSetGenericDataModelMapper : IVectorStoreRecordMapper<VectorStoreGenericDataModel<string>, (string Key, HashEntry[] HashEntries)>
{
/// <summary>A <see cref="VectorStoreRecordDefinition"/> that defines the schema of the data in the database.</summary>
private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition;

/// <summary>
/// Initializes a new instance of the <see cref="RedisHashSetGenericDataModelMapper"/> class.
/// </summary>
/// <param name="vectorStoreRecordDefinition">A <see cref="VectorStoreRecordDefinition"/> that defines the schema of the data in the database.</param>
public RedisHashSetGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition)
{
Verify.NotNull(vectorStoreRecordDefinition);

this._vectorStoreRecordDefinition = vectorStoreRecordDefinition;
}

/// <inheritdoc />
public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(VectorStoreGenericDataModel<string> dataModel)
{
var hashEntries = new List<HashEntry>();

foreach (var property in this._vectorStoreRecordDefinition.Properties)
{
var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName;
var sourceDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors;

// Only map properties across that actually exist in the input.
if (sourceDictionary is null || !sourceDictionary.TryGetValue(property.DataModelPropertyName, out var sourceValue))
{
continue;
}

// Replicate null if the property exists but is null.
if (sourceValue is null)
{
hashEntries.Add(new HashEntry(storagePropertyName, RedisValue.Null));
continue;
}

// Map data Properties
if (property is VectorStoreRecordDataProperty dataProperty)
{
hashEntries.Add(new HashEntry(storagePropertyName, RedisValue.Unbox(sourceValue)));
}
// Map vector properties
else if (property is VectorStoreRecordVectorProperty vectorProperty)
{
if (sourceValue is ReadOnlyMemory<float> rom)
{
hashEntries.Add(new HashEntry(storagePropertyName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rom)));
}
else if (sourceValue is ReadOnlyMemory<double> rod)
{
hashEntries.Add(new HashEntry(storagePropertyName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rod)));
}
else
{
throw new VectorStoreRecordMappingException($"Unsupported vector type {sourceValue.GetType().Name} found on property ${vectorProperty.DataModelPropertyName}. Only float and double vectors are supported.");
}
}
}

return (dataModel.Key, hashEntries.ToArray());
}

/// <inheritdoc />
public VectorStoreGenericDataModel<string> MapFromStorageToDataModel((string Key, HashEntry[] HashEntries) storageModel, StorageToDataModelMapperOptions options)
{
var dataModel = new VectorStoreGenericDataModel<string>(storageModel.Key);

foreach (var property in this._vectorStoreRecordDefinition.Properties)
{
var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName;
var targetDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors;
var hashEntry = storageModel.HashEntries.FirstOrDefault(x => x.Name == storagePropertyName);

// Only map properties across that actually exist in the input.
if (!hashEntry.Name.HasValue)
{
continue;
}

// Replicate null if the property exists but is null.
if (hashEntry.Value.IsNull)
{
targetDictionary.Add(property.DataModelPropertyName, null);
continue;
}

// Map data Properties
if (property is VectorStoreRecordDataProperty dataProperty)
{
var typeOrNullableType = Nullable.GetUnderlyingType(property.PropertyType) ?? property.PropertyType;
var convertedValue = Convert.ChangeType(hashEntry.Value, typeOrNullableType);
dataModel.Data.Add(dataProperty.DataModelPropertyName, convertedValue);
}
// Map vector properties
else if (property is VectorStoreRecordVectorProperty vectorProperty)
{
if (property.PropertyType == typeof(ReadOnlyMemory<float>) || property.PropertyType == typeof(ReadOnlyMemory<float>?))
{
var array = MemoryMarshal.Cast<byte, float>((byte[])hashEntry.Value!).ToArray();
dataModel.Vectors.Add(vectorProperty.DataModelPropertyName, new ReadOnlyMemory<float>(array));
}
else if (property.PropertyType == typeof(ReadOnlyMemory<double>) || property.PropertyType == typeof(ReadOnlyMemory<double>?))
{
var array = MemoryMarshal.Cast<byte, double>((byte[])hashEntry.Value!).ToArray();
dataModel.Vectors.Add(vectorProperty.DataModelPropertyName, new ReadOnlyMemory<double>(array));
}
else
{
throw new VectorStoreRecordMappingException($"Unsupported vector type '{property.PropertyType.Name}' found on property '{property.DataModelPropertyName}'. Only float and double vectors are supported.");
}
}
}

return dataModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,17 @@ public RedisHashSetVectorStoreRecordCollection(IDatabase database, string collec
// Assign Mapper.
if (this._options.HashEntriesCustomMapper is not null)
{
// Custom Mapper.
this._mapper = this._options.HashEntriesCustomMapper;
}
else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel<string>))
{
// Generic data model mapper.
this._mapper = (IVectorStoreRecordMapper<TRecord, (string Key, HashEntry[] HashEntries)>)new RedisHashSetGenericDataModelMapper(this._vectorStoreRecordDefinition);
}
else
{
// Default Mapper.
this._mapper = new RedisHashSetVectorStoreRecordMapper<TRecord>(this._vectorStoreRecordDefinition, this._storagePropertyNames);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ public RedisHashSetVectorStoreRecordMapper(
// collection constructor to ensure that the model has no other vector types.
if (value is ReadOnlyMemory<float> rom)
{
hashEntries.Add(new HashEntry(storageName, ConvertVectorToBytes(rom)));
hashEntries.Add(new HashEntry(storageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rom)));
}
else if (value is ReadOnlyMemory<double> rod)
{
hashEntries.Add(new HashEntry(storageName, ConvertVectorToBytes(rod)));
hashEntries.Add(new HashEntry(storageName, RedisVectorStoreRecordFieldMapping.ConvertVectorToBytes(rod)));
}
}
}
Expand Down Expand Up @@ -156,14 +156,4 @@ public TConsumerDataModel MapFromStorageToDataModel((string Key, HashEntry[] Has

return JsonSerializer.Deserialize<TConsumerDataModel>(jsonObject)!;
}

private static byte[] ConvertVectorToBytes(ReadOnlyMemory<float> vector)
{
return MemoryMarshal.AsBytes(vector.Span).ToArray();
}

private static byte[] ConvertVectorToBytes(ReadOnlyMemory<double> vector)
{
return MemoryMarshal.AsBytes(vector.Span).ToArray();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.SemanticKernel.Data;

namespace Microsoft.SemanticKernel.Connectors.Redis;

/// <summary>
/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Redis when using JSON.
/// </summary>
internal class RedisJsonGenericDataModelMapper : IVectorStoreRecordMapper<VectorStoreGenericDataModel<string>, (string Key, JsonNode Node)>
{
/// <summary>A <see cref="VectorStoreRecordDefinition"/> that defines the schema of the data in the database.</summary>
private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition;

/// <summary>The JSON serializer options to use when converting between the data model and the Redis record.</summary>
private readonly JsonSerializerOptions _jsonSerializerOptions;

/// <summary>A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties.</summary>
public readonly Dictionary<string, string> _storagePropertyNames;

/// <summary>
/// Initializes a new instance of the <see cref="RedisJsonGenericDataModelMapper"/> class.
/// </summary>
/// <param name="vectorStoreRecordDefinition">A <see cref="VectorStoreRecordDefinition"/> that defines the schema of the data in the database.</param>
/// <param name="jsonSerializerOptions">The JSON serializer options to use when converting between the data model and the Redis record.</param>
public RedisJsonGenericDataModelMapper(
VectorStoreRecordDefinition vectorStoreRecordDefinition,
JsonSerializerOptions jsonSerializerOptions)
{
Verify.NotNull(vectorStoreRecordDefinition);
Verify.NotNull(jsonSerializerOptions);

this._vectorStoreRecordDefinition = vectorStoreRecordDefinition;
this._jsonSerializerOptions = jsonSerializerOptions;

// Create a dictionary that maps from the data model property name to the storage property name.
this._storagePropertyNames = vectorStoreRecordDefinition.Properties.Select(x =>
{
if (x.StoragePropertyName is not null)
{
return new KeyValuePair<string, string>(
x.DataModelPropertyName,
x.StoragePropertyName);
}
if (jsonSerializerOptions.PropertyNamingPolicy is not null)
{
return new KeyValuePair<string, string>(
x.DataModelPropertyName,
jsonSerializerOptions.PropertyNamingPolicy.ConvertName(x.DataModelPropertyName));
}
return new KeyValuePair<string, string>(
x.DataModelPropertyName,
x.DataModelPropertyName);
}).ToDictionary(x => x.Key, x => x.Value);
}

/// <inheritdoc />
public (string Key, JsonNode Node) MapFromDataToStorageModel(VectorStoreGenericDataModel<string> dataModel)
{
var jsonObject = new JsonObject();

foreach (var property in this._vectorStoreRecordDefinition.Properties)
{
var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName];
var sourceDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors;

// Only map properties across that actually exist in the input.
if (sourceDictionary is null || !sourceDictionary.TryGetValue(property.DataModelPropertyName, out var sourceValue))
{
continue;
}

// Replicate null if the property exists but is null.
if (sourceValue is null)
{
jsonObject.Add(storagePropertyName, null);
continue;
}

jsonObject.Add(storagePropertyName, JsonSerializer.SerializeToNode(sourceValue, property.PropertyType));
}

return (dataModel.Key, jsonObject);
}

/// <inheritdoc />
public VectorStoreGenericDataModel<string> MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options)
{
var dataModel = new VectorStoreGenericDataModel<string>(storageModel.Key);

// The redis result can be either a single object or an array with a single object in the case where we are doing an MGET.
JsonObject jsonObject;
if (storageModel.Node is JsonObject topLevelJsonObject)
{
jsonObject = topLevelJsonObject;
}
else if (storageModel.Node is JsonArray jsonArray && jsonArray.Count == 1 && jsonArray[0] is JsonObject arrayEntryJsonObject)
{
jsonObject = arrayEntryJsonObject;
}
else
{
throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'");
}

foreach (var property in this._vectorStoreRecordDefinition.Properties)
{
var storagePropertyName = this._storagePropertyNames[property.DataModelPropertyName];
var targetDictionary = property is VectorStoreRecordDataProperty ? dataModel.Data : dataModel.Vectors;

// Only map properties across that actually exist in the input.
if (!jsonObject.TryGetPropertyValue(storagePropertyName, out var sourceValue))
{
continue;
}

// Replicate null if the property exists but is null.
if (sourceValue is null)
{
targetDictionary.Add(property.DataModelPropertyName, null);
continue;
}

// Map data and vector values.
if (property is VectorStoreRecordDataProperty || property is VectorStoreRecordVectorProperty)
{
targetDictionary.Add(property.DataModelPropertyName, JsonSerializer.Deserialize(sourceValue, property.PropertyType));
}
}

return dataModel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,19 @@ public RedisJsonVectorStoreRecordCollection(IDatabase database, string collectio
// Assign Mapper.
if (this._options.JsonNodeCustomMapper is not null)
{
// Custom Mapper.
this._mapper = this._options.JsonNodeCustomMapper;
}
else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel<string>))
{
// Generic data model mapper.
this._mapper = (IVectorStoreRecordMapper<TRecord, (string Key, JsonNode Node)>)new RedisJsonGenericDataModelMapper(
this._vectorStoreRecordDefinition,
this._jsonSerializerOptions);
}
else
{
// Default Mapper.
this._mapper = new RedisJsonVectorStoreRecordMapper<TRecord>(keyJsonPropertyName, this._jsonSerializerOptions);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.SemanticKernel.Connectors.Redis;

/// <summary>
/// Contains helper methods for mapping fields to and from the format required by the Redis client sdk.
/// </summary>
internal static class RedisVectorStoreRecordFieldMapping
{
/// <summary>
/// Convert a vector to a byte array as required by the Redis client sdk when using hashsets.
/// </summary>
/// <param name="vector">The vector to convert.</param>
/// <returns>The byte array.</returns>
public static byte[] ConvertVectorToBytes(ReadOnlyMemory<float> vector)
{
return MemoryMarshal.AsBytes(vector.Span).ToArray();
}

/// <summary>
/// Convert a vector to a byte array as required by the Redis client sdk when using hashsets.
/// </summary>
/// <param name="vector">The vector to convert.</param>
/// <returns>The byte array.</returns>
public static byte[] ConvertVectorToBytes(ReadOnlyMemory<double> vector)
{
return MemoryMarshal.AsBytes(vector.Span).ToArray();
}
}
Loading

0 comments on commit e9f1fca

Please sign in to comment.