Skip to content

Commit

Permalink
[release/9.0] Generate correct calls for complex and primitive proper…
Browse files Browse the repository at this point in the history
…ties in the model snapshot (#34587)

Fixes #34578
  • Loading branch information
AndriySvyryd authored Sep 3, 2024
1 parent 972a50a commit 1582f80
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ protected virtual void GenerateProperty(
var clrType = (FindValueConverter(property)?.ProviderClrType ?? property.ClrType)
.MakeNullable(property.IsNullable);

var propertyBuilderName = $"{entityTypeBuilderName}.Property<{Code.Reference(clrType)}>({Code.Literal(property.Name)})";
var propertyCall = property.IsPrimitiveCollection ? "PrimitiveCollection" : "Property";
var propertyBuilderName = $"{entityTypeBuilderName}.{propertyCall}<{Code.Reference(clrType)}>({Code.Literal(property.Name)})";

stringBuilder
.AppendLine()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,42 @@ private static readonly MethodInfo PropertyIsSparseMethodInfo
= typeof(SqlServerPropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerPropertyBuilderExtensions.IsSparse), [typeof(PropertyBuilder), typeof(bool)])!;

private static readonly MethodInfo PrimitiveCollectionIsSparseMethodInfo
= typeof(SqlServerPrimitiveCollectionBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerPrimitiveCollectionBuilderExtensions.IsSparse), [typeof(PrimitiveCollectionBuilder), typeof(bool)])!;

private static readonly MethodInfo ComplexTypePropertyIsSparseMethodInfo
= typeof(SqlServerComplexTypePropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerComplexTypePropertyBuilderExtensions.IsSparse), [typeof(ComplexTypePropertyBuilder), typeof(bool)])!;

private static readonly MethodInfo ComplexTypePrimitiveCollectionIsSparseMethodInfo
= typeof(SqlServerComplexTypePrimitiveCollectionBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerComplexTypePrimitiveCollectionBuilderExtensions.IsSparse), [typeof(ComplexTypePrimitiveCollectionBuilder), typeof(bool)])!;

private static readonly MethodInfo PropertyUseIdentityColumnsMethodInfo
= typeof(SqlServerPropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerPropertyBuilderExtensions.UseIdentityColumn), [typeof(PropertyBuilder), typeof(long), typeof(int)])!;

private static readonly MethodInfo ComplexTypePropertyUseIdentityColumnsMethodInfo
= typeof(SqlServerComplexTypePropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerComplexTypePropertyBuilderExtensions.UseIdentityColumn), [typeof(ComplexTypePropertyBuilder), typeof(long), typeof(int)])!;

private static readonly MethodInfo PropertyUseHiLoMethodInfo
= typeof(SqlServerPropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerPropertyBuilderExtensions.UseHiLo), [typeof(PropertyBuilder), typeof(string), typeof(string)])!;

private static readonly MethodInfo ComplexTypePropertyUseHiLoMethodInfo
= typeof(SqlServerComplexTypePropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerComplexTypePropertyBuilderExtensions.UseHiLo), [typeof(ComplexTypePropertyBuilder), typeof(string), typeof(string)])!;

private static readonly MethodInfo PropertyUseSequenceMethodInfo
= typeof(SqlServerPropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerPropertyBuilderExtensions.UseSequence), [typeof(PropertyBuilder), typeof(string), typeof(string)])!;

private static readonly MethodInfo ComplexTypePropertyUseSequenceMethodInfo
= typeof(SqlServerComplexTypePropertyBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerComplexTypePropertyBuilderExtensions.UseSequence), [typeof(ComplexTypePropertyBuilder), typeof(string), typeof(string)])!;

private static readonly MethodInfo IndexIsClusteredMethodInfo
= typeof(SqlServerIndexBuilderExtensions).GetRuntimeMethod(
nameof(SqlServerIndexBuilderExtensions.IsClustered), [typeof(IndexBuilder), typeof(bool)])!;
Expand Down Expand Up @@ -144,7 +168,7 @@ public override IReadOnlyList<MethodCallCodeFragment> GenerateFluentApiCalls(
{
var fragments = new List<MethodCallCodeFragment>(base.GenerateFluentApiCalls(model, annotations));

if (GenerateValueGenerationStrategy(annotations, model, onModel: true) is MethodCallCodeFragment valueGenerationStrategy)
if (GenerateValueGenerationStrategy(annotations, model, onModel: true, complexType: false) is MethodCallCodeFragment valueGenerationStrategy)
{
fragments.Add(valueGenerationStrategy);
}
Expand Down Expand Up @@ -179,18 +203,27 @@ public override IReadOnlyList<MethodCallCodeFragment> GenerateFluentApiCalls(
{
var fragments = new List<MethodCallCodeFragment>(base.GenerateFluentApiCalls(property, annotations));

if (GenerateValueGenerationStrategy(annotations, property.DeclaringType.Model, onModel: false) is MethodCallCodeFragment
var isPrimitiveCollection = property.IsPrimitiveCollection;

if (GenerateValueGenerationStrategy(annotations, property.DeclaringType.Model, onModel: false, complexType: property.DeclaringType is IComplexType) is MethodCallCodeFragment
valueGenerationStrategy)
{
fragments.Add(valueGenerationStrategy);
}

if (GetAndRemove<bool?>(annotations, SqlServerAnnotationNames.Sparse) is bool isSparse)
{
var methodInfo = isPrimitiveCollection
? property.DeclaringType is IComplexType
? ComplexTypePrimitiveCollectionIsSparseMethodInfo
: PrimitiveCollectionIsSparseMethodInfo
: property.DeclaringType is IComplexType
? ComplexTypePropertyIsSparseMethodInfo
: PropertyIsSparseMethodInfo;
fragments.Add(
isSparse
? new MethodCallCodeFragment(PropertyIsSparseMethodInfo)
: new MethodCallCodeFragment(PropertyIsSparseMethodInfo, false));
? new MethodCallCodeFragment(methodInfo)
: new MethodCallCodeFragment(methodInfo, false));
}

return fragments;
Expand Down Expand Up @@ -367,7 +400,8 @@ protected override bool IsHandledByConvention(IProperty property, IAnnotation an
private static MethodCallCodeFragment? GenerateValueGenerationStrategy(
IDictionary<string, IAnnotation> annotations,
IModel model,
bool onModel)
bool onModel,
bool complexType)
{
SqlServerValueGenerationStrategy strategy;
if (annotations.TryGetValue(SqlServerAnnotationNames.ValueGenerationStrategy, out var strategyAnnotation)
Expand Down Expand Up @@ -405,7 +439,11 @@ protected override bool IsHandledByConvention(IProperty property, IAnnotation an
?? model.FindAnnotation(SqlServerAnnotationNames.IdentityIncrement)?.Value as int?
?? 1;
return new MethodCallCodeFragment(
onModel ? ModelUseIdentityColumnsMethodInfo : PropertyUseIdentityColumnsMethodInfo,
onModel
? ModelUseIdentityColumnsMethodInfo
: complexType
? ComplexTypePropertyUseIdentityColumnsMethodInfo
: PropertyUseIdentityColumnsMethodInfo,
(seed, increment) switch
{
(1L, 1) => [],
Expand All @@ -418,7 +456,11 @@ protected override bool IsHandledByConvention(IProperty property, IAnnotation an
var name = GetAndRemove<string>(annotations, SqlServerAnnotationNames.HiLoSequenceName);
var schema = GetAndRemove<string>(annotations, SqlServerAnnotationNames.HiLoSequenceSchema);
return new MethodCallCodeFragment(
onModel ? ModelUseHiLoMethodInfo : PropertyUseHiLoMethodInfo,
onModel
? ModelUseHiLoMethodInfo
: complexType
? ComplexTypePropertyUseHiLoMethodInfo
: PropertyUseHiLoMethodInfo,
(name, schema) switch
{
(null, null) => [],
Expand All @@ -435,7 +477,11 @@ protected override bool IsHandledByConvention(IProperty property, IAnnotation an

var schema = GetAndRemove<string>(annotations, SqlServerAnnotationNames.SequenceSchema);
return new MethodCallCodeFragment(
onModel ? ModelUseKeySequencesMethodInfo : PropertyUseSequenceMethodInfo,
onModel
? ModelUseKeySequencesMethodInfo
: complexType
? ComplexTypePropertyUseSequenceMethodInfo
: PropertyUseSequenceMethodInfo,
(name: nameOrSuffix, schema) switch
{
(null, null) => [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5944,6 +5944,72 @@ public virtual void SQLServer_property_legacy_identity_seed_int_annotation()

#endregion

#region Primitive collection

[ConditionalFact]
public virtual void PrimitiveCollection_is_stored_in_snapshot()
=> Test(
builder =>
{
builder.Entity<EntityWithOneProperty>()
.PrimitiveCollection<List<int>>("List")
.IsSparse()
.IsFixedLength()
.HasMaxLength(100)
.IsUnicode()
.UseCollation("ListCollation")
.HasSentinel([])
.HasColumnName("ListColumn")
.HasColumnType("nvarchar")
.HasColumnOrder(1)
.HasComment("ListComment")
.HasComputedColumnSql("ListSql")
.HasJsonPropertyName("ListJson")
.ElementType(b => b.HasConversion<string>())
.ValueGeneratedOnUpdateSometimes()
.HasAnnotation("AnnotationName", "AnnotationValue");
builder.Ignore<EntityWithTwoProperties>();
},
AddBoilerPlate(
GetHeading()
+ """
modelBuilder.Entity("Microsoft.EntityFrameworkCore.Migrations.Design.CSharpMigrationsGeneratorTest+EntityWithOneProperty", b =>
{
b.Property<int>("Id")
.ValueGeneratedOnAdd()
.HasColumnType("int");

SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property<int>("Id"));

b.PrimitiveCollection<string>("List")
.ValueGeneratedOnUpdateSometimes()
.HasMaxLength(100)
.IsUnicode(true)
.HasColumnType("nvarchar")
.HasColumnName("ListColumn")
.HasColumnOrder(1)
.HasComputedColumnSql("ListSql")
.IsFixedLength()
.HasComment("ListComment")
.UseCollation("ListCollation")
.HasAnnotation("AnnotationName", "AnnotationValue")
.HasAnnotation("Relational:JsonPropertyName", "ListJson");

SqlServerPrimitiveCollectionBuilderExtensions.IsSparse(b.PrimitiveCollection<string>("List"));

b.HasKey("Id");

b.ToTable("EntityWithOneProperty", "DefaultSchema");
});
"""),
o =>
{
var property = o.GetEntityTypes().First().FindProperty("List");
Assert.Equal("AnnotationValue", property["AnnotationName"]);
});
#endregion

#region Complex types

[ConditionalFact]
Expand All @@ -5958,8 +6024,13 @@ public virtual void Complex_properties_are_stored_in_snapshot()
eo => eo.EntityWithTwoProperties, eb =>
{
eb.IsRequired();
eb.Property(e => e.AlternateId).HasColumnOrder(1);
eb.ComplexProperty(e => e.EntityWithStringKey).IsRequired();
eb.Property(e => e.AlternateId).HasColumnOrder(1).IsSparse();
eb.PrimitiveCollection<List<string>>("List")
.HasColumnType("nvarchar(max)")
.IsSparse();
eb.ComplexProperty(e => e.EntityWithStringKey)
.IsRequired()
.Ignore(e => e.Properties);
eb.HasPropertyAnnotation("PropertyAnnotation", 1);
eb.HasTypeAnnotation("TypeAnnotation", 2);
});
Expand All @@ -5984,9 +6055,16 @@ public virtual void Complex_properties_are_stored_in_snapshot()
.HasColumnType("int")
.HasColumnOrder(1);

SqlServerComplexTypePropertyBuilderExtensions.IsSparse(b1.Property<int>("AlternateId"));

b1.Property<int>("Id")
.HasColumnType("int");

b1.PrimitiveCollection<string>("List")
.HasColumnType("nvarchar(max)");

SqlServerComplexTypePrimitiveCollectionBuilderExtensions.IsSparse(b1.PrimitiveCollection<string>("List"));

b1.ComplexProperty<Dictionary<string, object>>("EntityWithStringKey", "Microsoft.EntityFrameworkCore.Migrations.Design.CSharpMigrationsGeneratorTest+EntityWithOneProperty.EntityWithTwoProperties#EntityWithTwoProperties.EntityWithStringKey#EntityWithStringKey", b2 =>
{
b2.IsRequired();
Expand All @@ -6005,7 +6083,7 @@ public virtual void Complex_properties_are_stored_in_snapshot()
b.ToTable("EntityWithOneProperty", "DefaultSchema");
});
""", usingCollections: true),
o =>
(_, o) =>
{
var entityWithOneProperty = o.FindEntityType(typeof(EntityWithOneProperty));
Assert.Equal(nameof(EntityWithOneProperty), entityWithOneProperty.GetTableName());
Expand Down Expand Up @@ -6037,7 +6115,8 @@ public virtual void Complex_properties_are_stored_in_snapshot()
Assert.Equal(nameof(EntityWithOneProperty), nestedComplexType.GetTableName());
var nestedIdProperty = nestedComplexType.FindProperty(nameof(EntityWithStringKey.Id));
Assert.True(nestedIdProperty.IsNullable);
});
},
validate: true);

#endregion

Expand Down Expand Up @@ -7981,7 +8060,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)

SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property<int>("Id"));

b.Property<string>("BoolCollection")
b.PrimitiveCollection<string>("BoolCollection")
.HasColumnType("nvarchar(max)");

b.Property<bool>("Boolean")
Expand All @@ -7993,7 +8072,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)
b.Property<byte[]>("Bytes")
.HasColumnType("varbinary(max)");

b.Property<string>("BytesCollection")
b.PrimitiveCollection<string>("BytesCollection")
.HasColumnType("nvarchar(max)");

b.Property<string>("Character")
Expand All @@ -8003,7 +8082,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)
b.Property<DateTime>("DateTime")
.HasColumnType("datetime2");

b.Property<string>("DateTimeCollection")
b.PrimitiveCollection<string>("DateTimeCollection")
.HasColumnType("nvarchar(max)");

b.Property<DateTimeOffset>("DateTimeOffset")
Expand All @@ -8015,7 +8094,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)
b.Property<double>("Double")
.HasColumnType("float");

b.Property<string>("DoubleCollection")
b.PrimitiveCollection<string>("DoubleCollection")
.HasColumnType("nvarchar(max)");

b.Property<short>("Enum16")
Expand Down Expand Up @@ -8048,7 +8127,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)
b.Property<int>("Int32")
.HasColumnType("int");

b.Property<string>("Int32Collection")
b.PrimitiveCollection<string>("Int32Collection")
.HasColumnType("nvarchar(max)");

b.Property<long>("Int64")
Expand Down Expand Up @@ -8108,7 +8187,7 @@ protected override void BuildModel(ModelBuilder modelBuilder)
b.Property<string>("String")
.HasColumnType("nvarchar(max)");

b.Property<string>("StringCollection")
b.PrimitiveCollection<string>("StringCollection")
.HasColumnType("nvarchar(max)");

b.Property<TimeSpan>("TimeSpan")
Expand Down Expand Up @@ -8403,15 +8482,15 @@ protected override void BuildModel(ModelBuilder modelBuilder)
protected void Test(Action<ModelBuilder> buildModel, string expectedCode, Action<IModel> assert)
=> Test(buildModel, expectedCode, (m, _) => assert(m));

protected void Test(Action<ModelBuilder> buildModel, string expectedCode, Action<IModel, IModel> assert)
protected void Test(Action<ModelBuilder> buildModel, string expectedCode, Action<IModel, IModel> assert, bool validate = false)
{
var modelBuilder = CreateConventionalModelBuilder();
modelBuilder.HasDefaultSchema("DefaultSchema");
modelBuilder.HasChangeTrackingStrategy(ChangeTrackingStrategy.Snapshot);
modelBuilder.Model.RemoveAnnotation(CoreAnnotationNames.ProductVersion);
buildModel(modelBuilder);

var model = modelBuilder.FinalizeModel(designTime: true, skipValidation: true);
var model = modelBuilder.FinalizeModel(designTime: true, skipValidation: !validate);

Test(model, expectedCode, assert);
}
Expand Down

0 comments on commit 1582f80

Please sign in to comment.