Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
#if NET
using System.Buffers;
using System.Buffers.Text;
#endif
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
#if !NET
using System.Runtime.InteropServices;
#endif
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S3996 // URI properties should not be strings
#pragma warning disable CA1054 // URI-like parameters should not be strings
#pragma warning disable CA1056 // URI-like properties should not be strings
#pragma warning disable CA1307 // Specify StringComparison for clarity

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -70,39 +78,35 @@ public DataContent(Uri uri, string? mediaType = null)
[JsonConstructor]
public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null)
{
// Store and validate the data URI.
_uri = Throw.IfNullOrWhitespace(uri);

if (!uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase))
{
Throw.ArgumentException(nameof(uri), "The provided URI is not a data URI.");
}

// Parse the data URI to extract the data and media type.
_dataUri = DataUriParser.Parse(uri.AsMemory());

// Validate and store the media type.
mediaType ??= _dataUri.MediaType;
if (mediaType is null)
{
mediaType = _dataUri.MediaType;
if (mediaType is null)
{
Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided.");
}
}
else
{
if (mediaType != _dataUri.MediaType)
{
// If the data URI contains a media type that's different from a non-null media type
// explicitly provided, prefer the one explicitly provided as an override.

// Extract the bytes from the data URI and null out the uri.
// Then we'll lazily recreate it later if needed based on the updated media type.
_data = _dataUri.ToByteArray();
_dataUri = null;
_uri = null;
}
Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided.");
}

MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType);

if (!_dataUri.IsBase64 || mediaType != _dataUri.MediaType)
{
// In rare cases, the data URI may contain non-base64 data, in which case we
// want to normalize it to base64. The supplied media type may also be different
// from the one in the data URI. In either case, we extract the bytes from the data URI
// and then throw away the uri; we'll recreate it lazily in the canonical form.
_data = _dataUri.ToByteArray();
_dataUri = null;
_uri = null;
}
}

/// <summary>
Expand Down Expand Up @@ -134,9 +138,8 @@ public DataContent(ReadOnlyMemory<byte> data, string mediaType)

/// <summary>Gets the data URI for this <see cref="DataContent"/>.</summary>
/// <remarks>
/// The returned URI is always a valid URI string, even if the instance was constructed from a <see cref="ReadOnlyMemory{Byte}"/>
/// or from a <see cref="System.Uri"/>. In the case of a <see cref="ReadOnlyMemory{T}"/>, this property returns a data URI containing
/// that data.
/// The returned URI is always a valid data URI string, even if the instance was constructed from a <see cref="ReadOnlyMemory{Byte}"/>
/// or from a <see cref="System.Uri"/>.
/// </remarks>
[StringSyntax(StringSyntaxAttribute.Uri)]
public string Uri
Expand All @@ -145,27 +148,26 @@ public string Uri
{
if (_uri is null)
{
if (_dataUri is null)
{
Debug.Assert(_data is not null, "Expected _data to be initialized.");
_uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(_data.GetValueOrDefault()
#if NET
.Span));
#else
.Span.ToArray()));
#endif
}
else
{
_uri = _dataUri.IsBase64 ?
Debug.Assert(_data is not null, "Expected _data to be initialized.");
ReadOnlyMemory<byte> data = _data.GetValueOrDefault();

#if NET
$"data:{MediaType};base64,{_dataUri.Data.Span}" :
$"data:{MediaType};,{_dataUri.Data.Span}";
char[] array = ArrayPool<char>.Shared.Rent(
"data:".Length + MediaType.Length + ";base64,".Length + Base64.GetMaxEncodedToUtf8Length(data.Length));

bool wrote = array.AsSpan().TryWrite($"data:{MediaType};base64,", out int prefixLength);
wrote |= Convert.TryToBase64Chars(data.Span, array.AsSpan(prefixLength), out int dataLength);
Debug.Assert(wrote, "Expected to successfully write the data URI.");
_uri = array.AsSpan(0, prefixLength + dataLength).ToString();

ArrayPool<char>.Shared.Return(array);
#else
$"data:{MediaType};base64,{_dataUri.Data}" :
$"data:{MediaType};,{_dataUri.Data}";
string base64 = MemoryMarshal.TryGetArray(data, out ArraySegment<byte> segment) ?
Convert.ToBase64String(segment.Array!, segment.Offset, segment.Count) :
Convert.ToBase64String(data.ToArray());

_uri = $"data:{MediaType};base64,{base64}";
#endif
}
}

return _uri;
Expand Down Expand Up @@ -205,6 +207,20 @@ public ReadOnlyMemory<byte> Data
}
}

/// <summary>Gets the data represented by this instance as a Base64 character sequence.</summary>
/// <returns>The base64 representation of the data.</returns>
[JsonIgnore]
public ReadOnlyMemory<char> Base64Data
{
get
{
string uri = Uri;
int pos = uri.IndexOf(',');
Debug.Assert(pos >= 0, "Expected comma to be present in the URI.");
return uri.AsMemory(pos + 1);
}
}

/// <summary>Gets a string representing this instance to display in the debugger.</summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,6 @@ internal static bool IsImageWithSupportedFormat(this AIContent content) =>
(content is UriContent uriContent && IsSupportedImageFormat(uriContent.MediaType)) ||
(content is DataContent dataContent && IsSupportedImageFormat(dataContent.MediaType));

internal static bool IsUriBase64Encoded(this DataContent dataContent)
{
ReadOnlyMemory<char> uri = dataContent.Uri.AsMemory();

int commaIndex = uri.Span.IndexOf(',');
if (commaIndex == -1)
{
return false;
}

ReadOnlyMemory<char> metadata = uri.Slice(0, commaIndex);

bool isBase64Encoded = metadata.Span.EndsWith(";base64".AsSpan(), StringComparison.OrdinalIgnoreCase);
return isBase64Encoded;
}

private static bool IsSupportedImageFormat(string mediaType)
{
// 'image/jpeg' is the official MIME type for JPEG. However, some systems recognize 'image/jpg' as well.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,25 +343,13 @@ IEnumerable<JsonObject> GetContents(ChatMessage message)
}
else if (content is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
{
string url;
if (dataContent.IsUriBase64Encoded())
{
url = dataContent.Uri;
}
else
{
BinaryData imageBytes = BinaryData.FromBytes(dataContent.Data);
string base64ImageData = Convert.ToBase64String(imageBytes.ToArray());
url = $"data:{dataContent.MediaType};base64,{base64ImageData}";
}

yield return new JsonObject
{
["type"] = "image_url",
["image_url"] =
new JsonObject
{
["url"] = url
["url"] = dataContent.Uri
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,7 @@ private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMe
if (item is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
{
IList<string> images = currentTextMessage?.Images ?? [];
images.Add(Convert.ToBase64String(dataContent.Data
#if NET
.Span));
#else
.ToArray()));
#endif
images.Add(dataContent.Base64Data.ToString());

if (currentTextMessage is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Text;
using System.Text.Json;
using Xunit;

Expand Down Expand Up @@ -66,21 +67,27 @@ public void Ctor_ValidMediaType_Roundtrips(string mediaType)
{
var content = new DataContent("data:image/png;base64,aGVsbG8=", mediaType);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());

content = new DataContent("data:,", mediaType);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("", content.Base64Data.ToString());

content = new DataContent("data:text/plain,", mediaType);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("", content.Base64Data.ToString());

content = new DataContent(new Uri("data:text/plain,"), mediaType);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("", content.Base64Data.ToString());

content = new DataContent(new byte[] { 0, 1, 2 }, mediaType);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("AAEC", content.Base64Data.ToString());

content = new DataContent(content.Uri);
Assert.Equal(mediaType, content.MediaType);
Assert.Equal("AAEC", content.Base64Data.ToString());
}

[Fact]
Expand All @@ -91,10 +98,12 @@ public void Ctor_NoMediaType_Roundtrips()
content = new DataContent("data:image/png;base64,aGVsbG8=");
Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri);
Assert.Equal("image/png", content.MediaType);
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());

content = new DataContent(new Uri("data:image/png;base64,aGVsbG8="));
Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri);
Assert.Equal("image/png", content.MediaType);
Assert.Equal("aGVsbG8=", content.Base64Data.ToString());
}

[Fact]
Expand Down Expand Up @@ -128,6 +137,7 @@ public void Deserialize_MatchesExpectedData()

Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri);
Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray());
Assert.Equal("AQIDBA==", content.Base64Data.ToString());
Assert.Equal("application/octet-stream", content.MediaType);

// Uri referenced content-only
Expand All @@ -150,6 +160,7 @@ public void Deserialize_MatchesExpectedData()

Assert.Equal("data:audio/wav;base64,AQIDBA==", content.Uri);
Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray());
Assert.Equal("AQIDBA==", content.Base64Data.ToString());
Assert.Equal("audio/wav", content.MediaType);
Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString());
}
Expand Down Expand Up @@ -224,4 +235,29 @@ public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix)
var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType);
Assert.False(content.HasTopLevelMediaType(prefix));
}

[Fact]
public void Data_Roundtrips()
{
Random rand = new(42);
for (int length = 0; length < 100; length++)
{
byte[] data = new byte[length];
rand.NextBytes(data);

var content = new DataContent(data, "application/octet-stream");
Assert.Equal(data, content.Data.ToArray());
Assert.Equal(Convert.ToBase64String(data), content.Base64Data.ToString());
Assert.Equal($"data:application/octet-stream;base64,{Convert.ToBase64String(data)}", content.Uri);
}
}

[Fact]
public void NonBase64Data_Normalized()
{
var content = new DataContent("data:text/plain,hello world");
Assert.Equal("data:text/plain;base64,aGVsbG8gd29ybGQ=", content.Uri);
Assert.Equal("aGVsbG8gd29ybGQ=", content.Base64Data.ToString());
Assert.Equal("hello world", Encoding.ASCII.GetString(content.Data.ToArray()));
}
}
Loading