diff --git a/src/lib/asn1/ASN1Writer.cpp b/src/lib/asn1/ASN1Writer.cpp index 643d79038bca92..7d5a81f4cf1205 100644 --- a/src/lib/asn1/ASN1Writer.cpp +++ b/src/lib/asn1/ASN1Writer.cpp @@ -234,13 +234,13 @@ CHIP_ERROR ASN1Writer::PutBitString(uint8_t unusedBitCount, const uint8_t * enco CHIP_ERROR ASN1Writer::PutBitString(uint8_t unusedBitCount, chip::TLV::TLVReader & tlvReader) { - ReturnErrorCodeIf(IsNullWriter(), CHIP_NO_ERROR); - ByteSpan encodedBits; ReturnErrorOnFailure(tlvReader.Get(encodedBits)); VerifyOrReturnError(CanCastTo(encodedBits.size() + 1), ASN1_ERROR_LENGTH_OVERFLOW); + ReturnErrorCodeIf(IsNullWriter(), CHIP_NO_ERROR); + ReturnErrorOnFailure( EncodeHead(kASN1TagClass_Universal, kASN1UniversalTag_BitString, false, static_cast(encodedBits.size() + 1))); @@ -335,13 +335,13 @@ CHIP_ERROR ASN1Writer::PutValue(uint8_t cls, uint8_t tag, bool isConstructed, co CHIP_ERROR ASN1Writer::PutValue(uint8_t cls, uint8_t tag, bool isConstructed, chip::TLV::TLVReader & tlvReader) { - ReturnErrorCodeIf(IsNullWriter(), CHIP_NO_ERROR); - ByteSpan val; ReturnErrorOnFailure(tlvReader.Get(val)); VerifyOrReturnError(CanCastTo(val.size()), ASN1_ERROR_LENGTH_OVERFLOW); + ReturnErrorCodeIf(IsNullWriter(), CHIP_NO_ERROR); + ReturnErrorOnFailure(EncodeHead(cls, tag, isConstructed, static_cast(val.size()))); WriteData(val.data(), val.size()); diff --git a/src/lib/asn1/tests/TestASN1.cpp b/src/lib/asn1/tests/TestASN1.cpp index 6f4d5fe31d2d05..56899892abcc61 100644 --- a/src/lib/asn1/tests/TestASN1.cpp +++ b/src/lib/asn1/tests/TestASN1.cpp @@ -304,6 +304,17 @@ static void TestASN1_NullWriter(nlTestSuite * inSuite, void * inContext) encodedLen = writer.GetLengthWritten(); NL_TEST_ASSERT(inSuite, encodedLen == 0); + + // Methods that take a reader should still read from it, + // even if the output is suppressed by the null writer. + TLVReader emptyTlvReader; + emptyTlvReader.Init(ByteSpan()); + err = writer.PutBitString(0, emptyTlvReader); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_WRONG_TLV_TYPE); + + emptyTlvReader.Init(ByteSpan()); + err = writer.PutOctetString(kASN1TagClass_ContextSpecific, 123, emptyTlvReader); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_WRONG_TLV_TYPE); } static void TestASN1_ASN1UniversalTime(nlTestSuite * inSuite, void * inContext)