Skip to content

Commit

Permalink
Combine post processing and initialization
Browse files Browse the repository at this point in the history
It's going to be useful for Any support
  • Loading branch information
vitalybuka committed Feb 4, 2020
1 parent a268b67 commit 045acda
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 79 deletions.
140 changes: 65 additions & 75 deletions src/mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,67 @@ class DataSourceSampler {
WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
};

class PostProcessing {
public:
using PostProcessors = std::unordered_multimap<const protobuf::Descriptor*,
Mutator::PostProcess>;

PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
RandomEngine* random)
: keep_initialized_(keep_initialized),
post_processors_(post_processors),
random_(random) {}

void Run(Message* message, int max_depth) {
--max_depth;
const Descriptor* descriptor = message->GetDescriptor();

// Apply custom mutators in nested messages before packing any.
const Reflection* reflection = message->GetReflection();
for (int i = 0; i < descriptor->field_count(); i++) {
const FieldDescriptor* field = descriptor->field(i);
if (keep_initialized_ &&
(field->is_required() || descriptor->options().map_entry()) &&
!reflection->HasField(*message, field)) {
CreateDefaultField()(FieldInstance(message, field));
}

if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;

if (max_depth < 0 && !field->is_required()) {
// Clear deep optional fields to avoid stack overflow.
reflection->ClearField(message, field);
if (field->is_repeated())
assert(!reflection->FieldSize(*message, field));
else
assert(!reflection->HasField(*message, field));
continue;
}

if (field->is_repeated()) {
const int field_size = reflection->FieldSize(*message, field);
for (int j = 0; j < field_size; ++j) {
Message* nested_message =
reflection->MutableRepeatedMessage(message, field, j);
Run(nested_message, max_depth);
}
} else if (reflection->HasField(*message, field)) {
Message* nested_message = reflection->MutableMessage(message, field);
Run(nested_message, max_depth);
}
}

auto range = post_processors_.equal_range(descriptor);
for (auto it = range.first; it != range.second; ++it)
it->second(message, (*random_)());
}

private:
bool keep_initialized_;
const PostProcessors& post_processors_;
RandomEngine* random_;
};

} // namespace

class FieldMutator {
Expand Down Expand Up @@ -479,47 +540,16 @@ void Mutator::Mutate(Message* message, size_t max_size_hint) {
static_cast<int>(max_size_hint) -
static_cast<int>(message->ByteSizeLong()));

InitializeAndTrim(message, kMaxInitializeDepth);
PostProcessing(keep_initialized_, post_processors_, &random_)
.Run(message, kMaxInitializeDepth);
assert(IsInitialized(*message));

if (!post_processors_.empty()) {
ApplyPostProcessing(message);
}
}

void Mutator::RegisterPostProcessor(const Descriptor* desc,
PostProcess callback) {
post_processors_.emplace(desc, callback);
}

void Mutator::ApplyPostProcessing(Message* message) {
const Descriptor* descriptor = message->GetDescriptor();

auto range = post_processors_.equal_range(descriptor);
for (auto it = range.first; it != range.second; ++it)
it->second(message, random_());

// Now recursively apply custom mutators.
const Reflection* reflection = message->GetReflection();
for (int i = 0; i < descriptor->field_count(); i++) {
const FieldDescriptor* field = descriptor->field(i);
if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
continue;
}
if (field->is_repeated()) {
const int field_size = reflection->FieldSize(*message, field);
for (int j = 0; j < field_size; ++j) {
Message* nested_message =
reflection->MutableRepeatedMessage(message, field, j);
ApplyPostProcessing(nested_message);
}
} else if (reflection->HasField(*message, field)) {
Message* nested_message = reflection->MutableMessage(message, field);
ApplyPostProcessing(nested_message);
}
}
}

bool Mutator::MutateImpl(const Message& source, Message* message,
bool copy_clone_only, int size_increase_hint) {
if (size_increase_hint > 0) size_increase_hint /= 2;
Expand Down Expand Up @@ -578,49 +608,9 @@ void Mutator::CrossOver(const Message& message1, Message* message2,
MutateImpl(message1, message2, true, size_increase_hint) ||
MutateImpl(*message2, message2, true, size_increase_hint);

InitializeAndTrim(message2, kMaxInitializeDepth);
PostProcessing(keep_initialized_, post_processors_, &random_)
.Run(message2, kMaxInitializeDepth);
assert(IsInitialized(*message2));

if (!post_processors_.empty()) {
ApplyPostProcessing(message2);
}
}

void Mutator::InitializeAndTrim(Message* message, int max_depth) {
const Descriptor* descriptor = message->GetDescriptor();
const Reflection* reflection = message->GetReflection();
for (int i = 0; i < descriptor->field_count(); ++i) {
const FieldDescriptor* field = descriptor->field(i);
if (keep_initialized_ &&
(field->is_required() || descriptor->options().map_entry()) &&
!reflection->HasField(*message, field)) {
CreateDefaultField()(FieldInstance(message, field));
}

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (max_depth <= 0 && !field->is_required()) {
// Clear deep optional fields to avoid stack overflow.
reflection->ClearField(message, field);
if (field->is_repeated())
assert(!reflection->FieldSize(*message, field));
else
assert(!reflection->HasField(*message, field));
continue;
}

if (field->is_repeated()) {
const int field_size = reflection->FieldSize(*message, field);
for (int j = 0; j < field_size; ++j) {
Message* nested_message =
reflection->MutableRepeatedMessage(message, field, j);
InitializeAndTrim(nested_message, max_depth - 1);
}
} else if (reflection->HasField(*message, field)) {
Message* nested_message = reflection->MutableMessage(message, field);
InitializeAndTrim(nested_message, max_depth - 1);
}
}
}
}

int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
Expand Down
7 changes: 3 additions & 4 deletions src/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,17 @@ class Mutator {
private:
friend class FieldMutator;
friend class TestMutator;
void InitializeAndTrim(protobuf::Message* message, int max_depth);
bool MutateImpl(const protobuf::Message& source, protobuf::Message* message,
bool copy_clone_only, int size_increase_hint);
std::string MutateUtf8String(const std::string& value,
int size_increase_hint);
void ApplyPostProcessing(protobuf::Message* message);
bool IsInitialized(const protobuf::Message& message) const;
bool keep_initialized_ = true;
size_t random_to_default_ratio_ = 100;
RandomEngine random_;
std::unordered_multimap<const protobuf::Descriptor*, PostProcess>
post_processors_;
using PostProcessors =
std::unordered_multimap<const protobuf::Descriptor*, PostProcess>;
PostProcessors post_processors_;
};

} // namespace protobuf_mutator
Expand Down

0 comments on commit 045acda

Please sign in to comment.