Skip to content

Commit

Permalink
Merge pull request BVLC#1239 from sguada/encoded
Browse files Browse the repository at this point in the history
Allow using encoded images in Datum, LevelDB, LMDB
  • Loading branch information
sguada committed Oct 16, 2014
2 parents ea43036 + f17cd3e commit 7effdca
Show file tree
Hide file tree
Showing 16 changed files with 599 additions and 91 deletions.
4 changes: 2 additions & 2 deletions examples/cifar10/create_cifar10.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rm -rf $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/cifar10_test_$DBTYPE

echo "Computing image mean..."

./build/tools/compute_image_mean $EXAMPLE/cifar10_train_$DBTYPE \
$EXAMPLE/mean.binaryproto $DBTYPE
./build/tools/compute_image_mean -backend=$DBTYPE \
$EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/mean.binaryproto

echo "Done."
5 changes: 5 additions & 0 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {
vector<vector<float> > fg_windows_;
vector<vector<float> > bg_windows_;
Blob<Dtype> data_mean_;
vector<Dtype> mean_values_;
bool has_mean_file_;
bool has_mean_values_;
bool cache_images_;
vector<std::pair<std::string, Datum > > image_database_cache_;
};

} // namespace caffe
Expand Down
10 changes: 10 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#ifdef TIMING
#include "caffe/util/benchmark.hpp"
#endif

namespace caffe {

Expand Down Expand Up @@ -76,9 +79,16 @@ class Net {
void Reshape();

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
#ifdef TIMING
Timer timer;
timer.Start();
#endif
Dtype loss;
Forward(bottom, &loss);
Backward();
#ifdef TIMING
LOG(INFO) << "ForwardBackward Time: " << timer.MilliSeconds() << "ms.";
#endif
return loss;
}

Expand Down
20 changes: 15 additions & 5 deletions include/caffe/util/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class Timer {
public:
Timer();
virtual ~Timer();
void Start();
void Stop();
float MilliSeconds();
float MicroSeconds();
float Seconds();
virtual void Start();
virtual void Stop();
virtual float MilliSeconds();
virtual float MicroSeconds();
virtual float Seconds();

inline bool initted() { return initted_; }
inline bool running() { return running_; }
Expand All @@ -37,6 +37,16 @@ class Timer {
float elapsed_microseconds_;
};

class CPUTimer : public Timer {
public:
explicit CPUTimer();
virtual ~CPUTimer() {}
virtual void Start();
virtual void Stop();
virtual float MilliSeconds();
virtual float MicroSeconds();
};

} // namespace caffe

#endif // CAFFE_UTIL_BENCHMARK_H_
38 changes: 38 additions & 0 deletions include/caffe/util/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ inline void WriteProtoToBinaryFile(
WriteProtoToBinaryFile(proto, filename.c_str());
}

bool ReadFileToDatum(const string& filename, const int label, Datum* datum);

inline bool ReadFileToDatum(const string& filename, Datum* datum) {
return ReadFileToDatum(filename, -1, datum);
}

bool ReadImageToDatum(const string& filename, const int label,
const int height, const int width, const bool is_color, Datum* datum);

Expand All @@ -106,6 +112,21 @@ inline bool ReadImageToDatum(const string& filename, const int label,
return ReadImageToDatum(filename, label, 0, 0, true, datum);
}

bool DecodeDatum(const int height, const int width, const bool is_color,
Datum* datum);

inline bool DecodeDatum(const int height, const int width, Datum* datum) {
return DecodeDatum(height, width, true, datum);
}

inline bool DecodeDatum(const bool is_color, Datum* datum) {
return DecodeDatum(0, 0, is_color, datum);
}

inline bool DecodeDatum(Datum* datum) {
return DecodeDatum(0, 0, true, datum);
}

#ifndef OSX
cv::Mat ReadImageToCVMat(const string& filename,
const int height, const int width, const bool is_color);
Expand All @@ -124,6 +145,23 @@ inline cv::Mat ReadImageToCVMat(const string& filename) {
return ReadImageToCVMat(filename, 0, 0, true);
}

cv::Mat DecodeDatumToCVMat(const Datum& datum,
const int height, const int width, const bool is_color);

inline cv::Mat DecodeDatumToCVMat(const Datum& datum,
const int height, const int width) {
return DecodeDatumToCVMat(datum, height, width, true);
}

inline cv::Mat DecodeDatumToCVMat(const Datum& datum,
const bool is_color) {
return DecodeDatumToCVMat(datum, 0, 0, is_color);
}

inline cv::Mat DecodeDatumToCVMat(const Datum& datum) {
return DecodeDatumToCVMat(datum, 0, 0, true);
}

void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
#endif

Expand Down
99 changes: 96 additions & 3 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,102 @@ void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
Blob<Dtype>* transformed_blob) {
Datum datum;
CVMatToDatum(cv_img, &datum);
Transform(datum, transformed_blob);
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;

const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();

CHECK_EQ(channels, img_channels);
CHECK_LE(height, img_height);
CHECK_LE(width, img_width);
CHECK_GE(num, 1);

CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

const int crop_size = param_.crop_size();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;

CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);

Dtype* mean = NULL;
if (has_mean_file) {
CHECK_EQ(img_channels, data_mean_.channels());
CHECK_EQ(img_height, data_mean_.height());
CHECK_EQ(img_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data();
}
if (has_mean_values) {
CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
"Specify either 1 mean_value or as many as channels: " << img_channels;
if (img_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity
for (int c = 1; c < img_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}

int h_off = 0;
int w_off = 0;
cv::Mat cv_cropped_img = cv_img;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = Rand(img_height - crop_size + 1);
w_off = Rand(img_width - crop_size + 1);
} else {
h_off = (img_height - crop_size) / 2;
w_off = (img_width - crop_size) / 2;
}
cv::Rect roi(w_off, h_off, crop_size, crop_size);
cv_cropped_img = cv_img(roi);
} else {
CHECK_EQ(img_height, height);
CHECK_EQ(img_width, width);
}

CHECK(cv_cropped_img.data);

Dtype* transformed_data = transformed_blob->mutable_cpu_data();
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < img_channels; ++c) {
if (do_mirror) {
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
// int top_index = (c * height + h) * width + w;
Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
if (has_mean_file) {
int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
transformed_data[top_index] =
(pixel - mean[mean_index]) * scale;
} else {
if (has_mean_values) {
transformed_data[top_index] =
(pixel - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = pixel * scale;
}
}
}
}
}
}
#endif

Expand Down
33 changes: 29 additions & 4 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "caffe/dataset_factory.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
Expand Down Expand Up @@ -45,8 +46,11 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
}
// Read a data point, and use it to initialize the top blob.
CHECK(iter_ != dataset_->end());
const Datum& datum = iter_->value;
Datum datum = iter_->value;

if (DecodeDatum(&datum)) {
LOG(INFO) << "Decoding Datum";
}
// image
int crop_size = this->layer_param_.transform_param().crop_size();
if (crop_size > 0) {
Expand Down Expand Up @@ -78,6 +82,11 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
// This function is used to create a thread that prefetches the data.
template <typename Dtype>
void DataLayer<Dtype>::InternalThreadEntry() {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
CHECK(this->prefetch_data_.count());
CHECK(this->transformed_data_.count());
Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
Expand All @@ -87,25 +96,41 @@ void DataLayer<Dtype>::InternalThreadEntry() {
top_label = this->prefetch_label_.mutable_cpu_data();
}
const int batch_size = this->layer_param_.data_param().batch_size();

for (int item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
// get a blob
CHECK(iter_ != dataset_->end());
const Datum& datum = iter_->value;

cv::Mat cv_img;
if (datum.encoded()) {
cv_img = DecodeDatumToCVMat(datum);
}
read_time += timer.MicroSeconds();
timer.Start();

// Apply data transformations (mirror, scale, crop...)
int offset = this->prefetch_data_.offset(item_id);
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_.Transform(datum, &(this->transformed_data_));
if (datum.encoded()) {
this->data_transformer_.Transform(cv_img, &(this->transformed_data_));
} else {
this->data_transformer_.Transform(datum, &(this->transformed_data_));
}
if (this->output_labels_) {
top_label[item_id] = datum.label();
}

trans_time += timer.MicroSeconds();
// go to the next iter
++iter_;
if (iter_ == dataset_->end()) {
iter_ = dataset_->begin();
}
}
batch_timer.Stop();
DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(DataLayer);
Expand Down
Loading

0 comments on commit 7effdca

Please sign in to comment.