Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Config generator test only for down-casts #106145

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol k
CollectionInstantiationStrategy instantiationStrategy;
CollectionInstantiationConcreteType instantiationConcreteType;
CollectionPopulationCastType populationCastType;
bool shouldTryCast = false;

if (HasPublicParameterLessCtor(type))
{
Expand All @@ -403,6 +404,7 @@ private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol k
}
else if (_typeSymbols.GenericIDictionary is not null && GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) is not null)
{
// implements IDictionary<,> -- cast to it.
populationCastType = CollectionPopulationCastType.IDictionary;
}
else
Expand All @@ -421,7 +423,9 @@ private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol k
{
instantiationStrategy = CollectionInstantiationStrategy.LinqToDictionary;
instantiationConcreteType = CollectionInstantiationConcreteType.Dictionary;
// is IReadonlyDictionary<,> -- test cast to IDictionary<,>
populationCastType = CollectionPopulationCastType.IDictionary;
shouldTryCast = true;
}
else
{
Expand All @@ -431,13 +435,15 @@ private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol k
TypeRef keyTypeRef = EnqueueTransitiveType(typeParseInfo, keyTypeSymbol, DiagnosticDescriptors.DictionaryKeyNotSupported);
TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementTypeSymbol, DiagnosticDescriptors.ElementTypeNotSupported);

Debug.Assert(!shouldTryCast || !type.IsValueType, "Should not test cast for value types.");
return new DictionarySpec(type)
{
KeyTypeRef = keyTypeRef,
ElementTypeRef = elementTypeRef,
InstantiationStrategy = instantiationStrategy,
InstantiationConcreteType = instantiationConcreteType,
PopulationCastType = populationCastType,
ShouldTryCast = shouldTryCast
};
}

Expand All @@ -458,6 +464,7 @@ private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo)
CollectionInstantiationStrategy instantiationStrategy;
CollectionInstantiationConcreteType instantiationConcreteType;
CollectionPopulationCastType populationCastType;
bool shouldTryCast = false;

if (HasPublicParameterLessCtor(type))
{
Expand All @@ -470,6 +477,7 @@ private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo)
}
else if (_typeSymbols.GenericICollection is not null && GetInterface(type, _typeSymbols.GenericICollection_Unbound) is not null)
{
// implements ICollection<> -- cast to it
populationCastType = CollectionPopulationCastType.ICollection;
}
else
Expand All @@ -487,7 +495,9 @@ private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo)
{
instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor;
instantiationConcreteType = CollectionInstantiationConcreteType.List;
// is IEnumerable<> -- test cast to ICollection<>
populationCastType = CollectionPopulationCastType.ICollection;
shouldTryCast = true;
}
else if (IsInterfaceMatch(type, _typeSymbols.ISet_Unbound))
{
Expand All @@ -499,13 +509,17 @@ private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo)
{
instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor;
instantiationConcreteType = CollectionInstantiationConcreteType.HashSet;
// is IReadOnlySet<> -- test cast to ISet<>
populationCastType = CollectionPopulationCastType.ISet;
shouldTryCast = true;
}
else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlyList_Unbound) || IsInterfaceMatch(type, _typeSymbols.IReadOnlyCollection_Unbound))
{
instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor;
instantiationConcreteType = CollectionInstantiationConcreteType.List;
// is IReadOnlyList<> or IReadOnlyCollection<> -- test cast to ICollection<>
populationCastType = CollectionPopulationCastType.ICollection;
shouldTryCast = true;
}
else
{
Expand All @@ -514,12 +528,14 @@ private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo)

TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementType, DiagnosticDescriptors.ElementTypeNotSupported);

Debug.Assert(!shouldTryCast || !type.IsValueType, "Should not test cast for value types.");
return new EnumerableSpec(type)
{
ElementTypeRef = elementTypeRef,
InstantiationStrategy = instantiationStrategy,
InstantiationConcreteType = instantiationConcreteType,
PopulationCastType = populationCastType,
ShouldTryCast = shouldTryCast
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1230,15 +1230,25 @@ private void EmitCollectionCastIfRequired(CollectionWithCtorInitSpec type, out s
return;
}

string castTypeDisplayString = TypeIndex.GetPopulationCastTypeDisplayString(type);
instanceIdentifier = Identifier.temp;
string castExpression = $"{TypeIndex.GetPopulationCastTypeDisplayString(type)} {instanceIdentifier}";

if (type.ShouldTryCast)
{
_writer.WriteLine($$"""
if ({{Identifier.instance}} is not {{castExpression}})
{
return;
}
""");
}
else
{
_writer.WriteLine($$"""
{{castExpression}} = {{Identifier.instance}};
""");
}

_writer.WriteLine($$"""
if ({{Identifier.instance}} is not {{castTypeDisplayString}} {{instanceIdentifier}})
{
return;
}
""");
_writer.WriteLine();

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ protected CollectionWithCtorInitSpec(ITypeSymbol type) : base(type) { }
public required CollectionInstantiationConcreteType InstantiationConcreteType { get; init; }

public required CollectionPopulationCastType PopulationCastType { get; init; }

public required bool ShouldTryCast { get; init; }
}

internal sealed record ArraySpec : CollectionSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,83 @@ public void TestCollectionWithNullOrEmptyItems()
Assert.Equal("System.Boolean", result[0].Elements[1].Type);
}

[Fact]
public void TestStringValues()
{
// StringValues is a struct that implements IList<string> -- though it doesn't actually support Add

var dic = new Dictionary<string, string>
{
{"StringValues:0", "Yo1"},
{"StringValues:1", "Yo2"},
};
var configurationBuilder = new ConfigurationBuilder();
configurationBuilder.AddInMemoryCollection(dic);

var config = configurationBuilder.Build();

var options = new OptionsWithStructs();

#if BUILDING_SOURCE_GENERATOR_TESTS
Assert.Throws<NotSupportedException>(() => config.Bind(options));
#else
Assert.Throws<InvalidOperationException>(() => config.Bind(options, (bo) => bo.ErrorOnUnknownConfiguration = true));
#endif
}

[Fact]
public void TestOptionsWithStructs()
{
var dic = new Dictionary<string, string>
{
{"CollectionStructExplicit:0", "cs1"},
{"CollectionStructExplicit:1", "cs2"},
{"DictionaryStructExplicit:k0", "ds1"},
{"DictionaryStructExplicit:k1", "ds2"},
};
var configurationBuilder = new ConfigurationBuilder();
configurationBuilder.AddInMemoryCollection(dic);

var config = configurationBuilder.Build();

var options = new OptionsWithStructs();
config.Bind(options);

ICollection<string> collection = options.CollectionStructExplicit;
Assert.Equal(2, collection.Count);
Assert.Equal(collection, ["cs1", "cs2"]);

IDictionary<string, string> dictionary = options.DictionaryStructExplicit;
Assert.Equal(2, dictionary.Count);
Assert.Equal("ds1", dictionary["k0"]);
Assert.Equal("ds2", dictionary["k1"]);
}

[Fact]
public void TestOptionsWithUnsupportedStructs()
{
var dic = new Dictionary<string, string>
{
{"ReadOnlyCollectionStructExplicit:0", "cs1"},
{"ReadOnlyCollectionStructExplicit:1", "cs2"},
{"ReadOnlyDictionaryStructExplicit:k0", "ds1"},
{"ReadOnlyDictionaryStructExplicit:k1", "ds2"},
};
var configurationBuilder = new ConfigurationBuilder();
configurationBuilder.AddInMemoryCollection(dic);

var config = configurationBuilder.Build();

var options = new OptionsWithUnsupportedStructs();
config.Bind(options);

IReadOnlyCollection<string> collection = options.ReadOnlyCollectionStructExplicit;
Assert.Equal(0, collection.Count);

IReadOnlyDictionary<string, string> dictionary = options.ReadOnlyDictionaryStructExplicit;
Assert.Equal(0, dictionary.Count);
}

// Test behavior for root level arrays.

// Tests for TypeConverter usage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Extensions
#if BUILDING_SOURCE_GENERATOR_TESTS
Expand Down Expand Up @@ -388,5 +389,123 @@ public class Element
{
public string Type { get; set; }
}

public class OptionsWithStructs
{
public StringValues StringValues { get; set; }

public CollectionStructExplicit CollectionStructExplicit { get; set; } = new();

public DictionaryStructExplicit DictionaryStructExplicit { get; set; } = new();
}

public struct CollectionStructExplicit : ICollection<string>
{
public CollectionStructExplicit() {}

ICollection<string> _collection = new List<string>();

int ICollection<string>.Count => _collection.Count;

bool ICollection<string>.IsReadOnly => _collection.IsReadOnly;

void ICollection<string>.Add(string item) => _collection.Add(item);

void ICollection<string>.Clear() => _collection.Clear();

bool ICollection<string>.Contains(string item) => _collection.Contains(item);

void ICollection<string>.CopyTo(string[] array, int arrayIndex) => _collection.CopyTo(array, arrayIndex);

IEnumerator<string> IEnumerable<string>.GetEnumerator() => _collection.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_collection).GetEnumerator();

bool ICollection<string>.Remove(string item) => _collection.Remove(item);
}

public struct DictionaryStructExplicit : IDictionary<string, string>
ericstj marked this conversation as resolved.
Show resolved Hide resolved
{
public DictionaryStructExplicit() {}

IDictionary<string, string> _dictionary = new Dictionary<string, string>();

string IDictionary<string, string>.this[string key] { get => _dictionary[key]; set => _dictionary[key] = value; }

ICollection<string> IDictionary<string, string>.Keys => _dictionary.Keys;

ICollection<string> IDictionary<string, string>.Values => _dictionary.Values;

int ICollection<KeyValuePair<string, string>>.Count => _dictionary.Count;

bool ICollection<KeyValuePair<string, string>>.IsReadOnly => _dictionary.IsReadOnly;

void IDictionary<string, string>.Add(string key, string value) => _dictionary.Add(key, value);

void ICollection<KeyValuePair<string, string>>.Add(KeyValuePair<string, string> item) => _dictionary.Add(item);

void ICollection<KeyValuePair<string, string>>.Clear() => _dictionary.Clear();

bool ICollection<KeyValuePair<string, string>>.Contains(KeyValuePair<string, string> item) => _dictionary.Contains(item);

bool IDictionary<string, string>.ContainsKey(string key) => _dictionary.ContainsKey(key);

void ICollection<KeyValuePair<string, string>>.CopyTo(KeyValuePair<string, string>[] array, int arrayIndex) => _dictionary.CopyTo(array, arrayIndex);

IEnumerator<KeyValuePair<string, string>> IEnumerable<KeyValuePair<string, string>>.GetEnumerator() => _dictionary.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_dictionary).GetEnumerator();

bool IDictionary<string, string>.Remove(string key) => _dictionary.Remove(key);

bool ICollection<KeyValuePair<string, string>>.Remove(KeyValuePair<string, string> item) => _dictionary.Remove(item);

bool IDictionary<string, string>.TryGetValue(string key, out string value) => _dictionary.TryGetValue(key, out value);
}

public class OptionsWithUnsupportedStructs
{
public ReadOnlyCollectionStructExplicit ReadOnlyCollectionStructExplicit { get; set; } = new();

public ReadOnlyDictionaryStructExplicit ReadOnlyDictionaryStructExplicit { get; set; } = new();
}

public struct ReadOnlyCollectionStructExplicit : IReadOnlyCollection<string>
{
public ReadOnlyCollectionStructExplicit()
{
_collection = new List<string>();
}

private readonly IReadOnlyCollection<string> _collection;
int IReadOnlyCollection<string>.Count => _collection.Count;
IEnumerator<string> IEnumerable<string>.GetEnumerator() => _collection.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_collection).GetEnumerator();
}

public struct ReadOnlyDictionaryStructExplicit : IReadOnlyDictionary<string, string>
{
public ReadOnlyDictionaryStructExplicit()
{
_dictionary = new Dictionary<string, string>();
}

private readonly IReadOnlyDictionary<string, string> _dictionary;
string IReadOnlyDictionary<string, string>.this[string key] => _dictionary[key];

IEnumerable<string> IReadOnlyDictionary<string, string>.Keys => _dictionary.Keys;

IEnumerable<string> IReadOnlyDictionary<string, string>.Values => _dictionary.Values;

int IReadOnlyCollection<KeyValuePair<string, string>>.Count => _dictionary.Count;

bool IReadOnlyDictionary<string, string>.ContainsKey(string key) => _dictionary.ContainsKey(key);

IEnumerator<KeyValuePair<string, string>> IEnumerable<KeyValuePair<string, string>>.GetEnumerator() => _dictionary.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_dictionary).GetEnumerator();

bool IReadOnlyDictionary<string, string>.TryGetValue(string key, out string value) => _dictionary.TryGetValue(key, out value);
}
}
}
Loading
Loading