From c57604106bf44dbec83fc7dfa57cbac6925906fb Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Wed, 16 Apr 2025 22:17:11 +0530 Subject: [PATCH 1/2] Feature Extension Changes for Vector Datatype Support (#3209) This commit adds feature extension for vector datatype support. Additionally, GenericTDSServer has been enhanced to enable tests to validate login requests and response for the vector feature extension. Tests have been added under SqlConnectionBasicTests through TestConnWithVectorFeatExtVersionNegotiation. --- .../SqlClient/SqlInternalConnectionTds.cs | 22 +++++ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 35 +++++++ .../SqlClient/SqlInternalConnectionTds.cs | 22 +++++ .../src/Microsoft/Data/SqlClient/TdsParser.cs | 35 +++++++ .../src/Microsoft/Data/SqlClient/TdsEnums.cs | 7 +- .../SqlConnectionBasicTests.cs | 96 +++++++++++++++++++ .../TDS/TDS.EndPoint/ITDSServerSession.cs | 5 + .../tools/TDS/TDS.Servers/GenericTDSServer.cs | 80 ++++++++++++++++ .../TDS.Servers/GenericTDSServerSession.cs | 5 + .../tests/tools/TDS/TDS/TDSFeatureID.cs | 5 + 10 files changed, 311 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 0d1da114cc..fdcf44d6b6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -207,6 +207,9 @@ internal bool IsDNSCachingBeforeRedirectSupported // Json Support Flag internal bool IsJsonSupportEnabled = false; + // Vector Support Flag + internal bool IsVectorSupportEnabled = false; + // TCE flags internal byte _tceVersionSupported; @@ -1425,6 +1428,7 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, // The SQLDNSCaching and JSON features are implicitly set requestedFeatures |= TdsEnums.FeatureExtension.SQLDNSCaching; requestedFeatures |= TdsEnums.FeatureExtension.JsonSupport; + requestedFeatures |= TdsEnums.FeatureExtension.VectorSupport; _parser.TdsLogin(login, requestedFeatures, _recoverySessionData, _fedAuthFeatureExtensionData, encrypt); } @@ -3011,6 +3015,24 @@ internal void OnFeatureExtAck(int featureId, byte[] data) break; } + case TdsEnums.FEATUREEXT_VECTORSUPPORT: + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, Received feature extension acknowledgement for VECTORSUPPORT", ObjectID); + if (data.Length != 1) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Unknown token for VECTORSUPPORT", ObjectID); + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + } + byte vectorSupportVersion = data[0]; + if (vectorSupportVersion == 0 || vectorSupportVersion > TdsEnums.MAX_SUPPORTED_VECTOR_VERSION) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Invalid version number {1} for VECTORSUPPORT, Max supported version is {2}", ObjectID, vectorSupportVersion, TdsEnums.MAX_SUPPORTED_VECTOR_VERSION); + throw SQL.ParsingError(); + } + IsVectorSupportEnabled = true; + break; + } + default: { // Unknown feature ack diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 7eaee4e150..be95454ddd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -8487,6 +8487,36 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD return len; } + /// + /// Writes the Vector Support feature request to the physical state object. + /// The request includes the feature ID, feature data length, and version number. + /// + /// If true, writes the feature request to the physical state object. + /// The length of the feature request in bytes. + /// + /// The feature request consists of: + /// - 1 byte for the feature ID. + /// - 4 bytes for the feature data length. + /// - 1 byte for the version number. + /// + internal int WriteVectorSupportFeatureRequest(bool write) + { + const int len = 6; + + if (write) + { + // Write Feature ID + _physicalStateObj.WriteByte(TdsEnums.FEATUREEXT_VECTORSUPPORT); + + // Feature Data Length + WriteInt(1, _physicalStateObj); + + _physicalStateObj.WriteByte(TdsEnums.MAX_SUPPORTED_VECTOR_VERSION); + } + + return len; + } + private void WriteLoginData(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures, SessionData recoverySessionData, @@ -8810,6 +8840,11 @@ private int ApplyFeatureExData(TdsEnums.FeatureExtension requestedFeatures, length += WriteJsonSupportFeatureRequest(write); } + if ((requestedFeatures & TdsEnums.FeatureExtension.VectorSupport) != 0) + { + length += WriteVectorSupportFeatureRequest(write); + } + length++; // for terminator if (write) { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index d460c61619..9da128fb6d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -208,6 +208,9 @@ internal bool IsDNSCachingBeforeRedirectSupported // Json Support Flag internal bool IsJsonSupportEnabled = false; + // Vector Support Flag + internal bool IsVectorSupportEnabled = false; + // TCE flags internal byte _tceVersionSupported; @@ -1431,6 +1434,7 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, // The SQLDNSCaching and JSON features are implicitly set requestedFeatures |= TdsEnums.FeatureExtension.SQLDNSCaching; requestedFeatures |= TdsEnums.FeatureExtension.JsonSupport; + requestedFeatures |= TdsEnums.FeatureExtension.VectorSupport; _parser.TdsLogin(login, requestedFeatures, _recoverySessionData, _fedAuthFeatureExtensionData, encrypt); } @@ -3043,6 +3047,24 @@ internal void OnFeatureExtAck(int featureId, byte[] data) break; } + case TdsEnums.FEATUREEXT_VECTORSUPPORT: + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, Received feature extension acknowledgement for VECTORSUPPORT", ObjectID); + if (data.Length != 1) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Unknown token for VECTORSUPPORT", ObjectID); + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + } + byte vectorSupportVersion = data[0]; + if (vectorSupportVersion == 0 || vectorSupportVersion > TdsEnums.MAX_SUPPORTED_VECTOR_VERSION) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Invalid version number {1} for VECTORSUPPORT, Max supported version is {2}", ObjectID, vectorSupportVersion, TdsEnums.MAX_SUPPORTED_VECTOR_VERSION); + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + } + IsVectorSupportEnabled = true; + break; + } + default: { // Unknown feature ack diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 85de3484d8..c5a6508fea 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -8670,6 +8670,36 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD return len; } + /// + /// Writes the Vector Support feature request to the physical state object. + /// The request includes the feature ID, feature data length, and version number. + /// + /// If true, writes the feature request to the physical state object. + /// The length of the feature request in bytes. + /// + /// The feature request consists of: + /// - 1 byte for the feature ID. + /// - 4 bytes for the feature data length. + /// - 1 byte for the version number. + /// + internal int WriteVectorSupportFeatureRequest(bool write) + { + const int len = 6; + + if (write) + { + // Write Feature ID + _physicalStateObj.WriteByte(TdsEnums.FEATUREEXT_VECTORSUPPORT); + + // Feature Data Length + WriteInt(1, _physicalStateObj); + + _physicalStateObj.WriteByte(TdsEnums.MAX_SUPPORTED_VECTOR_VERSION); + } + + return len; + } + private void WriteLoginData(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures, SessionData recoverySessionData, @@ -9002,6 +9032,11 @@ private int ApplyFeatureExData(TdsEnums.FeatureExtension requestedFeatures, length += WriteJsonSupportFeatureRequest(write); } + if ((requestedFeatures & TdsEnums.FeatureExtension.VectorSupport) != 0) + { + length += WriteVectorSupportFeatureRequest(write); + } + length++; // for terminator if (write) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs index 280469a5e0..b17e52500c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -240,6 +240,7 @@ public enum EnvChangeType : byte public const byte FEATUREEXT_UTF8SUPPORT = 0x0A; public const byte FEATUREEXT_SQLDNSCACHING = 0x0B; public const byte FEATUREEXT_JSONSUPPORT = 0x0D; + public const byte FEATUREEXT_VECTORSUPPORT = 0x0E; [Flags] public enum FeatureExtension : uint @@ -253,7 +254,8 @@ public enum FeatureExtension : uint DataClassification = 1 << (TdsEnums.FEATUREEXT_DATACLASSIFICATION - 1), UTF8Support = 1 << (TdsEnums.FEATUREEXT_UTF8SUPPORT - 1), SQLDNSCaching = 1 << (TdsEnums.FEATUREEXT_SQLDNSCACHING - 1), - JsonSupport = 1 << (TdsEnums.FEATUREEXT_JSONSUPPORT - 1) + JsonSupport = 1 << (TdsEnums.FEATUREEXT_JSONSUPPORT - 1), + VectorSupport = 1 << (TdsEnums.FEATUREEXT_VECTORSUPPORT -1) } public const uint UTF8_IN_TDSCOLLATION = 0x4000000; @@ -978,6 +980,9 @@ internal enum FedAuthInfoId : byte // JSON Support constants internal const byte MAX_SUPPORTED_JSON_VERSION = 0x01; + // Vector Support constants + internal const byte MAX_SUPPORTED_VECTOR_VERSION = 0x01; + // TCE Related constants internal const byte MAX_SUPPORTED_TCE_VERSION = 0x03; // max version internal const byte MIN_TCE_VERSION_WITH_ENCLAVE_SUPPORT = 0x02; // min version with enclave support diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 446ddefd36..616a8fec6f 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -7,11 +7,15 @@ using System.Data.Common; using System.Diagnostics; using System.Globalization; +using System.Linq; using System.Reflection; using System.Runtime.InteropServices; using System.Security; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlServer.TDS; +using Microsoft.SqlServer.TDS.FeatureExtAck; +using Microsoft.SqlServer.TDS.Login7; using Microsoft.SqlServer.TDS.PreLogin; using Microsoft.SqlServer.TDS.Servers; using Xunit; @@ -524,5 +528,97 @@ public void ConnectionTestDeniedVersion(int major, int minor, int build) Assert.Throws(() => conn.Open()); } + + + + // Test to verify that the server and client negotiate + // the common feature extension version. + // MDS currently supports vector feature ext version 0x1. + [Theory] + [InlineData(true, 0x2, 0x1)] + [InlineData(false, 0x0, 0x0)] + [InlineData(true, 0x1, 0x1)] + [InlineData(true, 0xFF, 0x0)] + public void TestConnWithVectorFeatExtVersionNegotiation(bool expectedConnectionResult, byte serverVersion, byte expectedNegotiatedVersion) + { + // Start the test TDS server. + using var server = TestTdsServer.StartTestServer(); + server.ServerSupportedVectorFeatureExtVersion = serverVersion; + server.EnableVectorFeatureExt = serverVersion == 0xFF ? false : true; + + byte expectedLoginReqFeatureExtId = (byte)TDSFeatureID.VectorSupport; + byte expectedLoginReqFeatureExtVersion = 0x1; + byte actualLoginReqFeatureExtId = 0; + byte actualLoginReqFeatureExtVersion = 0; + byte actualFeatureExtAckId = 0; + byte actualFeatureExtAckVersion = 0; + bool loginValuesFound = false; + bool responseValuesFound = false; + + server.OnLogin7Validated = loginToken => + { + if (loginToken.FeatureExt != null) + { + var optionToken = loginToken.FeatureExt + .OfType() + .FirstOrDefault(token => token.FeatureID == TDSFeatureID.VectorSupport); + + if (optionToken != null) + { + actualLoginReqFeatureExtId = (byte)optionToken.FeatureID; + actualLoginReqFeatureExtVersion = optionToken.Data[0]; + loginValuesFound = true; + } + } + }; + + server.OnAuthenticationResponseCompleted = response => + { + var featureExtAckToken = response + .OfType() + .FirstOrDefault(); + + if (featureExtAckToken != null) + { + var featureExtensionOption = featureExtAckToken.Options + .OfType() + .FirstOrDefault(option => option.FeatureID == TDSFeatureID.VectorSupport); + + if (featureExtensionOption != null) + { + actualFeatureExtAckId = (byte)featureExtensionOption.FeatureID; + actualFeatureExtAckVersion = featureExtensionOption.FeatureAckData[0]; + responseValuesFound = true; + } + } + }; + + // Connect to the test TDS server. + using var connection = new SqlConnection(server.ConnectionString); + if (expectedConnectionResult) + { + connection.Open(); + // Verify that the expected value was sent in the LOGIN packet. + Assert.Equal(expectedLoginReqFeatureExtId, actualLoginReqFeatureExtId); + Assert.Equal(expectedLoginReqFeatureExtVersion, actualLoginReqFeatureExtVersion); + Assert.True(loginValuesFound, "Expected login values not found in the login packet."); + // Verify that the expected values were received in the TDS response. + if (server.EnableVectorFeatureExt) + { + Assert.Equal(expectedLoginReqFeatureExtId, actualFeatureExtAckId); + Assert.Equal(expectedNegotiatedVersion, actualFeatureExtAckVersion); + Assert.True(responseValuesFound, "Expected response values not found in the login response."); + } + else + { + Assert.Equal(0x0, actualFeatureExtAckId); + Assert.Equal(0x0, actualFeatureExtAckVersion); + } + } + else + { + Assert.Throws(() => connection.Open()); + } + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs index ee4c7f169f..9b5b7804b4 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs @@ -88,5 +88,10 @@ public interface ITDSServerSession /// Indicates whether the client supports Json column type /// bool IsJsonSupportEnabled { get; set; } + + /// + /// Indicates whether the client supports Vector column type + /// + bool IsVectorSupportEnabled { get; set; } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs index a214c51f88..ac04fd2f57 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs @@ -27,6 +27,43 @@ namespace Microsoft.SqlServer.TDS.Servers /// public class GenericTDSServer : ITDSServer { + /// + /// Delegate to be called when a LOGIN7 request has been received and is + /// validated. This is called before any authentication work is done, + /// and before any response is sent. + /// + public delegate void OnLogin7ValidatedDelegate( + TDSLogin7Token login7Token); + public OnLogin7ValidatedDelegate OnLogin7Validated { private get; set; } + + /// + /// Delegate to be called when authentication is completed and TDSResponse + /// message is sent to the client. + /// + public delegate void OnAuthenticationCompletedDelegate( + TDSMessage response); + public OnAuthenticationCompletedDelegate OnAuthenticationResponseCompleted { private get; set; } + + /// + /// Default feature extension version supported on the server for vector support. + /// + public const byte DefaultSupportedVectorFeatureExtVersion = 0x01; + + /// + /// Property for setting server version for vector feature extension. + /// + public bool EnableVectorFeatureExt { get; set; } = false; + + /// + /// Property for setting server version for vector feature extension. + /// + public byte ServerSupportedVectorFeatureExtVersion { get; set; } = DefaultSupportedVectorFeatureExtVersion; + + /// + /// Client version for vector FeatureExtension. + /// + private byte _clientSupportedVectorFeatureExtVersion = 0; + /// /// Session counter /// @@ -239,6 +276,18 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T session.IsJsonSupportEnabled = true; break; } + + case TDSFeatureID.VectorSupport: + { + if (EnableVectorFeatureExt) + { + // Enable Vector Support + session.IsVectorSupportEnabled = true; + _clientSupportedVectorFeatureExtVersion = ((TDSLogin7GenericOptionToken)option).Data[0]; + } + break; + } + default: { // Do nothing @@ -248,6 +297,8 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T } } + OnLogin7Validated?.Invoke(loginRequest); + // Check if SSPI authentication is requested if (loginRequest.OptionalFlags2.IntegratedSecurity == TDSLogin7OptionalFlags2IntSecurity.On) { @@ -577,6 +628,32 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi } } + // Check if Vector is supported + if (session.IsVectorSupportEnabled) + { + // Create ack data (1 byte: Version number) + byte[] data = new byte[1]; + data[0] = ServerSupportedVectorFeatureExtVersion > _clientSupportedVectorFeatureExtVersion ? _clientSupportedVectorFeatureExtVersion : ServerSupportedVectorFeatureExtVersion; + + // Create vector support as a generic feature extension option + TDSFeatureExtAckGenericOption vectorSupportOption = new TDSFeatureExtAckGenericOption(TDSFeatureID.VectorSupport, (uint)data.Length, data); + + // Look for feature extension token + TDSFeatureExtAckToken featureExtAckToken = (TDSFeatureExtAckToken)responseMessage.Where(t => t is TDSFeatureExtAckToken).FirstOrDefault(); + + if (featureExtAckToken == null) + { + // Create feature extension ack token + featureExtAckToken = new TDSFeatureExtAckToken(vectorSupportOption); + responseMessage.Add(featureExtAckToken); + } + else + { + // Update the existing token + featureExtAckToken.Options.Add(vectorSupportOption); + } + } + // Create DONE token TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final); @@ -586,6 +663,9 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi // Serialize DONE token into the response packet responseMessage.Add(doneToken); + // Invoke delegate for response validation + OnAuthenticationResponseCompleted?.Invoke(responseMessage); + // Wrap a single message in a collection return new TDSMessageCollection(responseMessage); } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs index c2cfde2e29..e9e65d5f8f 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs @@ -119,6 +119,11 @@ public class GenericTDSServerSession : ITDSServerSession /// public bool IsJsonSupportEnabled { get; set; } + /// + /// Indicates whether this session supports Vector column type + /// + public bool IsVectorSupportEnabled { get; set; } + #region Session Options /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSFeatureID.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSFeatureID.cs index eb84a631d0..6bb6fbc8d2 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSFeatureID.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDSFeatureID.cs @@ -24,6 +24,11 @@ public enum TDSFeatureID : byte /// JsonSupport = 0x0D, + /// + /// Vector Support + /// + VectorSupport = 0x0E, + /// /// End of the list /// From 5cb5245bfa28e19b07d792c24fbaeb4db316d3b9 Mon Sep 17 00:00:00 2001 From: Apoorv Deshmukh Date: Wed, 25 Jun 2025 21:55:51 +0530 Subject: [PATCH 2/2] Change API for vector datatype from SqlVectorFloat32 to SqlVector for supporting SqlDbType Vector (#3441) This commit changes APIs for vector datatype. from SqlType class SqlVectorFloat32 to SqlVector to reduce API surface area for unmanaged types. All the tests have been modified to use SqlVector. --- .../SqlDataReader.xml | 26 + .../Microsoft.Data.SqlTypes/SqlVector.xml | 53 ++ .../Microsoft.Data/SqlDbTypeExtensions.xml | 8 + src/Microsoft.Data.SqlClient.sln | 1 + .../netcore/ref/Microsoft.Data.SqlClient.cs | 28 + .../src/Microsoft.Data.SqlClient.csproj | 6 + .../Microsoft/Data/SqlClient/SqlCommand.cs | 10 + .../SqlClient/SqlInternalConnectionTds.cs | 2 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 101 ++- .../netfx/ref/Microsoft.Data.SqlClient.cs | 28 + .../netfx/src/Microsoft.Data.SqlClient.csproj | 11 +- .../Microsoft/Data/SqlClient/SqlCommand.cs | 7 + .../src/Microsoft/Data/SqlClient/TdsParser.cs | 100 ++- .../src/Microsoft/Data/Common/AdapterUtil.cs | 9 + .../Microsoft/Data/SqlClient/ISqlVector.cs | 35 + .../src/Microsoft/Data/SqlClient/SqlBuffer.cs | 104 ++- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 10 +- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 49 +- .../src/Microsoft/Data/SqlClient/SqlEnums.cs | 34 + .../Microsoft/Data/SqlClient/SqlParameter.cs | 64 ++ .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 10 + .../src/Microsoft/Data/SqlClient/TdsEnums.cs | 4 +- .../src/Microsoft/Data/SqlDbTypeExtensions.cs | 6 + .../src/Microsoft/Data/SqlTypes/SqlVector.cs | 246 +++++++ .../src/Resources/Strings.Designer.cs | 59 +- .../src/Resources/Strings.resx | 15 + .../CompilerServices/IsExternalInit.netfx.cs | 23 + .../ManualTests/DataCommon/DataTestUtility.cs | 5 +- ....Data.SqlClient.ManualTesting.Tests.csproj | 2 + .../VectorTest/NativeVectorFloat32Tests.cs | 593 +++++++++++++++++ .../VectorTypeBackwardCompatibilityTests.cs | 630 ++++++++++++++++++ .../tests/UnitTests/SqlVectorTest.cs | 228 +++++++ 32 files changed, 2481 insertions(+), 26 deletions(-) create mode 100644 doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ISqlVector.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs create mode 100644 src/Microsoft.Data.SqlClient/src/System/Runtime/CompilerServices/IsExternalInit.netfx.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/VectorTypeBackwardCompatibilityTests.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml index 8e72434a4a..ff54d18a17 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml @@ -415,6 +415,7 @@ No conversions are performed; therefore. the data retrieved must already be a ch SqlMoney SqlSingle SqlString + SqlVector Stream String TextReader @@ -489,6 +490,7 @@ No conversions are performed; therefore. the data retrieved must already be a ch SqlMoney SqlSingle SqlString + SqlVector Stream String TextReader @@ -958,7 +960,31 @@ The method retur Gets the value of the specified column as a . A object representing the column at the given ordinal. + + No conversions are performed; therefore, the data retrieved must already be a JSON string, or an exception is generated. + + + + + Gets the value of the specified column as a . + + + A object representing the column at the given ordinal. + + + The index passed was outside the range of 0 to - 1 + + + An attempt was made to read or access columns in a closed . + + + The retrieved data is not compatible with the type. + + + No conversions are performed; therefore, the data retrieved must already be a vector value, or an exception is generated. + + The zero-based column ordinal. diff --git a/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml b/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml new file mode 100644 index 0000000000..9a0d621d5d --- /dev/null +++ b/doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml @@ -0,0 +1,53 @@ + + + + + Represents a vector value in SQL Server. + + + + + Takes a as input and represents a null vector value. + + + + + + Takes an array of values as input and initializes a new instance of the SqlVector class. + + + + + + + + Represents a null instance without any attributes. + + + This property is provided for compatibility with DataTable. + In most cases, prefer using IsNull to check if a SqlVector instance is a null vector. + This is equivalent to null value. + + + + + Indicates the count of the elements in the vector. + + + + + Indicates the size in bytes for a SqlVector value. + + + + Returns the values as a instance. + + + Returns an array representing the vector values. + + + Returns the string representation of a JSON array of the values. + A JSON value. + + + diff --git a/doc/snippets/Microsoft.Data/SqlDbTypeExtensions.xml b/doc/snippets/Microsoft.Data/SqlDbTypeExtensions.xml index 1ac732ea8c..cf2fda1536 100644 --- a/doc/snippets/Microsoft.Data/SqlDbTypeExtensions.xml +++ b/doc/snippets/Microsoft.Data/SqlDbTypeExtensions.xml @@ -18,5 +18,13 @@ The class provides enum value for JSON datatype. + + + Gets the enum value for the vector datatype. + + + enum value for vector datatype. + + diff --git a/src/Microsoft.Data.SqlClient.sln b/src/Microsoft.Data.SqlClient.sln index 2b68b2ea45..e4d29d999c 100644 --- a/src/Microsoft.Data.SqlClient.sln +++ b/src/Microsoft.Data.SqlClient.sln @@ -172,6 +172,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlTypes", " ProjectSection(SolutionItems) = preProject ..\doc\snippets\Microsoft.Data.SqlTypes\SqlFileStream.xml = ..\doc\snippets\Microsoft.Data.SqlTypes\SqlFileStream.xml ..\doc\snippets\Microsoft.Data.SqlTypes\SqlJson.xml = ..\doc\snippets\Microsoft.Data.SqlTypes\SqlJson.xml + ..\doc\snippets\Microsoft.Data.SqlTypes\SqlVector.xml = ..\doc\snippets\Microsoft.Data.SqlTypes\SqlVector.xml EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Data.SqlClient.TestUtilities", "Microsoft.Data.SqlClient\tests\tools\Microsoft.Data.SqlClient.TestUtilities\Microsoft.Data.SqlClient.TestUtilities.csproj", "{89D6D382-9B36-43C9-A912-03802FDA8E36}" diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index f80e590ba8..6eda29b17d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -19,6 +19,8 @@ public static class SqlDbTypeExtensions { /// public const System.Data.SqlDbType Json = (System.Data.SqlDbType)35; + /// + public const System.Data.SqlDbType Vector = (System.Data.SqlDbType)36; } } @@ -119,6 +121,30 @@ public SqlJson(System.Text.Json.JsonDocument jsonDoc) { } /// public override string ToString() { throw null; } } + + /// + public sealed class SqlVector : System.Data.SqlTypes.INullable + where T : unmanaged + { + /// + public SqlVector(int length) { } + /// + public SqlVector(System.ReadOnlyMemory memory) { } + /// + public bool IsNull => throw null; + /// + public static SqlVector Null => throw null; + /// + public int Length { get { throw null; } } + /// + public int Size { get { throw null; } } + /// + public System.ReadOnlyMemory Memory { get { throw null; } } + /// + public override string ToString() { throw null; } + /// + public T[] ToArray() { throw null; } + } } namespace Microsoft.Data.SqlClient { @@ -1370,6 +1396,8 @@ public override void Close() { } public virtual object GetSqlValue(int i) { throw null; } /// public virtual int GetSqlValues(object[] values) { throw null; } + /// + public virtual Microsoft.Data.SqlTypes.SqlVector GetSqlVector(int i) where T : unmanaged { throw null; } /// public virtual System.Data.SqlTypes.SqlXml GetSqlXml(int i) { throw null; } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 0a48c5499e..db60c1c927 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -471,6 +471,9 @@ Microsoft\Data\SqlClient\Server\ValueUtilsSmi.cs + + Microsoft\Data\SqlClient\ISqlVector.cs + Microsoft\Data\SqlClient\SignatureVerificationCache.cs @@ -792,6 +795,9 @@ Microsoft\Data\SqlTypes\SqlJson.cs + + Microsoft\Data\SqlTypes\SqlVector.cs + Resources\ResCategoryAttribute.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 4fbbd619e8..47349b492e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -6435,6 +6435,16 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete paramList.Append(scale); paramList.Append(')'); } + else if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + // The validate function for SqlParameters would + // have already thrown InvalidCastException if an incompatible + // value is specified for SqlDbType Vector. + var sqlVectorProps = (ISqlVector)sqlParam.Value; + paramList.Append('('); + paramList.Append(sqlVectorProps.Length); + paramList.Append(')'); + } else if (!mt.IsFixed && !mt.IsLong && mt.SqlDbType != SqlDbType.Timestamp && mt.SqlDbType != SqlDbType.Udt && SqlDbType.Structured != mt.SqlDbType) { int size = sqlParam.Size; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index fdcf44d6b6..f4f51f9b56 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1425,7 +1425,7 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, requestedFeatures |= TdsEnums.FeatureExtension.AzureSQLSupport; } - // The SQLDNSCaching and JSON features are implicitly set + // The following features are implicitly set requestedFeatures |= TdsEnums.FeatureExtension.SQLDNSCaching; requestedFeatures |= TdsEnums.FeatureExtension.JsonSupport; requestedFeatures |= TdsEnums.FeatureExtension.VectorSupport; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index be95454ddd..04594902d6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -4252,6 +4252,15 @@ internal TdsOperationStatus TryProcessReturnValue(int length, } } + if (tdsType == TdsEnums.SQLVECTOR) + { + result = stateObj.TryReadByte(out rec.scale); + if (result != TdsOperationStatus.Done) + { + return result; + } + } + if (rec.type == SqlDbType.Xml) { // Read schema info @@ -4373,6 +4382,11 @@ internal TdsOperationStatus TryProcessReturnValue(int length, intlen = int.MaxValue; // If plp data, read it all } + if (rec.type == SqlDbTypeExtensions.Vector) + { + rec.length = tdsLen; + } + if (isNull) { GetNullSqlValue(rec.value, rec, SqlCommandColumnEncryptionSetting.Disabled, _connHandler); @@ -5200,6 +5214,15 @@ private TdsOperationStatus TryProcessTypeInfo(TdsParserStateObject stateObj, Sql } } + if (col.type == SqlDbTypeExtensions.Vector) + { + result = stateObj.TryReadByte(out col.scale); + if (result != TdsOperationStatus.Done) + { + return result; + } + } + return TdsOperationStatus.Done; } @@ -5836,6 +5859,11 @@ internal static object GetNullSqlValue(SqlBuffer nullVal, nullVal.SetToNullOfType(SqlBuffer.StorageType.Json); break; + case SqlDbTypeExtensions.Vector: + nullVal.SetToNullOfType(SqlBuffer.StorageType.Vector); + nullVal.SetVectorInfo(MetaType.GetVectorElementCount(md.length, md.scale), md.scale, true); + break; + default: Debug.Fail("unknown null sqlType!" + md.type.ToString()); break; @@ -6442,6 +6470,27 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, } break; + case TdsEnums.SQLVECTOR: + // Vector data is read as non-plp binary value. + // This is same as reading varbinary(8000). + result = stateObj.TryReadByteArrayWithContinue(length, out b); + if (result != TdsOperationStatus.Done) + { + return result; + } + + // Internally, we use Sqlbinary to deal with varbinary data and store it in + // SqlBuffer as SqlBinary value. + value.SqlBinary = SqlBinary.WrapBytes(b); + + // Extract the metadata from the payload and set it as the vector attributes + // in the SqlBuffer. This metadata is further used when constructing a SqlVector + // object from binary payload. + int elementCount = BinaryPrimitives.ReadUInt16LittleEndian(b.AsSpan(2)); + byte elementType = b[4]; + value.SetVectorInfo(elementCount, elementType, false); + break; + case TdsEnums.SQLCHAR: case TdsEnums.SQLBIGCHAR: case TdsEnums.SQLVARCHAR: @@ -6747,6 +6796,7 @@ internal TdsOperationStatus TryReadSqlValueInternal(SqlBuffer value, byte tdsTyp case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLVARBINARY: case TdsEnums.SQLIMAGE: + case TdsEnums.SQLVECTOR: { // Note: Better not come here with plp data!! Debug.Assert(length <= TdsEnums.MAXSIZE); @@ -8050,6 +8100,18 @@ internal TdsOperationStatus TryGetTokenLength(byte token, TdsParserStateObject s tokenLength = -1; return TdsOperationStatus.Done; } + else if (token == TdsEnums.SQLVECTOR) + { + ushort value; + result = stateObj.TryReadUInt16(out value); + if (result != TdsOperationStatus.Done) + { + tokenLength = 0; + return result; + } + tokenLength = value; + return TdsOperationStatus.Done; + } } switch (token & TdsEnums.SQLLenMask) @@ -9601,7 +9663,13 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet if (param.Direction == ParameterDirection.Output) { isSqlVal = param.ParameterIsSqlType; // We have to forward the TYPE info, we need to know what type we are returning. Once we null the parameter we will no longer be able to distinguish what type were seeing. - param.Value = null; + + // Output parameter of SqlDbType vector are defined through SqlParameter.Value. + // This check ensures that we do not discard the parameter value when SqlDbType is vector. + if (mt.SqlDbType != SqlDbTypeExtensions.Vector) + { + param.Value = null; + } param.ParameterIsSqlType = isSqlVal; } else @@ -9934,6 +10002,14 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet maxsize = 1; } + if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + // For vector type we need to write the size in bytes required to represent + // vector value when communicating with SQL Server. + var sqlVectorProps = ((ISqlVector)param.Value); + maxsize = sqlVectorProps.Size; + } + WriteParameterVarLen(mt, maxsize, false /*IsNull*/, stateObj); } } @@ -9956,7 +10032,11 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet { stateObj.WriteByte(param.GetActualScale()); } - + else if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + // For vector type we need to write scale as the element type of the vector. + stateObj.WriteByte(((ISqlVector)param.Value).ElementType); + } // write out collation or xml metadata if ((mt.SqlDbType == SqlDbType.Xml || mt.SqlDbType == SqlDbTypeExtensions.Json)) @@ -10035,7 +10115,9 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet { // for codePageEncoded types, WriteValue simply expects the number of characters // For plp types, we also need the encoded byte size - writeParamTask = WriteValue(value, mt, isParameterEncrypted ? (byte)0 : param.GetActualScale(), actualSize, codePageByteSize, isParameterEncrypted ? 0 : param.Offset, stateObj, isParameterEncrypted ? 0 : param.Size, isDataFeed); + // For vector type we need to write scale as the element type of the vector. + byte writeScale = mt.SqlDbType == SqlDbTypeExtensions.Vector ? ((ISqlVector)param.Value).ElementType : param.GetActualScale(); + writeParamTask = WriteValue(value, mt, isParameterEncrypted ? (byte)0 : writeScale, actualSize, codePageByteSize, isParameterEncrypted ? 0 : param.Offset, stateObj, isParameterEncrypted ? 0 : param.Size, isDataFeed); } } @@ -10868,6 +10950,11 @@ internal void WriteBulkCopyMetaData(_SqlMetaDataSet metadataCollection, int coun case SqlDbTypeExtensions.Json: stateObj.WriteByteArray(s_jsonMetadataSubstituteSequence, s_jsonMetadataSubstituteSequence.Length, 0); break; + case SqlDbTypeExtensions.Vector: + stateObj.WriteByte(md.tdsType); + WriteTokenLength(md.tdsType, md.length, stateObj); + stateObj.WriteByte(md.scale); + break; default: stateObj.WriteByte(md.tdsType); WriteTokenLength(md.tdsType, md.length, stateObj); @@ -11083,6 +11170,7 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: + case TdsEnums.SQLVECTOR: ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; break; case TdsEnums.SQLUNIQUEID: @@ -11416,6 +11504,12 @@ private void WriteTokenLength(byte token, int length, TdsParserStateObject state { tokenLength = 8; } + else if (token == TdsEnums.SQLVECTOR) + { + tokenLength = 2; + WriteShort(length, stateObj); + return; + } } if (tokenLength == 0) @@ -12252,6 +12346,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: + case TdsEnums.SQLVECTOR: { // An array should be in the object Debug.Assert(isDataFeed || value is byte[], "Value should be an array of bytes"); diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index 95f4f8a9b0..a3bfa95add 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -23,6 +23,8 @@ public static class SqlDbTypeExtensions { /// public const System.Data.SqlDbType Json = (System.Data.SqlDbType)35; + /// + public const System.Data.SqlDbType Vector = (System.Data.SqlDbType)36; } } @@ -1368,6 +1370,8 @@ public override void Close() { } public virtual object GetSqlValue(int i) { throw null; } /// public virtual int GetSqlValues(object[] values) { throw null; } + /// + public virtual Microsoft.Data.SqlTypes.SqlVector GetSqlVector(int i) where T : unmanaged { throw null; } /// public virtual System.Data.SqlTypes.SqlXml GetSqlXml(int i) { throw null; } /// @@ -2411,4 +2415,28 @@ public SqlJson(System.Text.Json.JsonDocument jsonDoc) { } /// public override string ToString() { throw null; } } + + /// + public sealed class SqlVector : System.Data.SqlTypes.INullable + where T : unmanaged + { + /// + public SqlVector(int length) { } + /// + public SqlVector(System.ReadOnlyMemory memory) { } + /// + public bool IsNull => throw null; + /// + public static SqlVector Null => throw null; + /// + public int Length { get { throw null; } } + /// + public int Size { get { throw null; } } + /// + public System.ReadOnlyMemory Memory { get { throw null; } } + /// + public override string ToString() { throw null; } + /// + public T[] ToArray() { throw null; } + } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 6ecdb6e17d..1b1625bbb7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -585,6 +585,9 @@ Microsoft\Data\SqlClient\Server\SqlSer.cs + + Microsoft\Data\SqlClient\ISqlVector.cs + Microsoft\Data\SqlClient\SessionHandle.Windows.cs @@ -906,6 +909,9 @@ Microsoft\Data\SqlTypes\SqlJson.cs + + Microsoft\Data\SqlTypes\SqlVector.cs + Resources\ResDescriptionAttribute.cs @@ -915,7 +921,10 @@ System\IO\StreamExtensions.netfx.cs - + + System\Runtime\CompilerServices\IsExternalInit.netfx.cs + + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index b40e3dbfa2..97212843f5 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -6650,6 +6650,13 @@ internal string BuildParamList(TdsParser parser, SqlParameterCollection paramete paramList.Append(scale); paramList.Append(')'); } + else if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + var sqlVectorProps = (ISqlVector)sqlParam.Value; + paramList.Append('('); + paramList.Append(sqlVectorProps.Length); + paramList.Append(')'); + } else if (!mt.IsFixed && !mt.IsLong && mt.SqlDbType != SqlDbType.Timestamp && mt.SqlDbType != SqlDbType.Udt && SqlDbType.Structured != mt.SqlDbType) { int size = sqlParam.Size; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index c5a6508fea..a41686945c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -4304,6 +4304,15 @@ internal TdsOperationStatus TryProcessReturnValue(int length, } } + if (tdsType == TdsEnums.SQLVECTOR) + { + result = stateObj.TryReadByte(out rec.scale); + if (result != TdsOperationStatus.Done) + { + return result; + } + } + if (rec.type == SqlDbType.Xml) { // Read schema info @@ -4425,6 +4434,11 @@ internal TdsOperationStatus TryProcessReturnValue(int length, intlen = int.MaxValue; // If plp data, read it all } + if (rec.type == SqlDbTypeExtensions.Vector) + { + rec.length = tdsLen; + } + if (isNull) { GetNullSqlValue(rec.value, rec, SqlCommandColumnEncryptionSetting.Disabled, _connHandler); @@ -5316,6 +5330,15 @@ private TdsOperationStatus TryProcessTypeInfo(TdsParserStateObject stateObj, Sql } } + if (col.type == SqlDbTypeExtensions.Vector) + { + result = stateObj.TryReadByte(out col.scale); + if (result != TdsOperationStatus.Done) + { + return result; + } + } + return TdsOperationStatus.Done; } @@ -6033,6 +6056,11 @@ internal static object GetNullSqlValue(SqlBuffer nullVal, nullVal.SetToNullOfType(SqlBuffer.StorageType.Json); break; + case SqlDbTypeExtensions.Vector: + nullVal.SetToNullOfType(SqlBuffer.StorageType.Vector); + nullVal.SetVectorInfo(MetaType.GetVectorElementCount(md.length, md.scale), md.scale, true); + break; + default: Debug.Fail("unknown null sqlType!" + md.type.ToString()); break; @@ -6638,6 +6666,27 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, } break; + case TdsEnums.SQLVECTOR: + // Vector data is read as non-plp binary value. + // This is same as reading varbinary(8000). + result = stateObj.TryReadByteArrayWithContinue(length, out b); + if (result != TdsOperationStatus.Done) + { + return result; + } + + // Internally, we use Sqlbinary to deal with varbinary data and store it in + // SqlBuffer as SqlBinary value. + value.SqlBinary = SqlTypeWorkarounds.SqlBinaryCtor(b, true); + + // Extract the metadata from the payload and set it as the vector attributes + // in the SqlBuffer. This metadata is further used when constructing a SqlVector + // object from binary payload. + int elementCount = BinaryPrimitives.ReadUInt16LittleEndian(b.AsSpan(2)); + byte elementType = b[4]; + value.SetVectorInfo(elementCount, elementType, false); + break; + case TdsEnums.SQLCHAR: case TdsEnums.SQLBIGCHAR: case TdsEnums.SQLVARCHAR: @@ -6947,6 +6996,7 @@ internal TdsOperationStatus TryReadSqlValueInternal(SqlBuffer value, byte tdsTyp case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLVARBINARY: case TdsEnums.SQLIMAGE: + case TdsEnums.SQLVECTOR: { // Note: Better not come here with plp data!! Debug.Assert(length <= TdsEnums.MAXSIZE); @@ -8233,6 +8283,18 @@ internal TdsOperationStatus TryGetTokenLength(byte token, TdsParserStateObject s tokenLength = -1; return TdsOperationStatus.Done; } + else if (token == TdsEnums.SQLVECTOR) + { + ushort value; + result = stateObj.TryReadUInt16(out value); + if (result != TdsOperationStatus.Done) + { + tokenLength = 0; + return result; + } + tokenLength = value; + return TdsOperationStatus.Done; + } switch (token & TdsEnums.SQLLenMask) { @@ -9788,7 +9850,13 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet if (param.Direction == ParameterDirection.Output) { isSqlVal = param.ParameterIsSqlType; // We have to forward the TYPE info, we need to know what type we are returning. Once we null the parameter we will no longer be able to distinguish what type were seeing. - param.Value = null; + + // Output parameter of SqlDbType vector are defined through SqlParameter.Value. + // This check ensures that we do not discard the parameter value when SqlDbType is vector. + if (mt.SqlDbType != SqlDbTypeExtensions.Vector) + { + param.Value = null; + } param.ParameterIsSqlType = isSqlVal; } else @@ -10121,6 +10189,14 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet maxsize = 1; } + if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + // For vector type we need to write the size in bytes required to represent + // vector value when communicating with SQL Server. + var sqlVectorProps = ((ISqlVector)param.Value); + maxsize = sqlVectorProps.Size; + } + WriteParameterVarLen(mt, maxsize, false/*IsNull*/, stateObj); } } @@ -10143,6 +10219,11 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet { stateObj.WriteByte(param.GetActualScale()); } + // For vector type we need to write scale as the element type of the vector. + else if (mt.SqlDbType == SqlDbTypeExtensions.Vector) + { + stateObj.WriteByte(((ISqlVector)param.Value).ElementType); + } // write out collation or xml metadata @@ -10222,7 +10303,9 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet { // for codePageEncoded types, WriteValue simply expects the number of characters // For plp types, we also need the encoded byte size - writeParamTask = WriteValue(value, mt, isParameterEncrypted ? (byte)0 : param.GetActualScale(), actualSize, codePageByteSize, isParameterEncrypted ? 0 : param.Offset, stateObj, isParameterEncrypted ? 0 : param.Size, isDataFeed); + // For vector type we need to write scale as the element type of the vector. + byte writeScale = mt.SqlDbType == SqlDbTypeExtensions.Vector ? ((ISqlVector)param.Value).ElementType : param.GetActualScale(); + writeParamTask = WriteValue(value, mt, isParameterEncrypted ? (byte)0 : writeScale, actualSize, codePageByteSize, isParameterEncrypted ? 0 : param.Offset, stateObj, isParameterEncrypted ? 0 : param.Size, isDataFeed); } } @@ -11059,6 +11142,11 @@ internal void WriteBulkCopyMetaData(_SqlMetaDataSet metadataCollection, int coun case SqlDbTypeExtensions.Json: stateObj.WriteByteArray(s_jsonMetadataSubstituteSequence, s_xmlMetadataSubstituteSequence.Length, 0); break; + case SqlDbTypeExtensions.Vector: + stateObj.WriteByte(md.tdsType); + WriteTokenLength(md.tdsType, md.length, stateObj); + stateObj.WriteByte(md.scale); + break; default: stateObj.WriteByte(md.tdsType); WriteTokenLength(md.tdsType, md.length, stateObj); @@ -11274,6 +11362,7 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: + case TdsEnums.SQLVECTOR: ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; break; case TdsEnums.SQLUNIQUEID: @@ -11606,6 +11695,12 @@ private void WriteTokenLength(byte token, int length, TdsParserStateObject state { tokenLength = 8; } + else if (token == TdsEnums.SQLVECTOR) + { + tokenLength = 2; + WriteShort(length, stateObj); + return; + } if (tokenLength == 0) { @@ -12431,6 +12526,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: + case TdsEnums.SQLVECTOR: { // An array should be in the object Debug.Assert(isDataFeed || value is byte[], "Value should be an array of bytes"); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs index 46e6091ad4..470cb8e35e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs @@ -1141,6 +1141,15 @@ internal static ArgumentException BadParameterName(string parameterName) return e; } + internal static Exception NullOutputParameterValueForVector(string paramName) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_NullOutputParameterValueForVector, paramName)); + + internal static ArgumentException InvalidVectorHeader() + => Argument(StringsHelper.GetString(Strings.ADP_InvalidVectorHeader)); + + internal static Exception InvalidJsonStringForVector(string value, Exception inner) + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidJsonStringForVector, value), inner); + internal static Exception DeriveParametersNotSupported(IDbCommand value) => DataAdapter(StringsHelper.GetString(Strings.ADP_DeriveParametersNotSupported, value.GetType().Name, value.CommandType.ToString())); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ISqlVector.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ISqlVector.cs new file mode 100644 index 0000000000..ca9fe35743 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ISqlVector.cs @@ -0,0 +1,35 @@ +namespace Microsoft.Data.SqlClient +{ + /// + /// Internal interface for types that represent a vector of SQL values. + /// + internal interface ISqlVector + { + /// + /// Gets the number of elements in the vector. + /// + int Length { get; } + + /// + /// Gets the type of the elements in vector as a + /// TDS Vector Header DimensionType value. + /// Refer TDS section 2.2.5.5.7.4 + /// + byte ElementType { get; } + + /// + /// Gets the size (in bytes) of a single element. + /// + byte ElementSize { get; } + + /// + /// Gets the raw vector data formatted for TDS payload. + /// + byte[] VectorPayload { get; } + + /// + /// Returns the total size in bytes for sending SqlVector value on TDS. + /// + int Size { get; } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs index 8188044a29..87ba2e0629 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBuffer.cs @@ -37,6 +37,7 @@ internal enum StorageType DateTimeOffset, Time, Json, + Vector, } internal struct DateTimeInfo @@ -76,6 +77,12 @@ internal struct DateTimeOffsetInfo internal short _offset; } + internal struct VectorInfo + { + internal int _elementCount; + internal byte _elementType; + } + [StructLayout(LayoutKind.Explicit)] internal struct Storage { @@ -105,6 +112,8 @@ internal struct Storage internal DateTime2Info _dateTime2Info; [FieldOffset(0)] internal DateTimeOffsetInfo _dateTimeOffsetInfo; + [FieldOffset(0)] + internal VectorInfo _vectorInfo; } private bool _isNull; @@ -133,6 +142,15 @@ private SqlBuffer(SqlBuffer value) internal StorageType VariantInternalStorageType => _type; + internal Storage GetVectorInfo() + { + if (_type == StorageType.Vector) + { + return _value; + } + throw new InvalidOperationException(); + } + internal bool Boolean { get @@ -179,7 +197,10 @@ internal byte[] ByteArray { get { - ThrowIfNull(); + if (_type != StorageType.Vector) + { + ThrowIfNull(); + } return SqlBinary.Value; } } @@ -220,16 +241,16 @@ internal decimal Decimal // Only removing trailing zeros from a decimal part won't hit its value! if (_value._numericInfo._scale > 0) { - int zeroCnt = FindTrailingZerosAndPrec((uint)_value._numericInfo._data1, (uint)_value._numericInfo._data2, - (uint)_value._numericInfo._data3, (uint)_value._numericInfo._data4, + int zeroCnt = FindTrailingZerosAndPrec((uint)_value._numericInfo._data1, (uint)_value._numericInfo._data2, + (uint)_value._numericInfo._data3, (uint)_value._numericInfo._data4, _value._numericInfo._scale, out int precision); int minScale = _value._numericInfo._scale - zeroCnt; // minimum possible sacle after removing the trailing zeros. if (zeroCnt > 0 && minScale <= 28 && precision <= 29) { - SqlDecimal sqlValue = new(_value._numericInfo._precision, _value._numericInfo._scale, _value._numericInfo._positive, - _value._numericInfo._data1, _value._numericInfo._data2, + SqlDecimal sqlValue = new(_value._numericInfo._precision, _value._numericInfo._scale, _value._numericInfo._positive, + _value._numericInfo._data1, _value._numericInfo._data2, _value._numericInfo._data3, _value._numericInfo._data4); int integral = precision - minScale; @@ -486,7 +507,17 @@ internal string String get { ThrowIfNull(); - + if (_type == StorageType.Vector) + { + var elementType = (MetaType.SqlVectorElementType)_value._vectorInfo._elementType; + switch (elementType) + { + case MetaType.SqlVectorElementType.Float32: + return GetSqlVector().ToString(); + default: + throw SQL.VectorTypeNotSupported(elementType.ToString()); + } + } if (StorageType.String == _type || StorageType.Json == _type) { return (string)_object; @@ -654,7 +685,7 @@ internal SqlBinary SqlBinary { get { - if (StorageType.SqlBinary == _type) + if (_type is StorageType.SqlBinary or StorageType.Vector) { if (IsNull) { @@ -917,8 +948,23 @@ internal SqlString SqlString { get { + if (_type is StorageType.Vector) + { + if (IsNull) + { + return SqlString.Null; + } + var elementType = (MetaType.SqlVectorElementType)_value._vectorInfo._elementType; + switch (elementType) + { + case MetaType.SqlVectorElementType.Float32: + return new SqlString(GetSqlVector().ToString()); + default: + throw SQL.VectorTypeNotSupported(elementType.ToString()); + } + } // String and Json storage type are both strings. - if (StorageType.String == _type || StorageType.Json == _type) + if (_type is StorageType.String or StorageType.Json) { if (IsNull) { @@ -941,6 +987,19 @@ internal SqlString SqlString internal SqlJson SqlJson => (StorageType.Json == _type) ? (IsNull ? SqlTypes.SqlJson.Null : new SqlJson((string)_object)) : (SqlJson)SqlValue; + internal SqlVector GetSqlVector() where T : unmanaged + { + if (_type is StorageType.Vector) + { + if (IsNull) + { + return new SqlVector(_value._vectorInfo._elementCount); + } + return new SqlVector(SqlBinary.Value); + } + return (SqlVector)SqlValue; + } + internal object SqlValue { get @@ -975,6 +1034,15 @@ internal object SqlValue return SqlString; case StorageType.Json: return SqlJson; + case StorageType.Vector: + var elementType = (MetaType.SqlVectorElementType)_value._vectorInfo._elementType; + switch (elementType) + { + case MetaType.SqlVectorElementType.Float32: + return GetSqlVector(); + default: + throw SQL.VectorTypeNotSupported(elementType.ToString()); + } case StorageType.SqlCachedBuffer: { SqlCachedBuffer data = (SqlCachedBuffer)(_object); @@ -1067,6 +1135,7 @@ internal object Value case StorageType.String: return String; case StorageType.SqlBinary: + case StorageType.Vector: return ByteArray; case StorageType.SqlCachedBuffer: { @@ -1134,6 +1203,7 @@ internal Type GetTypeFromStorageType(bool isSqlType) case StorageType.SqlCachedBuffer: return typeof(SqlString); case StorageType.SqlBinary: + case StorageType.Vector: return typeof(object); case StorageType.SqlGuid: return typeof(SqlGuid); @@ -1190,6 +1260,8 @@ internal Type GetTypeFromStorageType(bool isSqlType) return typeof(DateTimeOffset); case StorageType.Json: return typeof(string); + case StorageType.Vector: + return typeof(byte[]); #if NET case StorageType.Time: return typeof(TimeOnly); @@ -1247,7 +1319,15 @@ internal void SetToDate(DateTime date) _value._int32 = date.Subtract(DateTime.MinValue).Days; _isNull = false; } - #endif +#endif + + internal void SetVectorInfo(int elementCount, byte elementType, bool isNull) + { + _value._vectorInfo._elementCount = elementCount; + _value._vectorInfo._elementType = elementType; + _type = StorageType.Vector; + _isNull = isNull; + } internal void SetToDateTime(int daypart, int timepart) { @@ -1257,8 +1337,8 @@ internal void SetToDateTime(int daypart, int timepart) _type = StorageType.DateTime; _isNull = false; } - - #if NETFRAMEWORK + +#if NETFRAMEWORK internal void SetToDateTime2(DateTime dateTime, byte scale) { Debug.Assert(IsEmpty, "setting value a second time?"); @@ -1269,7 +1349,7 @@ internal void SetToDateTime2(DateTime dateTime, byte scale) _value._dateTime2Info._date = dateTime.Subtract(DateTime.MinValue).Days; _isNull = false; } - #endif +#endif internal void SetToDecimal(byte precision, byte scale, bool positive, int[] bits) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index ca84636934..7819e31bfe 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -633,6 +633,10 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i { AppendColumnNameAndTypeName(updateBulkCommandText, metadata.column, "json"); } + else if (metadata.type == SqlDbTypeExtensions.Vector) + { + AppendColumnNameAndTypeName(updateBulkCommandText, metadata.column, "vector"); + } else { AppendColumnNameAndTypeName(updateBulkCommandText, metadata.column, metadata.type.ToString()); @@ -677,12 +681,15 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i case TdsEnums.SQLNTEXT: size /= 2; break; + case TdsEnums.SQLVECTOR: + size = MetaType.GetVectorElementCount(metadata.length, metadata.scale); + break; default: break; } updateBulkCommandText.AppendFormat((IFormatProvider)null, "({0})", size); } - else if (metadata.metaType.IsPlp && metadata.metaType.SqlDbType != SqlDbType.Xml && metadata.metaType.SqlDbType != SqlDbTypeExtensions.Json) + else if (metadata.metaType.IsPlp && !(metadata.metaType.SqlDbType is SqlDbType.Xml or SqlDbTypeExtensions.Json or SqlDbTypeExtensions.Vector)) { // Partial length column prefix (max) updateBulkCommandText.Append("(max)"); @@ -1592,6 +1599,7 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re case TdsEnums.SQLTIME: case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: + case TdsEnums.SQLVECTOR: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); break; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs index dcd5e858b3..373c071b36 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -2835,6 +2835,18 @@ virtual public SqlJson GetSqlJson(int i) return json; } + /// + virtual public SqlVector GetSqlVector(int i) where T : unmanaged + { + if (typeof(T) != typeof(float)) + { + throw SQL.VectorTypeNotSupported(typeof(T).FullName); + } + + ReadColumn(i); + return _data[i].GetSqlVector(); + } + /// virtual public object GetSqlValue(int i) { @@ -2954,7 +2966,6 @@ virtual public int GetSqlValues(object[] values) override public string GetString(int i) { ReadColumn(i); - // Convert 2008 value to string if type system knob is 2005 or earlier if (_typeSystem <= SqlConnectionString.TypeSystem.SQLServer2005 && _metaData[i].Is2008DateTimeType) { @@ -3084,6 +3095,24 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa { // TypeSystem.SQLServer2005 and above + if (metaData.type == SqlDbTypeExtensions.Vector) + { + if (data.IsNull) + { + return DBNull.Value; + } + else + { + switch (metaData.scale) + { + case (byte)MetaType.SqlVectorElementType.Float32: + return data.GetSqlVector(); + default: + throw SQL.VectorTypeNotSupported(metaData.scale.ToString()); + } + } + } + if (metaData.type != SqlDbType.Udt) { return data.Value; @@ -3187,6 +3216,15 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met return (T)(object)data.TimeOnly; } #endif + else if (typeof(T) == typeof(SqlVector)) + { + MetaType metaType = metaData.metaType; + if (metaType.SqlDbType != SqlDbTypeExtensions.Vector) + { + throw SQL.VectorNotSupportedOnColumnType(metaData.column); + } + return (T)(object)data.GetSqlVector(); + } else if (typeof(T) == typeof(XmlReader)) { // XmlReader only allowed on XML types @@ -3325,6 +3363,13 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { + if (typeof(T) == typeof(string) && metaData.metaType.SqlDbType == SqlDbTypeExtensions.Vector) + { + if (data.IsNull) + return (T)(object)data.String; + else + return (T)(object)data.GetSqlVector().ToString(); + } // the requested type is likely to be one that isn't supported so try the cast and // unless there is a null value conversion then feedback the cast exception with // type named to the user so they know what went wrong. Supported types are listed @@ -4712,7 +4757,7 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn _metaDataConsumed = true; if (_parser != null) - { + { // There is a valid case where parser is null // Peek, and if row token present, set _hasRows true since there is a // row in the result diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlEnums.cs index c48a72dcee..27b3eebcd0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlEnums.cs @@ -63,6 +63,15 @@ internal sealed class MetaType internal readonly bool Is90Supported; internal readonly bool Is100Supported; + // SqlVector Element Types + // + // These underlying values must match the vector "dimension type" values + // in the TDS protocol. + internal enum SqlVectorElementType : byte + { + Float32 = 0x00 + } + public MetaType(byte precision, byte scale, int fixedLength, bool isFixed, bool isLong, bool isPlp, byte tdsType, byte nullableTdsType, string typeName, #if NET [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] @@ -139,6 +148,7 @@ private static bool _IsBinType(SqlDbType type) => type == SqlDbType.VarBinary || type == SqlDbType.Timestamp || type == SqlDbType.Udt || + type == SqlDbTypeExtensions.Vector || (int)type == 24 /*SqlSmallVarBinary*/; private static bool _Is70Supported(SqlDbType type) => @@ -230,6 +240,8 @@ internal static MetaType GetMetaTypeFromSqlDbType(SqlDbType target, bool isMulti return MetaUdt; case SqlDbTypeExtensions.Json: return s_MetaJson; + case SqlDbTypeExtensions.Vector: + return s_MetaVector; case SqlDbType.Structured: if (isMultiValued) { @@ -368,6 +380,8 @@ private static MetaType GetMetaTypeFromValue(Type dataType, object value, bool i return MetaXml; else if (dataType == typeof(SqlJson)) return s_MetaJson; + else if (dataType == typeof(SqlVector)) + return s_MetaVector; else if (dataType == typeof(SqlString)) { return ((inferLen && !((SqlString)value).IsNull) @@ -870,6 +884,8 @@ internal static MetaType GetSqlDataType(int tdsType, uint userType, int length) return MetaDateTimeOffset; case TdsEnums.SQLJSON: return s_MetaJson; + case TdsEnums.SQLVECTOR: + return s_MetaVector; case TdsEnums.SQLVOID: default: @@ -978,6 +994,8 @@ internal static string GetStringFromXml(XmlReader xmlreader) internal static readonly MetaType s_MetaJson = new(255, 255, -1, false, true, true, TdsEnums.SQLJSON, TdsEnums.SQLJSON, MetaTypeName.JSON, typeof(string), typeof(string), SqlDbTypeExtensions.Json, DbType.String, 0); + internal static readonly MetaType s_MetaVector = new(255, 255, -1, false, false, false, TdsEnums.SQLVECTOR, TdsEnums.SQLVECTOR, MetaTypeName.VECTOR, typeof(byte[]), typeof(SqlBinary), SqlDbTypeExtensions.Vector, DbType.Binary, 2); + public static TdsDateTime FromDateTime(DateTime dateTime, byte cb) { SqlDateTime sqlDateTime; @@ -1027,6 +1045,21 @@ internal static int GetTimeSizeFromScale(byte scale) return 5; } + internal static int GetVectorElementSize(byte type) + { + switch (type) + { + case 0: return sizeof(float); + default: + throw SQL.VectorTypeNotSupported(type.ToString()); + } + } + + internal static int GetVectorElementCount(int size, byte elementType) + { + return (size - TdsEnums.VECTOR_HEADER_SIZE) / GetVectorElementSize(elementType); + } + // // please leave string sorted alphabetically // note that these names should only be used in the context of parameters. We always send over BIG* and nullable types for SQL Server @@ -1065,6 +1098,7 @@ private static class MetaTypeName public const string DATETIME2 = "datetime2"; public const string DATETIMEOFFSET = "datetimeoffset"; public const string JSON = "json"; + public const string VECTOR = "vector"; } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs index f0a35b7653..f9b940fec3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -14,6 +14,7 @@ using System.IO; using System.Reflection; using System.Text; +using System.Text.Json; using System.Threading; using System.Xml; using Microsoft.Data.Common; @@ -739,6 +740,10 @@ public override object Value { if (ParameterIsSqlType) { + if (_sqlBufferReturnValue.VariantInternalStorageType == SqlBuffer.StorageType.Vector) + { + return GetVectorReturnValue(); + } return _sqlBufferReturnValue.SqlValue; } return _sqlBufferReturnValue.Value; @@ -758,6 +763,30 @@ public override object Value } } + private object GetVectorReturnValue() + { + var elementType = (MetaType.SqlVectorElementType)_sqlBufferReturnValue.GetVectorInfo()._vectorInfo._elementType; + int elementCount = _sqlBufferReturnValue.GetVectorInfo()._vectorInfo._elementCount; + + if (IsNull) + { + switch (elementType) + { + case MetaType.SqlVectorElementType.Float32: + return new SqlVector(elementCount); + default: + throw SQL.VectorTypeNotSupported(elementType.ToString()); + } + } + switch (elementType) + { + case MetaType.SqlVectorElementType.Float32: + return new SqlVector((byte[])_sqlBufferReturnValue.Value); + default: + throw SQL.VectorTypeNotSupported(elementType.ToString()); + } + } + /// [ RefreshProperties(RefreshProperties.All), @@ -1603,6 +1632,7 @@ internal int GetActualSize() case SqlDbType.VarBinary: case SqlDbType.Image: case SqlDbType.Timestamp: + case SqlDbTypeExtensions.Vector: coercedSize = (!HasFlag(SqlParameterFlags.IsNull) && (!HasFlag(SqlParameterFlags.CoercedValueIsDataFeed))) ? (BinarySize(val, HasFlag(SqlParameterFlags.CoercedValueIsSqlType))) : 0; _actualSize = (ShouldSerializeSize() ? Size : 0); _actualSize = ((ShouldSerializeSize() && (_actualSize <= coercedSize)) ? _actualSize : coercedSize); @@ -1896,6 +1926,13 @@ private MetaType GetMetaTypeOnly() { if (_metaType != null) { + if (_metaType.SqlDbType == SqlDbTypeExtensions.Vector && + _direction == ParameterDirection.Input && + (_value == null || _value == DBNull.Value)) + { + _value = DBNull.Value; + return MetaType.GetDefaultMetaType(); + } return _metaType; } if (_value != null && DBNull.Value != _value) @@ -1923,6 +1960,7 @@ private MetaType GetMetaTypeOnly() return MetaType.GetMetaTypeFromType(valueType); } } + return MetaType.GetDefaultMetaType(); } @@ -2000,6 +2038,13 @@ internal void Validate(int index, bool isCommandProc) GetCoercedValue(); } + if (metaType.SqlDbType == SqlDbTypeExtensions.Vector && + (_value == null || _value == DBNull.Value) && + (Direction == ParameterDirection.Output || Direction == ParameterDirection.InputOutput)) + { + throw ADP.NullOutputParameterValueForVector(_parameterName); + } + //check if the UdtTypeName is specified for Udt params if (metaType.SqlDbType == SqlDbType.Udt) { @@ -2149,6 +2194,10 @@ private int ValueSize(object value) } return sqlString.Value.Length; } + if (value is ISqlVector sqlVector) + { + return sqlVector.Size; + } if (value is SqlChars sqlChars) { if (sqlChars.IsNull) @@ -2320,6 +2369,21 @@ internal static object CoerceValue(object value, MetaType destinationType, out b value = ((TimeOnly)value).ToTimeSpan(); } #endif + else if (currentType == typeof(SqlVector)) + { + value = ((ISqlVector)value).VectorPayload; + } + else if (currentType == typeof(string) && destinationType.SqlDbType == SqlDbTypeExtensions.Vector) + { + try + { + value = (new SqlVector(JsonSerializer.Deserialize(value as string)) as ISqlVector).VectorPayload; + } + catch (Exception ex) when (ex is ArgumentNullException || ex is JsonException) + { + throw ADP.InvalidJsonStringForVector(value as string, ex); + } + } else if ( TdsEnums.SQLTABLE == destinationType.TDSType && ( diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index 3248ac6637..06564633e8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -988,6 +988,16 @@ internal static Exception TextReaderNotSupportOnColumnType(string columnName) return ADP.InvalidCast(StringsHelper.GetString(Strings.SQL_TextReaderNotSupportOnColumnType, columnName)); } + internal static Exception VectorNotSupportedOnColumnType(string columnName) + { + return ADP.InvalidCast(StringsHelper.GetString(Strings.SQL_VectorNotSupportedOnColumnType, columnName)); + } + + internal static Exception VectorTypeNotSupported(string value) + { + return ADP.NotSupported(StringsHelper.GetString(Strings.SQL_VectorTypeNotSupported, value)); + } + internal static Exception XmlReaderNotSupportOnColumnType(string columnName) { return ADP.InvalidCast(StringsHelper.GetString(Strings.SQL_XmlReaderNotSupportOnColumnType, columnName)); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs index b17e52500c..630633a6b4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -255,7 +255,7 @@ public enum FeatureExtension : uint UTF8Support = 1 << (TdsEnums.FEATUREEXT_UTF8SUPPORT - 1), SQLDNSCaching = 1 << (TdsEnums.FEATUREEXT_SQLDNSCACHING - 1), JsonSupport = 1 << (TdsEnums.FEATUREEXT_JSONSUPPORT - 1), - VectorSupport = 1 << (TdsEnums.FEATUREEXT_VECTORSUPPORT -1) + VectorSupport = 1 << (TdsEnums.FEATUREEXT_VECTORSUPPORT - 1) } public const uint UTF8_IN_TDSCOLLATION = 0x4000000; @@ -487,6 +487,7 @@ public enum ActiveDirectoryWorkflow : byte public const int SQLDATETIMEOFFSET = 0x2b; public const int SQLJSON = 0xF4; + public const int SQLVECTOR = 0xF5; public const int DEFAULT_VARTIME_SCALE = 7; @@ -982,6 +983,7 @@ internal enum FedAuthInfoId : byte // Vector Support constants internal const byte MAX_SUPPORTED_VECTOR_VERSION = 0x01; + internal const int VECTOR_HEADER_SIZE = 8; // TCE Related constants internal const byte MAX_SUPPORTED_TCE_VERSION = 0x03; // max version diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlDbTypeExtensions.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlDbTypeExtensions.cs index df7f597517..96244fb7a8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlDbTypeExtensions.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlDbTypeExtensions.cs @@ -14,6 +14,12 @@ public static class SqlDbTypeExtensions public const SqlDbType Json = SqlDbType.Json; #else public const SqlDbType Json = (SqlDbType)35; +#endif + /// +#if NET10_0_OR_GREATER + public const SqlDbType Vector = SqlDbType.Vector; +#else + public const SqlDbType Vector = (SqlDbType)36; #endif } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs new file mode 100644 index 0000000000..3ff7a488a0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs @@ -0,0 +1,246 @@ +using System; +using System.Buffers.Binary; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text.Json; +using Microsoft.Data.Common; +using Microsoft.Data.SqlClient; + +#nullable enable + +namespace Microsoft.Data.SqlTypes; + +/// +public sealed class SqlVector : INullable, ISqlVector +where T : unmanaged +{ + #region Constants + + private const byte VecHeaderMagicNo = 0xA9; + private const byte VecVersionNo = 0x01; + + #endregion + + #region Fields + + private readonly byte _elementType; + private readonly byte _elementSize; + private readonly byte[] _tdsBytes; + private T[] _array; + + #endregion + + #region Constructors + + /// + public SqlVector(int length) + { + if (length < 0) + { + throw ADP.ArgumentOutOfRange(nameof(length), SQLResource.InvalidArraySizeMessage); + } + + (_elementType, _elementSize) = GetTypeFieldsOrThrow(); + + IsNull = true; + + Length = length; + Size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); + + _tdsBytes = Array.Empty(); + _array = Array.Empty(); + Memory = new(); + } + + /// + public SqlVector(ReadOnlyMemory memory) + { + (_elementType, _elementSize) = GetTypeFieldsOrThrow(); + + IsNull = false; + + Length = memory.Length; + Size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * Length); + + _tdsBytes = MakeTdsBytes(memory); + _array = memory.ToArray(); + Memory = memory; + } + + internal SqlVector(byte[] tdsBytes) + { + (_elementType, _elementSize) = GetTypeFieldsOrThrow(); + + (Length, Size) = GetCountsOrThrow(tdsBytes); + + IsNull = false; + + _tdsBytes = tdsBytes; + _array = MakeArray(); + Memory = new(_array); + } + + #endregion + + #region Methods + + /// + public override string ToString() + { + if (IsNull) + { + return SQLResource.NullString; + } + return JsonSerializer.Serialize(Memory); + } + + /// + public T[] ToArray() + { + return _array; + } + + #endregion + + #region Properties + + /// + public bool IsNull { get; init; } + + /// + public static SqlVector? Null => null; + + /// + public int Length { get; init; } + /// + public int Size { get; init; } + + /// + public ReadOnlyMemory Memory { get; init; } + + #endregion + + #region ISqlVector Internal Properties + byte ISqlVector.ElementType => _elementType; + byte ISqlVector.ElementSize => _elementSize; + byte[] ISqlVector.VectorPayload => _tdsBytes; + #endregion + + #region Helpers + + private (byte, byte) GetTypeFieldsOrThrow() + { + byte elementType; + byte elementSize; + + if (typeof(T) == typeof(float)) + { + elementType = (byte)MetaType.SqlVectorElementType.Float32; + elementSize = sizeof(float); + } + else + { + throw SQL.VectorTypeNotSupported(typeof(T).FullName); + } + + return (elementType, elementSize); + } + + private byte[] MakeTdsBytes(ReadOnlyMemory values) + { + //Refer to TDS section 2.2.5.5.7 for vector header format + // +------------------------+-----------------+----------------------+------------------+----------------------------+--------------+ + // | Field | Size (bytes) | Example Value | Description | + // +------------------------+-----------------+----------------------+--------------------------------------------------------------+ + // | Layout Format | 1 | 0xA9 | Magic number indicating vector layout format | + // | Layout Version | 1 | 0x01 | Version of the vector format | + // | Number of Dimensions | 2 | NN | Number of vector elements | + // | Dimension Type | 1 | 0x00 | Element type indicator (e.g. 0x00 for float32) | + // | Reserved | 3 | 0x00 0x00 0x00 | Reserved for future use | + // | Stream of Values | NN * sizeof(T) | [element bytes...] | Raw bytes for vector elements | + // +------------------------+-----------------+----------------------+--------------------------------------------------------------+ + + byte[] result = new byte[Size]; + + // Header Bytes + result[0] = VecHeaderMagicNo; + result[1] = VecVersionNo; + result[2] = (byte)(Length & 0xFF); + result[3] = (byte)((Length >> 8) & 0xFF); + result[4] = _elementType; + result[5] = 0x00; + result[6] = 0x00; + result[7] = 0x00; + +#if NETFRAMEWORK + // Copy data via marshaling. + if (MemoryMarshal.TryGetArray(values, out ArraySegment segment)) + { + Buffer.BlockCopy(segment.Array, segment.Offset * _elementSize, result, TdsEnums.VECTOR_HEADER_SIZE, segment.Count * _elementSize); + } + else + { + Buffer.BlockCopy(values.ToArray(), 0, result, TdsEnums.VECTOR_HEADER_SIZE, values.Length * _elementSize); + } +#else + // Fast span-based copy. + var byteSpan = MemoryMarshal.AsBytes(values.Span); + byteSpan.CopyTo(result.AsSpan(TdsEnums.VECTOR_HEADER_SIZE)); +#endif + return result; + } + + private (int, int) GetCountsOrThrow(byte[] rawBytes) + { + // Validate some of the header fields. + if ( + // Do we have enough bytes for the header? + rawBytes.Length < TdsEnums.VECTOR_HEADER_SIZE || + // Do we have the expected magic number? + rawBytes[0] != VecHeaderMagicNo || + // Do we support the version? + rawBytes[1] != VecVersionNo || + // Do the vector types match? + rawBytes[4] != _elementType) + { + // No, so throw. + throw ADP.InvalidVectorHeader(); + } + + // The vector length is an unsigned 16-bit integer, little-endian. + int length = BinaryPrimitives.ReadUInt16LittleEndian(rawBytes.AsSpan(2)); + + // The vector size is the number of bytes required to represent the vector in TDS. + int size = TdsEnums.VECTOR_HEADER_SIZE + (_elementSize * length); + + // Are there exactly enough bytes for the vector elements? + if (rawBytes.Length != size) + { + // No, so throw. + throw ADP.InvalidVectorHeader(); + } + + return (length, size); + } + + private T[] MakeArray() + { + if (_tdsBytes.Length == 0) + { + return Array.Empty(); + } + +#if NETFRAMEWORK + // Allocate array and copy bytes into it + T[] result = new T[Length]; + Buffer.BlockCopy(_tdsBytes, 8, result, 0, _elementSize * Length); + return result; +#else + ReadOnlySpan dataSpan = _tdsBytes.AsSpan(8, _elementSize * Length); + return MemoryMarshal.Cast(dataSpan).ToArray(); +#endif + } + + #endregion +} diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs index 7337cf5f42..b41b331db7 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs @@ -10821,10 +10821,65 @@ internal static string SQL_UserInstanceNotAvailableInProc { } /// - /// Looks up a localized string similar to Expecting argument of type {1}, but received type {0}.. + /// Looks up a localized string similar to Invalid attempt to get vector data from column '{0}'. Vectors are only supported for columns of type vector.. /// - internal static string SQL_WrongType { + internal static string SQL_VectorNotSupportedOnColumnType { get { + return ResourceManager.GetString("SQL_VectorNotSupportedOnColumnType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unsupported Vector type '{0}'.. + /// + internal static string SQL_VectorTypeNotSupported + { + get + { + return ResourceManager.GetString("SQL_VectorTypeNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to 'null' value not supported for output paramter {0} of SqlDbtype Vector.. + /// + internal static string ADP_NullOutputParameterValueForVector + { + get + { + return ResourceManager.GetString("ADP_NullOutputParameterValueForVector", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Invalid vector header received.. + /// + internal static string ADP_InvalidVectorHeader + { + get + { + return ResourceManager.GetString("ADP_InvalidVectorHeader", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to {0} Invalid JSON string for vector... + /// + internal static string ADP_InvalidJsonStringForVector + { + get + { + return ResourceManager.GetString("ADP_InvalidJsonStringForVector", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Expecting argument of type {1}, but received type {0}.. + /// + internal static string SQL_WrongType + { + get + { return ResourceManager.GetString("SQL_WrongType", resourceCulture); } } diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx index 1caf481f7c..90a157876a 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx @@ -4743,4 +4743,19 @@ SqlBatchCommand list has not been initialized. + + Invalid attempt to get vector data from column '{0}'. Vectors are only supported for columns of type vector. + + + Unsupported Vector type '{0}'. + + + 'null' value not supported for output parameter '{0}' of SqlDbtype Vector. + + + Invalid vector header received. + + + {0} Invalid JSON string for vector. + \ No newline at end of file diff --git a/src/Microsoft.Data.SqlClient/src/System/Runtime/CompilerServices/IsExternalInit.netfx.cs b/src/Microsoft.Data.SqlClient/src/System/Runtime/CompilerServices/IsExternalInit.netfx.cs new file mode 100644 index 0000000000..0d0181ba6d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/System/Runtime/CompilerServices/IsExternalInit.netfx.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#if NETFRAMEWORK + +using System.ComponentModel; + + +// This class enables the use of the `init` property accessor in .NET framework. +namespace System.Runtime.CompilerServices +{ + /// + /// Reserved to be used by the compiler for tracking metadata. + /// This class should not be used by developers in source code. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + internal static class IsExternalInit + { + } +} + +#endif diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index c84e07a6ef..aacc723453 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -92,9 +92,12 @@ public static class DataTestUtility //SQL Server EngineEdition private static string s_sqlServerEngineEdition; - // JSON Coloumn type + // JSON Column type public static readonly bool IsJsonSupported = false; + // VECTOR column type + public static readonly bool IsVectorSupported = false; + // Azure Synapse EngineEditionId == 6 // More could be read at https://learn.microsoft.com/en-us/sql/t-sql/functions/serverproperty-transact-sql?view=sql-server-ver16#propertyname public static bool IsAzureSynapse diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 740542145b..7bd503c5b9 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -297,6 +297,8 @@ + + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs new file mode 100644 index 0000000000..37c413c7eb --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs @@ -0,0 +1,593 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.SqlTypes; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Data.SqlTypes; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest +{ + public static class VectorFloat32TestData + { + public const int VectorHeaderSize = 8; + public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f }; + public static int sizeInbytes = VectorHeaderSize + testData.Length * sizeof(float); + public static int vectorColumnLength = testData.Length; + public static IEnumerable GetVectorFloat32TestData() + { + // Pattern 1-4 with SqlVector(values: testData) + yield return new object[] { 1, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; + yield return new object[] { 2, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; + yield return new object[] { 3, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; + yield return new object[] { 4, new SqlVector(testData), testData, sizeInbytes, vectorColumnLength }; + + // Pattern 1�4 with SqlVector(n) + yield return new object[] { 1, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 2, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 3, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 4, new SqlVector(vectorColumnLength), Array.Empty(), sizeInbytes, vectorColumnLength }; + + // Pattern 1�4 with DBNull + yield return new object[] { 1, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 2, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 3, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 4, DBNull.Value, Array.Empty(), sizeInbytes, vectorColumnLength }; + + // Pattern 1�4 with SqlVector.Null + yield return new object[] { 1, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + + // Following scenario is not supported in SqlClient. + // This can only be fixed with a behavior change that SqlParameter.Value is internally set to DBNull.Value if it is set to null. + //yield return new object[] { 2, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + + yield return new object[] { 3, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + yield return new object[] { 4, SqlVector.Null, Array.Empty(), sizeInbytes, vectorColumnLength }; + } + } + + public sealed class NativeVectorFloat32Tests : IDisposable + { + private readonly ITestOutputHelper _output; + private static readonly string s_connectionString = ManualTesting.Tests.DataTestUtility.TCPConnectionString; + private static readonly string s_tableName = DataTestUtility.GetUniqueName("VectorTestTable"); + private static readonly string s_bulkCopySrcTableName = DataTestUtility.GetUniqueName("VectorBulkCopyTestTable"); + private static readonly string s_bulkCopySrcTableDef = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector(3) NULL)"; + private static readonly string s_tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector(3) NULL)"; + private static readonly string s_selectCmdString = $"SELECT VectorData FROM {s_tableName} ORDER BY Id DESC"; + private static readonly string s_insertCmdString = $"INSERT INTO {s_tableName} (VectorData) VALUES (@VectorData)"; + private static readonly string s_vectorParamName = $"@VectorData"; + private static readonly string s_outputVectorParamName = $"@OutputVectorData"; + private static readonly string s_storedProcName = DataTestUtility.GetUniqueName("VectorsAsVarcharSp"); + private static readonly string s_storedProcBody = $@" + {s_vectorParamName} vector(3), -- Input: Serialized float[] as JSON string + {s_outputVectorParamName} vector(3) OUTPUT -- Output: Echoed back from latest inserted row + AS + BEGIN + SET NOCOUNT ON; + + -- Insert into vector table + INSERT INTO {s_tableName} (VectorData) + VALUES ({s_vectorParamName}); + + -- Retrieve latest entry (assumes auto-incrementing ID) + SELECT TOP 1 {s_outputVectorParamName} = VectorData + FROM {s_tableName} + ORDER BY Id DESC; + END;"; + + public NativeVectorFloat32Tests(ITestOutputHelper output) + { + _output = output; + using var connection = new SqlConnection(s_connectionString); + connection.Open(); + DataTestUtility.CreateTable(connection, s_tableName, s_tableDefinition); + DataTestUtility.CreateTable(connection, s_bulkCopySrcTableName, s_bulkCopySrcTableDef); + DataTestUtility.CreateSP(connection, s_storedProcName, s_storedProcBody); + } + + public void Dispose() + { + using var connection = new SqlConnection(s_connectionString); + connection.Open(); + DataTestUtility.DropTable(connection, s_tableName); + DataTestUtility.DropTable(connection, s_bulkCopySrcTableName); + DataTestUtility.DropStoredProcedure(connection, s_storedProcName); + } + + private void ValidateSqlVectorFloat32Object(bool isNull, SqlVector sqlVectorFloat32, float[] expectedData, int expectedSize, int expectedLength) + { + Assert.Equal(expectedData, sqlVectorFloat32.Memory.ToArray()); + Assert.Equal(expectedSize, sqlVectorFloat32.Size); + Assert.Equal(expectedLength, sqlVectorFloat32.Length); + if (!isNull) + { + Assert.False(sqlVectorFloat32.IsNull, "IsNull set to true for a non-null value"); + } + else + { + Assert.True(sqlVectorFloat32.IsNull, "IsNull set to false for a null value"); + } + } + + private void ValidateInsertedData(SqlConnection connection, float[] expectedData, int expectedSize, int expectedLength) + { + using var selectCmd = new SqlCommand(s_selectCmdString, connection); + using var reader = selectCmd.ExecuteReader(); + Assert.True(reader.Read(), "No data found in the table."); + + //For both null and non-null cases, validate the SqlVector object + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), reader.GetFieldValue>(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedSize, expectedLength); + + if (!reader.IsDBNull(0)) + { + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetValue(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader[0], expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader["VectorData"], expectedData, expectedSize, expectedLength); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetFieldValue(0))); + } + else + { + Assert.Equal(DBNull.Value, reader.GetValue(0)); + Assert.Equal(DBNull.Value, reader[0]); + Assert.Equal(DBNull.Value, reader["VectorData"]); + Assert.Throws(() => reader.GetString(0)); + Assert.Throws(() => reader.GetSqlString(0).Value); + Assert.Throws(() => reader.GetFieldValue(0)); + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))] + public void TestSqlVectorFloat32ParameterInsertionAndReads( + int pattern, + object value, + float[] expectedValues, + int expectedSize, + int expectedLength) + { + using var conn = new SqlConnection(s_connectionString); + conn.Open(); + + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + + SqlParameter param = pattern switch + { + 1 => new SqlParameter + { + ParameterName = s_vectorParamName, + SqlDbType = SqlDbTypeExtensions.Vector, + Value = value + }, + 2 => new SqlParameter(s_vectorParamName, value), + 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") + }; + + insertCmd.Parameters.Add(param); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + + ValidateInsertedData(conn, expectedValues, expectedSize, expectedLength); + } + + private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] expectedData, int expectedSize, int expectedLength) + { + using var selectCmd = new SqlCommand(s_selectCmdString, connection); + using var reader = await selectCmd.ExecuteReaderAsync(); + Assert.True(await reader.ReadAsync(), "No data found in the table."); + + //For both null and non-null cases, validate the SqlVector object + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), await reader.GetFieldValueAsync>(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedSize, expectedLength); + + if (!await reader.IsDBNullAsync(0)) + { + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetValue(0), expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader[0], expectedData, expectedSize, expectedLength); + ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader["VectorData"], expectedData, expectedSize, expectedLength); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); + Assert.Equal(expectedData, JsonSerializer.Deserialize(await reader.GetFieldValueAsync(0))); + } + else + { + Assert.Equal(DBNull.Value, reader.GetValue(0)); + Assert.Equal(DBNull.Value, reader[0]); + Assert.Equal(DBNull.Value, reader["VectorData"]); + Assert.Throws(() => reader.GetString(0)); + Assert.Throws(() => reader.GetSqlString(0).Value); + await Assert.ThrowsAsync(async () => await reader.GetFieldValueAsync(0)); + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))] + public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync( + int pattern, + object value, + float[] expectedValues, + int expectedSize, + int expectedLength) + { + using var conn = new SqlConnection(s_connectionString); + await conn.OpenAsync(); + + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + + SqlParameter param = pattern switch + { + 1 => new SqlParameter + { + ParameterName = s_vectorParamName, + SqlDbType = (SqlDbType)36, // SqlDbTypeExtension.Vector + Value = value + }, + 2 => new SqlParameter(s_vectorParamName, value), + 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") + }; + + insertCmd.Parameters.Add(param); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + + await ValidateInsertedDataAsync(conn, expectedValues, expectedSize, expectedLength); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))] + public void TestStoredProcParamsForVectorFloat32( + int pattern, + object value, + float[] expectedValues, + int expectedSize, + int expectedLength) + { + //Create SP for test + using var conn = new SqlConnection(s_connectionString); + conn.Open(); + DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); + using var command = new SqlCommand(s_storedProcName, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + SqlParameter inputParam = pattern switch + { + 1 => new SqlParameter + { + ParameterName = s_vectorParamName, + SqlDbType = SqlDbTypeExtensions.Vector, // SqlDbTypeExtension.Vector + Value = value + }, + 2 => new SqlParameter(s_vectorParamName, value), + 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") + }; + command.Parameters.Add(inputParam); + + var outputParam = new SqlParameter + { + ParameterName = s_outputVectorParamName, + SqlDbType = SqlDbTypeExtensions.Vector, + Direction = ParameterDirection.Output, + Value = new SqlVector(3) + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + command.ExecuteNonQuery(); + + // Validate the output parameter + var vector = outputParam.Value as SqlVector; + ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedSize, expectedLength); + + // Validate error for conventional way of setting output parameters + command.Parameters.Clear(); + command.Parameters.Add(inputParam); + var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Direction = ParameterDirection.Output }; + command.Parameters.Add(outputParamWithoutVal); + Assert.Throws(() => command.ExecuteNonQuery()); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData))] + public async Task TestStoredProcParamsForVectorFloat32Async( + int pattern, + object value, + float[] expectedValues, + int expectedSize, + int expectedLength) + { + //Create SP for test + using var conn = new SqlConnection(s_connectionString); + await conn.OpenAsync(); + DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); + using var command = new SqlCommand(s_storedProcName, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + SqlParameter inputParam = pattern switch + { + 1 => new SqlParameter + { + ParameterName = s_vectorParamName, + SqlDbType = SqlDbTypeExtensions.Vector, // SqlDbTypeExtension.Vector + Value = value + }, + 2 => new SqlParameter(s_vectorParamName, value), + 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, + 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Value = value }, + _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") + }; + command.Parameters.Add(inputParam); + + var outputParam = new SqlParameter + { + ParameterName = s_outputVectorParamName, + SqlDbType = SqlDbTypeExtensions.Vector, + Direction = ParameterDirection.Output, + Value = new SqlVector(3) + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + await command.ExecuteNonQueryAsync(); + + // Validate the output parameter + var vector = outputParam.Value as SqlVector; + ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedSize, expectedLength); + + // Validate error for conventional way of setting output parameters + command.Parameters.Clear(); + command.Parameters.Add(inputParam); + var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, new SqlVector(3).Size) { Direction = ParameterDirection.Output }; + command.Parameters.Add(outputParamWithoutVal); + await Assert.ThrowsAsync(async () => await command.ExecuteNonQueryAsync()); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [InlineData(1)] + [InlineData(2)] + public void TestBulkCopyFromSqlTable(int bulkCopySourceMode) + { + //Setup source with test data and create destination table for bulkcopy. + SqlConnection sourceConnection = new SqlConnection(s_connectionString); + sourceConnection.Open(); + SqlConnection destinationConnection = new SqlConnection(s_connectionString); + destinationConnection.Open(); + DataTable table = null; + switch (bulkCopySourceMode) + { + + case 1: + // Use SqlServer table as source + var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); + var vectorParam = new SqlParameter(s_vectorParamName, new SqlVector(VectorFloat32TestData.testData)); + + // Insert 2 rows with one non-null and null value + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + vectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + break; + case 2: + table = new DataTable(s_bulkCopySrcTableName); + table.Columns.Add("Id", typeof(int)); + table.Columns.Add("VectorData", typeof(SqlVector)); + table.Rows.Add(1, new SqlVector(VectorFloat32TestData.testData)); + table.Rows.Add(2, DBNull.Value); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + + + //Bulkcopy from sql server table to destination table + using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); + using SqlDataReader reader = sourceDataCommand.ExecuteReader(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) + { + DestinationTableName = s_tableName, + }; + + try + { + switch (bulkCopySourceMode) + { + case 1: + bulkCopy.WriteToServer(reader); + break; + case 2: + bulkCopy.WriteToServer(table); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + } + catch (Exception ex) + { + // If bulk copy fails, fail the test with the exception message + Assert.Fail($"Bulk copy failed: {ex.Message}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); + using SqlDataReader verifyReader = verifyCommand.ExecuteReader(); + + // Verify that we have data in the destination table + Assert.True(verifyReader.Read(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.True(!verifyReader.IsDBNull(0), "First row in the table is null."); + Assert.Equal(VectorFloat32TestData.testData, ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); + Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); + Assert.Equal(VectorFloat32TestData.sizeInbytes, ((SqlVector)verifyReader.GetSqlVector(0)).Size); + + // Verify that we have another row + Assert.True(verifyReader.Read(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(verifyReader.IsDBNull(0)); + Assert.Equal(Array.Empty(), ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); + Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); + Assert.Equal(VectorFloat32TestData.sizeInbytes, ((SqlVector)verifyReader.GetSqlVector(0)).Size); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + [InlineData(1)] + [InlineData(2)] + public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) + { + //Setup source with test data and create destination table for bulkcopy. + SqlConnection sourceConnection = new SqlConnection(s_connectionString); + await sourceConnection.OpenAsync(); + SqlConnection destinationConnection = new SqlConnection(s_connectionString); + await destinationConnection.OpenAsync(); + + DataTable table = null; + switch (bulkCopySourceMode) + { + + case 1: + // Use SqlServer table as source + var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); + var vectorParam = new SqlParameter(s_vectorParamName, new SqlVector(VectorFloat32TestData.testData)); + + // Insert 2 rows with one non-null and null value + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + vectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + break; + case 2: + table = new DataTable(s_bulkCopySrcTableName); + table.Columns.Add("Id", typeof(int)); + table.Columns.Add("VectorData", typeof(SqlVector)); + table.Rows.Add(1, new SqlVector(VectorFloat32TestData.testData)); + table.Rows.Add(2, DBNull.Value); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + //Bulkcopy from sql server table to destination table + using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); + using SqlDataReader reader = await sourceDataCommand.ExecuteReaderAsync(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) + { + DestinationTableName = s_tableName, + }; + + try + { // Perform bulkcopy + switch (bulkCopySourceMode) + { + case 1: + await bulkCopy.WriteToServerAsync(reader); + break; + case 2: + await bulkCopy.WriteToServerAsync(table); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + } + catch (Exception ex) + { + // If bulk copy fails, fail the test with the exception message + Assert.Fail($"Bulk copy failed: {ex.Message}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); + using SqlDataReader verifyReader = await verifyCommand.ExecuteReaderAsync(); + + // Verify that we have data in the destination table + Assert.True(await verifyReader.ReadAsync(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.True(!await verifyReader.IsDBNullAsync(0), "First row in the table is null."); + var vector = await verifyReader.GetFieldValueAsync>(0); + Assert.Equal(VectorFloat32TestData.testData, vector.Memory.ToArray()); + Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); + Assert.Equal(VectorFloat32TestData.sizeInbytes, vector.Size); + + // Verify that we have another row + Assert.True(await verifyReader.ReadAsync(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(await verifyReader.IsDBNullAsync(0)); + vector = await verifyReader.GetFieldValueAsync>(0); + Assert.Equal(Array.Empty(), vector.Memory.ToArray()); + Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); + Assert.Equal(VectorFloat32TestData.sizeInbytes, vector.Size); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestInsertVectorsFloat32WithPrepare() + { + SqlConnection conn = new SqlConnection(s_connectionString); + conn.Open(); + SqlCommand command = new SqlCommand(s_insertCmdString, conn); + SqlParameter vectorParam = new SqlParameter("@VectorData", SqlDbTypeExtensions.Vector, new SqlVector(3).Size); + command.Parameters.Add(vectorParam); + command.Prepare(); + for (int i = 0; i < 10; i++) + { + vectorParam.Value = new SqlVector(new float[] { i + 0.1f, i + 0.2f, i + 0.3f }); + command.ExecuteNonQuery(); + } + SqlCommand validateCommand = new SqlCommand($"SELECT VectorData FROM {s_tableName}", conn); + using SqlDataReader reader = validateCommand.ExecuteReader(); + int rowcnt = 0; + while (reader.Read()) + { + float[] expectedData = new float[] { rowcnt + 0.1f, rowcnt + 0.2f, rowcnt + 0.3f }; + float[] dbData = reader.GetSqlVector(0).Memory.ToArray(); + Assert.Equal(expectedData, dbData); + rowcnt++; + } + Assert.Equal(10, rowcnt); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/VectorTypeBackwardCompatibilityTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/VectorTypeBackwardCompatibilityTests.cs new file mode 100644 index 0000000000..d4857bf5e2 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/VectorTypeBackwardCompatibilityTests.cs @@ -0,0 +1,630 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.SqlTypes; +using System.Text.Json; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest +{ + public sealed class VectorTypeBackwardCompatibilityTests : IDisposable + { + private readonly ITestOutputHelper _output; + private static readonly string s_connectionString = ManualTesting.Tests.DataTestUtility.TCPConnectionString; + private static readonly string s_tableName = DataTestUtility.GetUniqueName("VectorTestTable"); + private static readonly string s_bulkCopySrcTableName = DataTestUtility.GetUniqueName("VectorBulkCopyTestTable"); + private static readonly string s_bulkCopySrcTableDef = $@"(Id INT PRIMARY KEY IDENTITY, VectorData varchar(max) NULL)"; + private static readonly string s_tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector(3) NULL)"; + private static readonly string s_selectCmdString = $"SELECT VectorData FROM {s_tableName} ORDER BY Id DESC"; + private static readonly string s_insertCmdString = $"INSERT INTO {s_tableName} (VectorData) VALUES (@VectorData)"; + private static readonly string s_vectorParamName = $"@VectorData"; + private static readonly string s_storedProcName = DataTestUtility.GetUniqueName("VectorsAsVarcharSp"); + private static readonly string s_storedProcBody = $@" + @InputVectorJson VARCHAR(MAX), -- Input: Serialized float[] as JSON string + @OutputVectorJson VARCHAR(MAX) OUTPUT -- Output: Echoed back from latest inserted row + AS + BEGIN + SET NOCOUNT ON; + + -- Insert into vector table + INSERT INTO {s_tableName} (VectorData) + VALUES (@InputVectorJson); + + -- Retrieve latest entry (assumes auto-incrementing ID) + SELECT TOP 1 @OutputVectorJson = VectorData + FROM {s_tableName} + ORDER BY Id DESC; + END;"; + + public VectorTypeBackwardCompatibilityTests(ITestOutputHelper output) + { + _output = output; + using var connection = new SqlConnection(s_connectionString); + connection.Open(); + DataTestUtility.CreateTable(connection, s_tableName, s_tableDefinition); + DataTestUtility.CreateTable(connection, s_bulkCopySrcTableName, s_bulkCopySrcTableDef); + DataTestUtility.CreateSP(connection, s_storedProcName, s_storedProcBody); + } + + public void Dispose() + { + using var connection = new SqlConnection(s_connectionString); + connection.Open(); + DataTestUtility.DropTable(connection, s_tableName); + DataTestUtility.DropTable(connection, s_bulkCopySrcTableName); + DataTestUtility.DropStoredProcedure(connection, s_storedProcName); + } + + private void ValidateInsertedData(SqlConnection connection, float[] expectedData) + { + using var selectCmd = new SqlCommand(s_selectCmdString, connection); + using var reader = selectCmd.ExecuteReader(); + Assert.True(reader.Read(), "No data found in the table."); + + if (!reader.IsDBNull(0)) + { + string jsonFromDb = reader.GetString(0); + float[] deserialized = JsonSerializer.Deserialize(jsonFromDb)!; + Assert.Equal(expectedData, deserialized); + } + else + { + Assert.Null(expectedData); + var val = reader.GetValue(0); + Assert.Equal(DBNull.Value, val); + } + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestVectorDataInsertionAsVarchar() + { + float[] data = { 1.1f, 2.2f, 3.3f }; + string json = JsonSerializer.Serialize(data); + + using var conn = new SqlConnection(s_connectionString); + conn.Open(); + + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + + // Pattern 1: Default constructor + property setters + var p1 = new SqlParameter(); + p1.ParameterName = s_vectorParamName; + p1.SqlDbType = SqlDbType.VarChar; + p1.Size = -1; //varchar(max) + p1.Value = json; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, data); + + var nullp1 = new SqlParameter(); + nullp1.ParameterName = s_vectorParamName; + nullp1.SqlDbType = SqlDbType.VarChar; + nullp1.Size = -1; //varchar(max) + nullp1.Value = DBNull.Value; + insertCmd.Parameters.Add(nullp1); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, null); + + // Pattern 2: Name + value constructor + var p2 = new SqlParameter(s_vectorParamName, json); + insertCmd.Parameters.Add(p2); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, data); + + var nullp2 = new SqlParameter(s_vectorParamName, DBNull.Value); + insertCmd.Parameters.Add(nullp2); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, null); + + // Pattern 3: Name + SqlDbType constructor + var p3 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar) { Value = json }; + insertCmd.Parameters.Add(p3); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, data); + + var nullp3 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar) { Value = DBNull.Value }; + insertCmd.Parameters.Add(nullp3); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, null); + + // Pattern 4: Name + SqlDbType + Size constructor + var p4 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = json }; + insertCmd.Parameters.Add(p4); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, data); + + var nullp4 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = DBNull.Value }; + insertCmd.Parameters.Add(nullp4); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + ValidateInsertedData(conn, null); + } + + private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] expectedData) + { + using var selectCmd = new SqlCommand(s_selectCmdString, connection); + using var reader = await selectCmd.ExecuteReaderAsync(); + Assert.True(await reader.ReadAsync(), "No data found in the table."); + + if (!await reader.IsDBNullAsync(0)) + { + string jsonFromDb = reader.GetString(0); + float[] deserialized = JsonSerializer.Deserialize(jsonFromDb)!; + Assert.Equal(expectedData, deserialized); + } + else + { + Assert.Null(expectedData); + var val = reader.GetValue(0); + Assert.Equal(DBNull.Value, val); + } + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public async Task TestVectorParameterInitializationAsync() + { + float[] data = { 1.1f, 2.2f, 3.3f }; + string json = JsonSerializer.Serialize(data); + + using var conn = new SqlConnection(s_connectionString); + await conn.OpenAsync(); + + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + + // Pattern 1: Default constructor + property setters + var p1 = new SqlParameter(); + p1.ParameterName = s_vectorParamName; + p1.SqlDbType = SqlDbType.VarChar; + p1.Size = -1; + p1.Value = json; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, data); + + var nullp1 = new SqlParameter(); + nullp1.ParameterName = s_vectorParamName; + nullp1.SqlDbType = SqlDbType.VarChar; + nullp1.Size = -1; + nullp1.Value = DBNull.Value; + insertCmd.Parameters.Add(nullp1); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, null); + + // Pattern 2: Name + value constructor + var p2 = new SqlParameter(s_vectorParamName, json); + insertCmd.Parameters.Add(p2); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, data); + + var nullp2 = new SqlParameter(s_vectorParamName, DBNull.Value); + insertCmd.Parameters.Add(nullp2); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, null); + + // Pattern 3: Name + SqlDbType constructor + var p3 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar) { Value = json }; + insertCmd.Parameters.Add(p3); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, data); + + var nullp3 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar) { Value = DBNull.Value }; + insertCmd.Parameters.Add(nullp3); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, null); + + // Pattern 4: Name + SqlDbType + Size constructor + var p4 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = json }; + insertCmd.Parameters.Add(p4); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, data); + + var nullp4 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = DBNull.Value }; + insertCmd.Parameters.Add(nullp4); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + await ValidateInsertedDataAsync(conn, null); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestVectorDataReadsAsVarchar() + { + float[] data = { 1.1f, 2.2f, 3.3f }; + string dataAsJson = JsonSerializer.Serialize(data); + + using var conn = new SqlConnection(s_connectionString); + conn.Open(); + + //Insert non-null values and validate APIs for reading vector data as varchar(max) + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + var p1 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = dataAsJson }; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + + //Validate Reader + using SqlCommand verifyCommand = new SqlCommand(s_selectCmdString, conn); + var reader = verifyCommand.ExecuteReader(); + Assert.True(reader.Read(), "No data found in the table."); + + //Read using GetString + string result = reader.GetString(0); + float[] dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetSqlString + result = reader.GetSqlString(0).Value; + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetValue.ToString() + result = reader.GetValue(0).ToString()!; + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetFieldValue + result = reader.GetFieldValue(0); + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + reader.Close(); + + // Validate For Null Value + insertCmd.Parameters.Clear(); + p1.Value = DBNull.Value; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + + //Validate Reader for null value + reader = verifyCommand.ExecuteReader(); + Assert.True(reader.Read(), "No data found in the table."); + + //Read using GetString + Assert.Throws(() => reader.GetString(0)); + + //Read using GetSqlString + Assert.Throws(() => reader.GetString(0)); + + //Read using GetValue.ToString() + result = reader.GetValue(0).ToString(); + Assert.Equal(string.Empty, result); + + //Read using GetFieldValue + Assert.Throws(() => reader.GetFieldValue(0)); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public async Task TestVectorDataReadsAsVarcharAsync() + { + float[] data = { 1.1f, 2.2f, 3.3f }; + string dataAsJson = JsonSerializer.Serialize(data); + + using var conn = new SqlConnection(s_connectionString); + await conn.OpenAsync(); + + //Insert non-null values and validate APIs for reading vector data as varchar(max) + using var insertCmd = new SqlCommand(s_insertCmdString, conn); + var p1 = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = dataAsJson }; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + + //Validate Reader + using SqlCommand verifyCommand = new SqlCommand(s_selectCmdString, conn); + using var reader = await verifyCommand.ExecuteReaderAsync(); + Assert.True(await reader.ReadAsync(), "No data found in the table."); + + //Read using GetString + string result = reader.GetString(0); + float[] dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetSqlString + result = reader.GetSqlString(0).Value; + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetValue.ToString() + result = reader.GetValue(0).ToString()!; + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + //Read using GetFieldValue + result = await reader.GetFieldValueAsync(0); + dbData = JsonSerializer.Deserialize(result)!; + Assert.Equal(data, dbData); + + reader.Close(); + + // Validate For Null Value + insertCmd.Parameters.Clear(); + p1.Value = DBNull.Value; + insertCmd.Parameters.Add(p1); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + + //Validate Reader for null value + var reader2 = await verifyCommand.ExecuteReaderAsync(); + Assert.True(await reader2.ReadAsync(), "No data found in the table."); + + //Read using GetString + Assert.Throws(() => reader2.GetString(0)); + + //Read using GetSqlString + Assert.Throws(() => reader2.GetString(0)); + + //Read using GetValue.ToString() + result = reader2.GetValue(0).ToString(); + Assert.Equal(string.Empty, result); + + //Read using GetFieldValueAsync + await Assert.ThrowsAsync(async () => await reader2.GetFieldValueAsync(0)); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestStoredProcParamsForVectorAsVarchar() + { + // Test data + float[] data = { 7.1f, 8.2f, 9.3f }; + string dataAsJson = JsonSerializer.Serialize(data); + + //Create SP for test + using var conn = new SqlConnection(s_connectionString); + conn.Open(); + DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); + using var command = new SqlCommand(s_storedProcName, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + var inputParam = new SqlParameter("@InputVectorJson", SqlDbType.VarChar, -1); + inputParam.Value = dataAsJson; + command.Parameters.Add(inputParam); + var outputParam = new SqlParameter("@OutputVectorJson", SqlDbType.VarChar, -1) + { + Direction = ParameterDirection.Output + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + command.ExecuteNonQuery(); + + // Validate the output parameter + var dbDataAsJson = outputParam.Value as string; + float[] dbData = JsonSerializer.Deserialize(dbDataAsJson)!; + Assert.NotNull(dbDataAsJson); + Assert.Equal(data, dbData); + + // Test with null value + command.Parameters.Clear(); + inputParam.Value = DBNull.Value; + command.Parameters.Add(inputParam); + command.Parameters.Add(outputParam); + command.ExecuteNonQuery(); + + // Validate output paramter for null value + Assert.True(outputParam.Value == DBNull.Value); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public async Task TestStoredProcParamsForVectorAsVarcharAsync() + { + // Test data + float[] data = { 7.1f, 8.2f, 9.3f }; + string dataAsJson = JsonSerializer.Serialize(data); + + // Create SP for test + using var conn = new SqlConnection(s_connectionString); + await conn.OpenAsync(); + DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); + + using var command = new SqlCommand(s_storedProcName, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + var inputParam = new SqlParameter("@InputVectorJson", SqlDbType.VarChar, -1) + { + Value = dataAsJson + }; + command.Parameters.Add(inputParam); + + var outputParam = new SqlParameter("@OutputVectorJson", SqlDbType.VarChar, -1) + { + Direction = ParameterDirection.Output + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + await command.ExecuteNonQueryAsync(); + + // Validate the output parameter + var dbDataAsJson = outputParam.Value as string; + float[] dbData = JsonSerializer.Deserialize(dbDataAsJson)!; + Assert.NotNull(dbDataAsJson); + Assert.Equal(data, dbData); + + // Test with null value + command.Parameters.Clear(); + inputParam.Value = DBNull.Value; + command.Parameters.Add(inputParam); + command.Parameters.Add(outputParam); + + await command.ExecuteNonQueryAsync(); + + // Validate output parameter for null value + Assert.True(outputParam.Value == DBNull.Value); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestSqlBulkCopyForVectorAsVarchar() + { + //Setup source with test data and create destination table for bulkcopy. + SqlConnection sourceConnection = new SqlConnection(s_connectionString); + sourceConnection.Open(); + SqlConnection destinationConnection = new SqlConnection(s_connectionString); + destinationConnection.Open(); + float[] testData = { 1.1f, 2.2f, 3.3f }; + string testDataAsJson = JsonSerializer.Serialize(testData); + using var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); + var varcharVectorParam = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = testDataAsJson }; + insertCmd.Parameters.Add(varcharVectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + varcharVectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(varcharVectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + + //Bulkcopy from sql server table to destination table + using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); + using SqlDataReader reader = sourceDataCommand.ExecuteReader(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) + { + DestinationTableName = s_tableName, + }; + + try + { + // Perform bulk copy from source to destination table + bulkCopy.WriteToServer(reader); + } + catch (Exception ex) + { + // If bulk copy fails, fail the test with the exception message + Assert.Fail($"Bulk copy failed: {ex.Message}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); + using SqlDataReader verifyReader = verifyCommand.ExecuteReader(); + + // Verify that we have data in the destination table + Assert.True(verifyReader.Read(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.True(!verifyReader.IsDBNull(0), "First row in the table is null."); + Assert.Equal(testData, JsonSerializer.Deserialize(verifyReader.GetString(0))); + + // Verify that we have another row + Assert.True(verifyReader.Read(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(verifyReader.IsDBNull(0)); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public async Task TestSqlBulkCopyForVectorAsVarcharAsync() + { + //Setup source with test data and create destination table for bulkcopy. + SqlConnection sourceConnection = new SqlConnection(s_connectionString); + await sourceConnection.OpenAsync(); + SqlConnection destinationConnection = new SqlConnection(s_connectionString); + await destinationConnection.OpenAsync(); + float[] testData = { 1.1f, 2.2f, 3.3f }; + string testDataAsJson = JsonSerializer.Serialize(testData); + using var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); + var varcharVectorParam = new SqlParameter(s_vectorParamName, SqlDbType.VarChar, -1) { Value = testDataAsJson }; + insertCmd.Parameters.Add(varcharVectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + varcharVectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(varcharVectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + + //Bulkcopy from sql server table to destination table + using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); + using SqlDataReader reader = await sourceDataCommand.ExecuteReaderAsync(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) + { + DestinationTableName = s_tableName, + }; + + try + { + // Perform bulk copy from source to destination table + await bulkCopy.WriteToServerAsync(reader); + } + catch (Exception ex) + { + // If bulk copy fails, fail the test with the exception message + Assert.Fail($"Bulk copy failed: {ex.Message}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); + using SqlDataReader verifyReader = await verifyCommand.ExecuteReaderAsync(); + + // Verify that we have data in the destination table + Assert.True(await verifyReader.ReadAsync(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.True(!verifyReader.IsDBNull(0), "First row in the table is null."); + Assert.Equal(testData, JsonSerializer.Deserialize(verifyReader.GetString(0))); + + // Verify that we have another row + Assert.True(await verifyReader.ReadAsync(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(await verifyReader.IsDBNullAsync(0)); + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsVectorSupported))] + public void TestInsertVectorsAsVarcharWithPrepare() + { + SqlConnection conn = new SqlConnection(s_connectionString); + conn.Open(); + SqlCommand command = new SqlCommand(s_insertCmdString, conn); + SqlParameter vectorParam = new SqlParameter("@VectorData", SqlDbType.VarChar, -1); + command.Parameters.Add(vectorParam); + command.Prepare(); + for (int i = 0; i < 10; i++) + { + vectorParam.Value = JsonSerializer.Serialize(new float[] { i + 0.1f, i + 0.2f, i + 0.3f }); + command.ExecuteNonQuery(); + } + SqlCommand validateCommand = new SqlCommand($"SELECT VectorData FROM {s_tableName}", conn); + using SqlDataReader reader = validateCommand.ExecuteReader(); + int rowcnt = 0; + while (reader.Read()) + { + float[] expectedData = new float[] { rowcnt + 0.1f, rowcnt + 0.2f, rowcnt + 0.3f }; + float[] dbData = JsonSerializer.Deserialize(reader.GetString(0))!; + Assert.Equal(expectedData, dbData); + rowcnt++; + } + Assert.Equal(10, rowcnt); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs new file mode 100644 index 0000000000..bb691d5589 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SqlVectorTest.cs @@ -0,0 +1,228 @@ + +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Data.SqlTypes; +using Xunit; + +#nullable enable + +namespace Microsoft.Data.SqlClient.Tests; + +public class SqlVectorTest +{ + #region Tests + + [Fact] + public void UnsupportedType() + { + Assert.Throws(() => new SqlVector(5)); + Assert.Throws(() => new SqlVector(5)); + Assert.Throws(() => new SqlVector(5)); + } + + [Fact] + public void Construct_Length_Negative() + { + Assert.Throws(() => new SqlVector(-1)); + } + + [Fact] + public void Construct_Length() + { + var vec = new SqlVector(5); + Assert.True(vec.IsNull); + Assert.Equal(5, vec.Length); + Assert.Equal(28, vec.Size); + // Note that ReadOnlyMemory<> equality checks that both instances point + // to the same memory. We want to check memory content equality, so we + // compare their arrays instead. + Assert.Equal(new ReadOnlyMemory().ToArray(), vec.Memory.ToArray()); + Assert.Equal(Array.Empty(), vec.ToArray()); + Assert.Equal(SQLResource.NullString, vec.ToString()); + + var ivec = vec as ISqlVector; + Assert.Equal(0x00, ivec.ElementType); + Assert.Equal(0x04, ivec.ElementSize); + Assert.Empty(ivec.VectorPayload); + } + + [Fact] + public void Construct_Memory_Empty() + { + SqlVector vec = new(new ReadOnlyMemory()); + Assert.False(vec.IsNull); + Assert.Equal(0, vec.Length); + Assert.Equal(8, vec.Size); + Assert.Equal(new ReadOnlyMemory().ToArray(), vec.Memory.ToArray()); + Assert.Equal(Array.Empty(), vec.ToArray()); + Assert.Equal("[]", vec.ToString()); + + var ivec = vec as ISqlVector; + Assert.Equal(0x00, ivec.ElementType); + Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal( + new byte[] { 0xA9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + ivec.VectorPayload); + } + + [Fact] + public void Construct_Memory() + { + float[] data = [1.1f, 2.2f]; + ReadOnlyMemory memory = new(data); + SqlVector vec = new(memory); + Assert.False(vec.IsNull); + Assert.Equal(2, vec.Length); + Assert.Equal(16, vec.Size); + Assert.Equal(memory.ToArray(), vec.Memory.ToArray()); + Assert.Equal(data, vec.ToArray()); + #if NETFRAMEWORK + Assert.Equal("[1.10000002,2.20000005]", vec.ToString()); + #else + Assert.Equal("[1.1,2.2]", vec.ToString()); + #endif + var ivec = vec as ISqlVector; + Assert.Equal(0x00, ivec.ElementType); + Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal( + MakeTdsPayload( + new byte[] { 0xA9, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00 }, + memory), + ivec.VectorPayload); + } + + [Fact] + public void Construct_Memory_ImplicitConversionFromFloatArray() + { + float[] data = new float[] { 3.3f, 4.4f, 5.5f }; + var vec = new SqlVector(data); + Assert.False(vec.IsNull); + Assert.Equal(3, vec.Length); + Assert.Equal(20, vec.Size); + Assert.Equal(new ReadOnlyMemory(data).ToArray(), vec.Memory.ToArray()); + Assert.Equal(data, vec.ToArray()); + #if NETFRAMEWORK + Assert.Equal("[3.29999995,4.4000001,5.5]", vec.ToString()); + #else + Assert.Equal("[3.3,4.4,5.5]", vec.ToString()); + #endif + + var ivec = vec as ISqlVector; + Assert.Equal(0x00, ivec.ElementType); + Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal( + MakeTdsPayload( + new byte[] { 0xA9, 0x01, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00 }, + data), + ivec.VectorPayload); + } + + [Fact] + public void Construct_Bytes() + { + float[] data = new float[] { 6.6f, 7.7f }; + var bytes = + MakeTdsPayload( + new byte[] { 0xA9, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00 }, + data); + + var vec = new SqlVector(bytes); + Assert.False(vec.IsNull); + Assert.Equal(2, vec.Length); + Assert.Equal(16, vec.Size); + Assert.Equal(new ReadOnlyMemory(data).ToArray(), vec.Memory.ToArray()); + Assert.Equal(data, vec.ToArray()); + #if NETFRAMEWORK + Assert.Equal("[6.5999999,7.69999981]", vec.ToString()); + #else + Assert.Equal("[6.6,7.7]", vec.ToString()); + #endif + + var ivec = vec as ISqlVector; + Assert.Equal(0x00, ivec.ElementType); + Assert.Equal(0x04, ivec.ElementSize); + Assert.Equal(bytes, ivec.VectorPayload); + } + + [Fact] + public void Construct_Bytes_ShortHeader() + { + Assert.Throws(() => + { + new SqlVector(new byte[] { 0xA9, 0x01, 0x00, 0x00 }); + }); + } + + [Fact] + public void Construct_Bytes_UnknownMagic() + { + Assert.Throws(() => + { + new SqlVector( + new byte[] { 0xA8, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }); + }); + } + + [Fact] + public void Construct_Bytes_UnsupportedVersion() + { + Assert.Throws(() => + { + new SqlVector( + new byte[] { 0xA9, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }); + }); + } + + [Fact] + public void Construct_Bytes_TypeMismatch() + { + Assert.Throws(() => + { + new SqlVector( + new byte[] { 0xA9, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00 }); + }); + } + + [Fact] + public void Construct_Bytes_LengthMismatch() + { + // The header indicates 2 elements, but the payload has 3 floats. + var header = new byte[] { 0xA9, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00 }; + var bytes = MakeTdsPayload( + header, + new ReadOnlyMemory(new float[] { 1.1f, 2.2f, 3.3f })); + + Assert.Throws(() => + { + new SqlVector(bytes); + }); + } + + [Fact] + public void Null_Property() + { + Assert.Null(SqlVector.Null); + } + + #endregion + + #region Helpers + + private byte[] MakeTdsPayload(byte[] header, ReadOnlyMemory values) + { + int length = header.Length + (values.Length * sizeof(float)); + byte[] payload = new byte[length]; + header.CopyTo(payload, 0); + for (int i = 0; i < values.Length; i++) + { + var offset = header.Length + (i * sizeof(float)); + BitConverter.GetBytes(values.Span[i]).CopyTo(payload, offset); + } + return payload; + } + + #endregion +}