Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solving problems with Header Union verify function (reopened) #3214

Merged
merged 2 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 63 additions & 51 deletions midend/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,33 +310,45 @@ void SymbolicStruct::dbprint(std::ostream& out) const {
}

SymbolicHeaderUnion::SymbolicHeaderUnion(const IR::Type_HeaderUnion* type,
bool uninitialized,
const SymbolicValueFactory* factory) :
SymbolicStruct(type, uninitialized, factory),
valid(new SymbolicBool(false)) {}

void SymbolicHeaderUnion::setValid(bool v) {
if (!v)
setAllUnknown();
valid = new SymbolicBool(v);
bool uninitialized,
const SymbolicValueFactory* factory) :
SymbolicStruct(type, uninitialized, factory) {}

SymbolicBool* SymbolicHeaderUnion::isValid() const {
int validFields = 0;
for (auto f : type->to<IR::Type_StructLike>()->fields) {
if (fieldValue.count(f->name.name)) {
auto fieldValid = fieldValue.at(f->name.name)->checkedTo<SymbolicHeader>()->valid;
if (!fieldValid->isKnown() || fieldValid->isUninitialized()) {
return fieldValid;
} else if (fieldValid->value) {
validFields +=1;
}
} else {
BUG("The number of fields in %1% is different from HeaderUnion fieldValue", type);
}
}
if (validFields == 1) {
return new SymbolicBool(true);
} else if (validFields > 1) {
BUG("In HeaderUnion cannot be more than one valid field");
}
return new SymbolicBool(false);
}

SymbolicValue* SymbolicHeaderUnion::get(const IR::Node* node, cstring field) const {
if (valid->isKnown() && !valid->value)
return new SymbolicStaticError(node, "Reading field from invalid header union");
return SymbolicStruct::get(node, field);
}

void SymbolicHeaderUnion::setAllUnknown() {
SymbolicStruct::setAllUnknown();
valid->setAllUnknown();
this->isValid()->setAllUnknown();
}

SymbolicValue* SymbolicHeaderUnion::clone() const {
auto result = new SymbolicHeaderUnion(type->to<IR::Type_HeaderUnion>());
for (auto f : fieldValue)
result->fieldValue[f.first] = f.second->clone();
result->valid = valid->clone()->to<SymbolicBool>();
return result;
}

Expand All @@ -346,7 +358,6 @@ void SymbolicHeaderUnion::assign(const SymbolicValue* other) {
BUG_CHECK(hv, "%1%: expected a header union", other);
for (auto f : hv->fieldValue)
fieldValue[f.first]->assign(f.second);
valid->assign(hv->valid);
}

bool SymbolicHeaderUnion::merge(const SymbolicValue* other) {
Expand All @@ -355,26 +366,17 @@ bool SymbolicHeaderUnion::merge(const SymbolicValue* other) {
bool changes = false;
for (auto f : hv->fieldValue)
changes = changes || fieldValue[f.first]->merge(f.second);
changes = changes || valid->merge(hv->valid);
return changes;
}

bool SymbolicHeaderUnion::equals(const SymbolicValue* other) const {
if (!other->is<SymbolicHeaderUnion>())
return false;
auto sh = other->to<SymbolicHeaderUnion>();
if (!valid->equals(sh->valid))
return false;
if (valid->isKnown() && !valid->value)
// Invalid headers are equal
return true;
return SymbolicStruct::equals(other);
}

void SymbolicHeaderUnion::dbprint(std::ostream& out) const {
out << "{ ";
out << "valid=>";
valid->dbprint(out);
#if 0
for (auto f : fieldValue) {
out << ", ";
Expand Down Expand Up @@ -417,11 +419,15 @@ SymbolicValue* SymbolicHeader::clone() const {

void SymbolicHeader::assign(const SymbolicValue* other) {
if (other->is<SymbolicError>()) return;
BUG_CHECK(other->is<SymbolicHeader>(), "%1%: expected a header", other);
auto hv = other->to<SymbolicHeader>();
for (auto f : hv->fieldValue)
fieldValue[f.first]->assign(f.second);
valid->assign(hv->valid);
BUG_CHECK(other->is<SymbolicStruct>() , "%1%: expected a struct", other);
if (auto hv = other->to<SymbolicStruct>()) {
for (auto f : hv->fieldValue)
fieldValue[f.first]->assign(f.second);
}
if (auto hv = other->to<SymbolicHeader>())
valid->assign(hv->valid);
else
valid->assign(new SymbolicBool(true));
}

bool SymbolicHeader::merge(const SymbolicValue* other) {
Expand Down Expand Up @@ -486,9 +492,6 @@ void SymbolicArray::shift(int amount) {
if (values[i]->is<SymbolicHeader>()) {
values[i]->to<SymbolicHeader>()->setValid(false);
}
if (values[i]->is<SymbolicHeaderUnion>()) {
values[i]->to<SymbolicHeaderUnion>()->setValid(false);
}
}
} else if (amount > 0) {
for (unsigned i = 0; i < values.size() - amount; i++)
Expand All @@ -497,9 +500,6 @@ void SymbolicArray::shift(int amount) {
if (values[i]->is<SymbolicHeader>()) {
values[i]->to<SymbolicHeader>()->setValid(false);
}
if (values[i]->is<SymbolicHeaderUnion>()) {
values[i]->to<SymbolicHeaderUnion>()->setValid(false);
}
}
}
}
Expand All @@ -515,10 +515,10 @@ SymbolicValue* SymbolicArray::next(const IR::Node* node) {
return v;
}
if (values[i]->is<SymbolicHeaderUnion>()) {
if (v->to<SymbolicHeaderUnion>()->valid->isUnknown() ||
v->to<SymbolicHeaderUnion>()->valid->isUninitialized())
if (v->to<SymbolicHeaderUnion>()->isValid()->isUnknown() ||
v->to<SymbolicHeaderUnion>()->isValid()->isUninitialized())
return new AnyElement(this);
if (!v->to<SymbolicHeaderUnion>()->valid->value)
if (!v->to<SymbolicHeaderUnion>()->isValid()->value)
return v;
}
}
Expand All @@ -538,10 +538,10 @@ SymbolicValue* SymbolicArray::lastIndex(const IR::Node* node) {
}

if (values[i]->is<SymbolicHeaderUnion>()) {
if (v->to<SymbolicHeaderUnion>()->valid->isUnknown() ||
v->to<SymbolicHeaderUnion>()->valid->isUninitialized())
if (v->to<SymbolicHeaderUnion>()->isValid()->isUnknown() ||
v->to<SymbolicHeaderUnion>()->isValid()->isUninitialized())
return new AnyElement(this);
if (v->to<SymbolicHeaderUnion>()->valid->value)
if (v->to<SymbolicHeaderUnion>()->isValid()->value)
return new SymbolicInteger(new IR::Constant(IR::Type_Bits::get(32), index));
}
}
Expand All @@ -560,10 +560,10 @@ SymbolicValue* SymbolicArray::last(const IR::Node* node) {
return v;
}
if (values[i]->is<SymbolicHeaderUnion>()) {
if (v->to<SymbolicHeaderUnion>()->valid->isUnknown() ||
v->to<SymbolicHeaderUnion>()->valid->isUninitialized())
if (v->to<SymbolicHeaderUnion>()->isValid()->isUnknown() ||
v->to<SymbolicHeaderUnion>()->isValid()->isUninitialized())
return new AnyElement(this);
if (v->to<SymbolicHeaderUnion>()->valid->value)
if (v->to<SymbolicHeaderUnion>()->isValid()->value)
return v;
}
}
Expand Down Expand Up @@ -1108,10 +1108,20 @@ void ExpressionEvaluator::postorder(const IR::MethodCallExpression* expression)
auto bim = mi->to<BuiltInMethod>();
auto base = get(bim->appliedTo);
cstring name = bim->name.name;
// Needed to get Header from HeaderUnion
const auto node = expression->method->checkedTo<IR::Member>()->expr;
CHECK_NULL(node);
auto structVar = get(node);
if (name == IR::Type_Header::setInvalid ||
name == IR::Type_Header::setValid) {
BUG_CHECK(base->is<SymbolicHeader>(), "%1%: expected a header", base);
auto hv = base->to<SymbolicHeader>();
auto hv = structVar->checkedTo<SymbolicHeader>();
if (auto member = node->to<IR::Member>()) {
if (auto hu = get(member->expr)->to<SymbolicHeaderUnion>()) {
if (hu->isValid()) {
hu->setAllUnknown();
}
}
}
hv->setValid(name == IR::Type_Header::setValid);
set(expression, SymbolicVoid::get());
return;
Expand All @@ -1137,11 +1147,13 @@ void ExpressionEvaluator::postorder(const IR::MethodCallExpression* expression)
} else {
BUG_CHECK(name == IR::Type_Header::isValid,
"%1%: unexpected method", bim->name);
BUG_CHECK(base->is<SymbolicHeader>(), "%1%: expected a header", base);
auto hv = base->to<SymbolicHeader>();
auto v = hv->valid;
set(expression, v);
return;
if (auto hv = structVar->to<SymbolicHeader>()) {
auto v = hv->valid;
set(expression, v);
return;
} else {
BUG("Unexpected expression (%1%) type: %2%", base, base->type);
}
}
}

Expand Down Expand Up @@ -1253,4 +1265,4 @@ SymbolicValue* ExpressionEvaluator::evaluate(const IR::Expression* expression, b
auto result = get(expression);
return result;
}
} // namespace P4
} // namespace P4
5 changes: 2 additions & 3 deletions midend/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,9 @@ class SymbolicHeader : public SymbolicStruct {
class SymbolicHeaderUnion : public SymbolicStruct {
public:
explicit SymbolicHeaderUnion(const IR::Type_HeaderUnion* type) : SymbolicStruct(type) {}
SymbolicBool* valid = nullptr;
SymbolicHeaderUnion(const IR::Type_HeaderUnion* type, bool uninitialized,
const SymbolicValueFactory* factory);
virtual void setValid(bool v);
SymbolicBool* isValid() const;
SymbolicValue* clone() const override;
SymbolicValue* get(const IR::Node* node, cstring field) const override;
void setAllUnknown() override;
Expand Down Expand Up @@ -574,4 +573,4 @@ class SymbolicPacketIn final : public SymbolicExtern {

} // namespace P4

#endif /* _MIDEND_INTERPRETER_H_ */
#endif /* _MIDEND_INTERPRETER_H_ */
107 changes: 107 additions & 0 deletions testdata/p4_16_samples/extract_for_header_union.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/* -*- P4_16 -*- */
#include <core.p4>
#include <v1model.p4>

/*************************************************************************
*********************** H E A D E R S ***********************************
*************************************************************************/

header addr_ipv4_t {
bit<32> addr;
}

header addr_ipv6_t {
bit<128> addr;
}

header_union addr_t1 {
addr_ipv4_t ipv4;
addr_ipv6_t ipv6;
}

header_union addr_t2 {
addr_ipv4_t ipv4;
addr_ipv6_t ipv6;
}

struct metadata { }

struct headers {
addr_t1 addr_src1;
addr_t2 addr_src2;
}

/*************************************************************************
*********************** P A R S E R ***********************************
*************************************************************************/
parser ProtParser(packet_in packet,
out headers hdr,
inout metadata meta,
inout standard_metadata_t standard_metadata) {
state start {
packet.extract(hdr.addr_src1.ipv4);
hdr.addr_src1.ipv4.addr = hdr.addr_src2.ipv4.addr;
}
}


/*************************************************************************
************ C H E C K S U M V E R I F I C A T I O N *************
*************************************************************************/

control ProtVerifyChecksum(inout headers hdr, inout metadata meta) {
apply { }
}


/*************************************************************************
************** I N G R E S S P R O C E S S I N G *******************
*************************************************************************/

control ProtIngress(inout headers hdr,
inout metadata meta,
inout standard_metadata_t standard_metadata) {
apply { }
}

/*************************************************************************
**************** E G R E S S P R O C E S S I N G *******************
*************************************************************************/

control ProtEgress(inout headers hdr,
inout metadata meta,
inout standard_metadata_t standard_metadata) {
apply { }
}

/*************************************************************************
************* C H E C K S U M C O M P U T A T I O N **************
*************************************************************************/

control ProtComputeChecksum(inout headers hdr, inout metadata meta) {
apply { }
}


/*************************************************************************
*********************** D E P A R S E R *******************************
*************************************************************************/

control ProtDeparser(packet_out packet, in headers hdr) {
apply {
packet.emit<headers>(hdr);
}
}

/*************************************************************************
*********************** S W I T C H *******************************
*************************************************************************/

V1Switch(
ProtParser(),
ProtVerifyChecksum(),
ProtIngress(),
ProtEgress(),
ProtComputeChecksum(),
ProtDeparser()
) main;
Loading