Skip to content

[NRBF] Reject orphaned records and invalid references #103632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@ internal enum AllowedRecordTypes : uint
Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,

/// <summary>
/// Any .NET object (a primitive, a reference type, a reference or single null).
/// Any .NET object (a class, primitive type or an array).
/// </summary>
AnyObject = MemberPrimitiveTyped
| ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
| BinaryObjectString
| MemberReference
| ObjectNull,
| MemberReference,

/// <summary>
/// Any .NET object or a reference or a single null.
/// </summary>
AnyObjectOrNullOrReference = AnyObject | ObjectNull | MemberReference,
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ internal static ArraySingleObjectRecord Decode(BinaryReader reader)
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
// An array of objects can contain any Object or multiple nulls.
const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.Nulls;
const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObjectOrNullOrReference | AllowedRecordTypes.Nulls;

return (Allowed, default);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA
if (record is MemberReferenceRecord memberReference)
{
record = memberReference.GetReferencedRecord();

if (record is not BinaryObjectStringRecord)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is no longer needed here, as there is a universal check for all references in RecordMap.Add. The benefit is that we throw as soon as we encounter an invalid reference.

{
ThrowHelper.ThrowInvalidReference();
}
}

if (record is BinaryObjectStringRecord stringRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Formats.Nrbf.Utils;
using System.IO;
using System.Reflection.Metadata;

Expand All @@ -15,16 +16,19 @@ namespace System.Formats.Nrbf;
/// </remarks>
internal sealed class MemberReferenceRecord : SerializationRecord
{
private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordMap)
private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordMap, AllowedRecordTypes referencedRecordType)
{
Reference = reference;
RecordMap = recordMap;
ReferencedRecordType = referencedRecordType;
}

public override SerializationRecordType RecordType => SerializationRecordType.MemberReference;

internal SerializationRecordId Reference { get; }

private AllowedRecordTypes ReferencedRecordType { get; }

private RecordMap RecordMap { get; }

// MemberReferenceRecord has no Id, which makes it impossible to create a cycle
Expand All @@ -35,8 +39,26 @@ private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordM

internal override object? GetValue() => GetReferencedRecord().GetValue();

internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap)
=> new(SerializationRecordId.Decode(reader), recordMap);
internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap, AllowedRecordTypes allowed)
{
SerializationRecordId reference = SerializationRecordId.Decode(reader);

// We were supposed to decode a record of specific type or a reference to it.
// Since a reference was decoded and we don't know when the referenced record will be provided.
// We just store the allowed record type and are going to check it later.
AllowedRecordTypes referencedRecordType = allowed & ~(AllowedRecordTypes.MemberReference | AllowedRecordTypes.Nulls);

return new MemberReferenceRecord(reference, recordMap, referencedRecordType);
}

internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference);

internal void VerifyReferencedRecordType(SerializationRecord serializationRecord)
{
if (((uint)ReferencedRecordType & (1u << (byte)serializationRecord.RecordType)) == 0)
{
// We expected a reference to a record of a different type.
ThrowHelper.ThrowInvalidReference();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
| AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference;
const AllowedRecordTypes ObjectArray = AllowedRecordTypes.ArraySingleObject
| AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference;
const AllowedRecordTypes NonPrimitiveArray = AllowedRecordTypes.BinaryArray
| AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference;

// Every string can be a string, a null or a reference (to a string)
const AllowedRecordTypes Strings = AllowedRecordTypes.BinaryObjectString
Expand All @@ -92,13 +94,53 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
{
BinaryType.Primitive => (default, (PrimitiveType)additionalInfo!),
BinaryType.String => (Strings, default),
BinaryType.Object => (AllowedRecordTypes.AnyObject, default),
BinaryType.Object => (AllowedRecordTypes.AnyObjectOrNullOrReference, default),
BinaryType.StringArray => (StringArray, default),
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.Class => (((ClassTypeInfo)additionalInfo!).TypeName.IsArray ? NonPrimitiveArray : NonSystemClass, default),
BinaryType.SystemClass => (MapSystemClassTypeName((TypeName)additionalInfo!), default),
_ => (ObjectArray, default),
};

static AllowedRecordTypes MapSystemClassTypeName(TypeName typeName)
{
if (!typeName.IsArray)
{
return SystemClass;
}
else if (typeName.IsSZArray)
{
TypeName elementTypeName = typeName.GetElementType();
if (elementTypeName.IsSimple && elementTypeName.FullName.StartsWith("System.", StringComparison.Ordinal))
{
switch (elementTypeName.FullName)
{
case "System.Boolean":
case "System.Byte":
case "System.SByte":
case "System.Char":
case "System.Int16":
case "System.UInt16":
case "System.Int32":
case "System.UInt32":
case "System.Int64":
case "System.UInt64":
case "System.Single":
case "System.Double":
case "System.Decimal":
case "System.DateTime":
case "System.TimeSpan":
// BinaryFormatter should use BinaryType.PrimitiveArray for these primitive types,
// but it uses BinaryType.SystemClass and we need this workaround.
return PrimitiveArray;
default:
break;
}
}
}

return NonPrimitiveArray;
}
}

internal bool ShouldBeRepresentedAsArrayOfClassRecords()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,14 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
Stack<NextInfo> readStack = new();
RecordMap recordMap = new();

// Everything has to start with a header
var header = (SerializedStreamHeaderRecord)DecodeNext(reader, recordMap, AllowedRecordTypes.SerializedStreamHeader, options, out _);
// and can be followed by any Object, BinaryLibrary and a MessageEnd.
const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObject
| AllowedRecordTypes.BinaryLibrary | AllowedRecordTypes.MessageEnd;
// Every NRBF payload has to start with a header
AllowedRecordTypes allowed = AllowedRecordTypes.SerializedStreamHeader;
var header = (SerializedStreamHeaderRecord)DecodeNext(reader, recordMap, allowed, options, out _);

// The root can be any Object or BinaryLibrary, but not a reference.
allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.BinaryLibrary;
SerializationRecord rootRecord = DecodeNext(reader, recordMap, allowed, options, out _);
PushFirstNestedRecordInfo(rootRecord, readStack);

SerializationRecordType recordType;
SerializationRecord nextRecord;
Expand All @@ -184,16 +187,7 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
if (nextInfo.Allowed != AllowedRecordTypes.None)
{
// Decode the next Record
do
{
nextRecord = DecodeNext(reader, recordMap, nextInfo.Allowed, options, out _);
// BinaryLibrary often precedes class records.
// It has been already added to the RecordMap and it must not be added
// to the array record, so simply read next record.
// It's possible to read multiple BinaryLibraryRecord in a row, hence the loop.
}
while (nextRecord is BinaryLibraryRecord);

nextRecord = DecodeNext(reader, recordMap, nextInfo.Allowed, options, out _);
// Handle it:
// - add to the parent records list,
// - push next info if there are remaining nested records to read.
Expand All @@ -210,7 +204,20 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
}
}

nextRecord = DecodeNext(reader, recordMap, Allowed, options, out recordType);
if (recordMap.UnresolvedReferences == 0)
{
// There are no unresolved references, so the End is the only allowed record.
allowed = AllowedRecordTypes.MessageEnd;
}
else
{
// There are unresolved references and we don't know in what order they are going to appear.
// We allow for any Object (which does not include references or nulls).
// The actual type validation is going to be performed by RecordMap.Add.
allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.BinaryLibrary;
}

nextRecord = DecodeNext(reader, recordMap, allowed, options, out recordType, isReferencedRecord: true);
PushFirstNestedRecordInfo(nextRecord, readStack);
}
while (recordType != SerializationRecordType.MessageEnd);
Expand All @@ -220,31 +227,41 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
}

private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap,
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType)
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType, bool isReferencedRecord = false)
{
recordType = reader.ReadSerializationRecordType(allowed);
SerializationRecord? record;

SerializationRecord record = recordType switch
do
{
SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader),
SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader),
SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader),
SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options),
SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options),
SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader),
SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap),
SerializationRecordType.ClassWithMembersAndTypes => ClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.MemberPrimitiveTyped => DecodeMemberPrimitiveTypedRecord(reader),
SerializationRecordType.MemberReference => MemberReferenceRecord.Decode(reader, recordMap),
SerializationRecordType.MessageEnd => MessageEndRecord.Singleton,
SerializationRecordType.ObjectNull => ObjectNullRecord.Instance,
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
};
recordType = reader.ReadSerializationRecordType(allowed);

recordMap.Add(record);
record = recordType switch
{
SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader),
SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader),
SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader),
SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options),
SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options),
SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader),
SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap),
SerializationRecordType.ClassWithMembersAndTypes => ClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.MemberPrimitiveTyped => DecodeMemberPrimitiveTypedRecord(reader),
SerializationRecordType.MemberReference => MemberReferenceRecord.Decode(reader, recordMap, allowed),
SerializationRecordType.MessageEnd => MessageEndRecord.Singleton,
SerializationRecordType.ObjectNull => ObjectNullRecord.Instance,
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
};

recordMap.Add(record, isReferencedRecord);

// BinaryLibrary often precedes class records.
// It has been already added to the RecordMap and it must not be added
// to the array record or class member values, so simply read next record.
// It's possible to read multiple BinaryLibraryRecord in a row, hence the loop.
} while (recordType == SerializationRecordType.BinaryLibrary);

return record;
}
Expand Down
Loading
Loading