Straightforward Neural Network is an open source neural network library in C++20 optimized for CPU. The goal of this library is to make the use of neural networks as easy as possible.
See the full documentation here.
Dataset Name | Data type | Problem type | Score | Number of Parameters |
---|---|---|---|---|
Audio Cats and Dogs | audio | classification | 91.04% Accurracy | 382 |
Daily min temperatures | time series | regression | 1.42 Mean Absolute Error | 30 |
CIFAR-10 | image | classification | 61.96% Accurracy | 207210 |
Fashion-MNIST | image | classification | 89.65% Accurracy | 270926 |
MNIST | image | classification | 98.71% Accurracy | 261206 |
Wine | multivariate | classification | 100.0% Accurracy | 444 |
Iris | multivariate | classification | 100.0% Accurracy | 150 |
- First of all, move to a build folder:
mkdir build && cd ./build
- For dataset tests, the datasets must be downloaded:
./resources/ImportDatasets.sh
-
Use CMake to build:
cmake -G"Unix Makefiles" ./.. && make
-
Run the unit tests:
./build/tests/unit_tests/UnitTests
-
Run the dataset tests:
./tests/dataset_tests/DatasetTests
-
Use CMake to generate a VS2022 project :
cmake -G"Visual Studio 17 2022" ./..
-
To run the unit tests open the generated project:
./build/tests/unit_tests/UnitTests.vcxproj
-
To run the dataset tests open the generated project:
./build/tests/dataset_tests/DatasetTests.vcxproj
Create, train and use a neural network in few lines of code.
#include <snn/data/Dataset.hpp>
#include <snn/neural_network/StraightforwardNeuralNetwork.hpp>
snn::Dataset dataset(snn::problem::classification, inputData, expectedOutputs);
snn::StraightforwardNeuralNetwork neuralNetwork({
Input(1, 28, 28), // The input shape is (C, X, Y).
Convolution(16, 3, activation::ReLU), // The layer has 16 filters and (3, 3) kernels.
FullyConnected(92), // The layer has 92 neurons.
FullyConnected(10, activation::identity, Softmax())
});
neuralNetwork.train(dataset, 0.90_acc || 20_s); // Train neural network on data until 90% accuracy or 20s.
float accuracy = neuralNetwork.getGlobalClusteringRate(); // Retrieve the accuracy.