Skip to content

Commit a001c0b

Browse files
committed
Reworked the MNIST example to include the SIFT dataset.
Added support for reading the SIFT dataset to pico_toolshed. Renamed the mnist example to kd_forest. Documentation updated.
1 parent 4f41141 commit a001c0b

File tree

16 files changed

+277
-165
lines changed

16 files changed

+277
-165
lines changed

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ PicoTree can interface with different types of points and point sets through tra
6161
* Creating a [custom search visitor](./examples/kd_tree/kd_tree_custom_search_visitor.cpp).
6262
* [Saving and loading](./examples/kd_tree/kd_tree_save_and_load.cpp) a KdTree to and from a file.
6363
* Support for [Eigen](./examples/eigen/eigen.cpp) and [OpenCV](./examples/opencv/opencv.cpp) data types.
64-
* Running the KdTree on the [MNIST](./examples/mnist/mnist.cpp) [database](http://yann.lecun.com/exdb/mnist/).
64+
* [Running the KdTree and KdForest](./examples/kd_forest/kd_forest.cpp) on the [MNIST](http://yann.lecun.com/exdb/mnist/) and [SIFT](http://corpus-texmex.irisa.fr/) datasets.
6565
* How to use the [KdTree with Python](./examples/python/kd_tree.py).
6666

6767
# Requirements
@@ -113,9 +113,11 @@ $ pip install ./pico_tree
113113

114114
# References
115115

116-
* [Computational Geometry - Algorithms and Applications.](https://www.springer.com/gp/book/9783540779735) Mark de Berg, Otfried Cheong, Marc van Kreveld, and Mark Overmars, Springer-Verlag, third edition, 2008.
117-
* S. Maneewongvatana and D. M. Mount. [It's okay to be skinny, if your friends are fat.](http://www.cs.umd.edu/~mount/Papers/cgc99-smpack.pdf) 4th Annual CGC Workshop on Computational Geometry, 1999.
118-
* S. Arya and H. Y. Fu. [Expected-case complexity of approximate nearest neighbor searching.](https://www.cse.ust.hk/faculty/arya/pub/exp.pdf) InProceedings of the 11th ACM-SIAM Symposium on Discrete Algorithms, 2000.
119-
* S. Arya and D. M. Mount. [Algorithms for fast vector quantization.](https://www.cs.umd.edu/~mount/Papers/DCC.pdf) In IEEE Data Compression Conference, pages 381–390, March 1993.
120-
* N. Sample, M. Haines, M. Arnold and T. Purcell. [Optimizing Search Strategies in k-d Trees.](http://infolab.stanford.edu/~nsample/pubs/samplehaines.pdf) In: 5th WSES/IEEE World Multiconference on Circuits, Systems, Communications & Computers (CSCC 2001), July 2001.
121-
* A. Yershova and S. M. LaValle, [Improving Motion-Planning Algorithms by Efficient Nearest-Neighbor Searching.](http://msl.cs.uiuc.edu/~lavalle/papers/YerLav06.pdf) In IEEE Transactions on Robotics, vol. 23, no. 1, pp. 151-157, Feb. 2007.
116+
* J. L. Bentley, [Multidimensional binary search trees used for associative searching](https://dl.acm.org/doi/pdf/10.1145/361002.361007), Communications of the ACM, vol. 18, no. 9, pp. 509–517, 1975.
117+
* S. Arya and D. M. Mount, [Algorithms for fast vector quantization](https://www.cs.umd.edu/~mount/Papers/DCC.pdf), In IEEE Data Compression Conference, pp. 381–390, March 1993.
118+
* S. Maneewongvatana and D. M. Mount, [It's okay to be skinny, if your friends are fat](http://www.cs.umd.edu/~mount/Papers/cgc99-smpack.pdf), 4th Annual CGC Workshop on Computational Geometry, 1999.
119+
* S. Arya and H. Y. Fu, [Expected-case complexity of approximate nearest neighbor searching](https://www.cse.ust.hk/faculty/arya/pub/exp.pdf), InProceedings of the 11th ACM-SIAM Symposium on Discrete Algorithms, 2000.
120+
* N. Sample, M. Haines, M. Arnold and T. Purcell, [Optimizing Search Strategies in k-d Trees](http://infolab.stanford.edu/~nsample/pubs/samplehaines.pdf), In: 5th WSES/IEEE World Multiconference on Circuits, Systems, Communications & Computers (CSCC 2001), July 2001.
121+
* A. Yershova and S. M. LaValle, [Improving Motion-Planning Algorithms by Efficient Nearest-Neighbor Searching](http://msl.cs.uiuc.edu/~lavalle/papers/YerLav06.pdf), In IEEE Transactions on Robotics, vol. 23, no. 1, pp. 151-157, Feb. 2007.
122+
* M. de Berg, O. Cheong, M. van Kreveld, and M. Overmars, [Computational Geometry - Algorithms and Applications](https://www.springer.com/gp/book/9783540779735), Springer-Verlag, third edition, 2008.
123+
* C. Silpa-Anan and R. Hartley, [Optimised KD-trees for fast image descriptor matching](http://vigir.missouri.edu/~gdesouza/Research/Conference_CDs/IEEE_CVPR_2008/data/papers/298.pdf), In CVPR, 2008.

examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ add_subdirectory(pico_understory)
88

99
add_subdirectory(kd_tree)
1010

11-
add_subdirectory(mnist)
11+
add_subdirectory(kd_forest)
1212

1313
find_package(Eigen3 QUIET)
1414

examples/kd_forest/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_executable(kd_forest kd_forest.cpp)
2+
set_default_target_properties(kd_forest)
3+
target_link_libraries(kd_forest PUBLIC pico_toolshed pico_understory)

examples/kd_forest/kd_forest.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <iostream>
2+
#include <pico_toolshed/format/format_bin.hpp>
3+
#include <pico_toolshed/scoped_timer.hpp>
4+
#include <pico_tree/array_traits.hpp>
5+
#include <pico_tree/kd_tree.hpp>
6+
#include <pico_tree/vector_traits.hpp>
7+
#include <pico_understory/kd_forest.hpp>
8+
9+
#include "mnist.hpp"
10+
#include "sift.hpp"
11+
12+
template <typename Dataset>
13+
void RunDataset(
14+
std::size_t max_leaf_size_exact,
15+
std::size_t max_leaf_size_apprx,
16+
std::size_t forest_size) {
17+
using Point = typename Dataset::PointType;
18+
using Space = std::reference_wrapper<std::vector<Point>>;
19+
using Scalar = typename Point::value_type;
20+
21+
auto train = Dataset::ReadTrain();
22+
auto test = Dataset::ReadTest();
23+
std::size_t count = test.size();
24+
std::vector<pico_tree::Neighbor<int, Scalar>> nns(count);
25+
std::string fn_nns_gt = Dataset::kDatasetName + "_nns_gt.bin";
26+
27+
if (!std::filesystem::exists(fn_nns_gt)) {
28+
auto kd_tree = [&train, &max_leaf_size_exact]() {
29+
ScopedTimer t0("kd_tree build");
30+
return pico_tree::KdTree<Space>(train, max_leaf_size_exact);
31+
}();
32+
33+
{
34+
ScopedTimer t1("kd_tree query");
35+
for (std::size_t i = 0; i < nns.size(); ++i) {
36+
kd_tree.SearchNn(test[i], nns[i]);
37+
}
38+
}
39+
40+
std::cout << "Writing " << fn_nns_gt << "." << std::endl;
41+
pico_tree::WriteBin(fn_nns_gt, nns);
42+
} else {
43+
std::cout << "Reading " << fn_nns_gt << "." << std::endl;
44+
pico_tree::ReadBin(fn_nns_gt, nns);
45+
}
46+
47+
std::size_t equal = 0;
48+
49+
// Building the KdForest takes roughly forest_size times longer compared to
50+
// building the regular KdTree. However, it is usually a lot faster.
51+
{
52+
auto rkd_tree = [&train, &max_leaf_size_apprx, &forest_size]() {
53+
ScopedTimer t0("kd_forest build");
54+
return pico_tree::KdForest<Space>(
55+
train, max_leaf_size_apprx, forest_size);
56+
}();
57+
58+
ScopedTimer t1("kd_forest query");
59+
pico_tree::Neighbor<int, Scalar> nn;
60+
for (std::size_t i = 0; i < nns.size(); ++i) {
61+
rkd_tree.SearchNn(test[i], nn);
62+
63+
if (nns[i].index == nn.index) {
64+
++equal;
65+
}
66+
}
67+
}
68+
69+
std::cout << "Precision: "
70+
<< (static_cast<float>(equal) / static_cast<float>(count))
71+
<< std::endl;
72+
}
73+
74+
int main() {
75+
// max_leaf_size_apprx = 128:
76+
// forest_size 8: a precision of around 79%.
77+
// forest_size 16: a precision of around 93%.
78+
// forest_size 32: a precision of around 98%.
79+
RunDataset<Mnist>(16, 128, 8);
80+
// max_leaf_size_apprx = 1024:
81+
// forest_size 8: a precision of around 58%.
82+
// forest_size 16: a precision of around 68%.
83+
// forest_size 32: a precision of around 80%.
84+
// forest_size 64: a precision of around 87%.
85+
// forest_size 128: out of memory :'(
86+
RunDataset<Sift>(16, 1024, 8);
87+
return 0;
88+
}

examples/kd_forest/mnist.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#pragma once
2+
3+
#include <algorithm>
4+
#include <filesystem>
5+
#include <pico_toolshed/format/format_mnist.hpp>
6+
7+
template <typename U, typename T, std::size_t N>
8+
std::array<U, N> Cast(std::array<T, N> const& i) {
9+
std::array<U, N> c;
10+
std::transform(i.begin(), i.end(), c.begin(), [](T a) -> U {
11+
return static_cast<U>(a);
12+
});
13+
return c;
14+
}
15+
16+
template <typename U, typename T, std::size_t N>
17+
std::vector<std::array<U, N>> Cast(std::vector<std::array<T, N>> const& i) {
18+
std::vector<std::array<U, N>> c;
19+
std::transform(
20+
i.begin(),
21+
i.end(),
22+
std::back_inserter(c),
23+
[](std::array<T, N> const& a) -> std::array<U, N> { return Cast<U>(a); });
24+
return c;
25+
}
26+
27+
class Mnist {
28+
private:
29+
using Scalar = float;
30+
using ImageByte = std::array<std::byte, 28 * 28>;
31+
using ImageFloat = std::array<Scalar, 28 * 28>;
32+
33+
static std::vector<ImageFloat> ReadImages(std::string const& filename) {
34+
if (!std::filesystem::exists(filename)) {
35+
throw std::runtime_error(filename + " doesn't exist.");
36+
}
37+
38+
std::vector<ImageByte> images_u8;
39+
pico_tree::ReadMnistImages(filename, images_u8);
40+
return Cast<Scalar>(images_u8);
41+
}
42+
43+
public:
44+
using PointType = ImageFloat;
45+
46+
static std::string const kDatasetName;
47+
48+
static std::vector<PointType> ReadTrain() {
49+
std::string fn_images_train = "train-images.idx3-ubyte";
50+
return ReadImages(fn_images_train);
51+
}
52+
53+
static std::vector<PointType> ReadTest() {
54+
std::string fn_images_test = "t10k-images.idx3-ubyte";
55+
return ReadImages(fn_images_test);
56+
}
57+
};
58+
59+
std::string const Mnist::kDatasetName = "mnist";

examples/kd_forest/sift.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include <filesystem>
4+
#include <pico_toolshed/format/format_xvecs.hpp>
5+
6+
class Sift {
7+
private:
8+
using VectorFloat = std::array<float, 128>;
9+
10+
static std::vector<VectorFloat> ReadVectors(std::string const& filename) {
11+
if (!std::filesystem::exists(filename)) {
12+
throw std::runtime_error(filename + " doesn't exist.");
13+
}
14+
15+
std::vector<VectorFloat> vectors;
16+
pico_tree::ReadXvecs(filename, vectors);
17+
return vectors;
18+
}
19+
20+
public:
21+
using PointType = VectorFloat;
22+
23+
static std::string const kDatasetName;
24+
25+
static std::vector<PointType> ReadTrain() {
26+
std::string fn_images_train = "sift_base.fvecs";
27+
return ReadVectors(fn_images_train);
28+
}
29+
30+
static std::vector<PointType> ReadTest() {
31+
std::string fn_images_test = "sift_query.fvecs";
32+
return ReadVectors(fn_images_test);
33+
}
34+
};
35+
36+
std::string const Sift::kDatasetName = "sift";

examples/mnist/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/mnist/mnist.cpp

Lines changed: 0 additions & 126 deletions
This file was deleted.

examples/pico_toolshed/pico_toolshed/format/format_bin.hpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <limits>
3+
#include <filesystem>
44
#include <pico_tree/internal/stream.hpp>
55

66
namespace pico_tree {
@@ -22,20 +22,10 @@ void ReadBin(std::string const& filename, std::vector<T>& v) {
2222
std::fstream stream =
2323
internal::OpenStream(filename, std::ios::in | std::ios::binary);
2424

25-
// The four lines below are used when determining the file size.
26-
// C++17 is not used here in order to keep the benchmark C++11.
27-
stream.ignore(std::numeric_limits<std::streamsize>::max());
28-
std::streamsize byte_count = stream.gcount();
29-
stream.clear();
30-
stream.seekg(0, std::ios_base::beg);
31-
32-
if (byte_count == 0) {
33-
return;
34-
}
35-
25+
auto bytes = std::filesystem::file_size(filename);
3626
std::size_t const element_size = sizeof(T);
3727
std::size_t const element_count =
38-
static_cast<std::size_t>(byte_count / element_size);
28+
static_cast<std::size_t>(bytes) / element_size;
3929
v.resize(element_count);
4030
stream.read(reinterpret_cast<char*>(&v[0]), element_size * element_count);
4131
}

0 commit comments

Comments
 (0)