Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified models/converted_model.tflite
Binary file not shown.
22 changes: 12 additions & 10 deletions src/daemon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@

// local
#include "image_processing.hpp"
#include "model.hpp"
#include "opencv_helper.hpp"
#include "symbol.hpp"
#include "stroke.hpp"
#include "symbol_processing.hpp"
#include "unix_socket_server/unix_socket_server.hpp"

// libs
// spdlog
#include <spdlog/spdlog.h>
// json
#include <nlohmann/json.hpp>
// opencv
#include <opencv2/opencv.hpp>

// std
#include <string>
#include <vector>

namespace mathboard {
Expand Down Expand Up @@ -84,8 +86,8 @@ void Daemon() {
nlohmann::json jsonData = nlohmann::json::parse(stringData);
auto &strokesData = jsonData["strokes"];

std::vector<Stroke> strokeVector;
strokeVector.reserve(strokesData.size());
std::vector<Stroke> stroke_vector;
stroke_vector.reserve(strokesData.size());

for (std::size_t i = 0; i != strokesData.size(); i++) {
auto &strokeData = strokesData[i];
Expand All @@ -94,14 +96,14 @@ void Daemon() {
processedImage = opencvHelper.GetFrame();
processedImage = GrayScaleImage(processedImage);

strokeVector.emplace_back(strokeData["id"], strokeData["x"],
strokeData["y"], processedImage);
stroke_vector.emplace_back(strokeData["id"], strokeData["x"],
strokeData["y"], processedImage);
}

// TODO
// Do something with the stroke vector
Model model(
"/home/projects/MathBoardAlgoML/models/converted_model.tflite");
const auto symbol_groups = GenerateSymbolGroups(model, stroke_vector);
const auto symbol_group = GetBestSymbolGroup(symbol_groups);
}
}
}

} // namespace mathboard
27 changes: 12 additions & 15 deletions src/grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ template <typename T> concept HasPosition = requires(T object) {
object.GetPosition();
};

template <typename T> concept HasDimensions = requires(T object) {
object.GetWidth();
object.GetHeight();
template <typename T> concept HasSize = requires(T object) {
object.GetSize();
};

struct Position2f {
float x;
float y;
struct Pos2i {
int x;
int y;
};
struct BoundingBox {
struct Size2i {
int width;
int height;
};
Expand All @@ -32,7 +31,7 @@ namespace mathboard {
// The Grid class handles broad-phase intersection detection by dividing space
// into cells, each containing potential object intersection. Grid as a class
// doesn't own pointers to objects which stores.
template <typename T> requires HasPosition<T> && HasDimensions<T> class Grid {
template <typename T> requires HasPosition<T> &&HasSize<T> class Grid {
public:
// Constructor that sets up a grid covering a specified area,
// defined by the top-left and bottom-right corners, with each cell having the
Expand All @@ -55,13 +54,11 @@ template <typename T> requires HasPosition<T> && HasDimensions<T> class Grid {

// Insert object to grid.
void Insert(T *object) {
const Position2f pos =
Position2f{object->GetPosition().x, object->GetPosition().y};
Position2f object_min = pos;
const BoundingBox bounding_box =
BoundingBox{object->GetWidth(), object->GetHeight()};
Position2f object_max =
Position2f{pos.x + bounding_box.width, pos.y + bounding_box.height};
const Pos2i pos = Pos2i(object->GetPosition().x, object->GetPosition().y);
const Size2i size =
Size2i(object->GetSize().width, object->GetSize().height);
Pos2i object_min = pos;
Pos2i object_max = Pos2i(pos.x + size.width, pos.y + size.height);

// calculating position of vertices in grid
// decrese width and height because containers are 0 index based
Expand Down
77 changes: 70 additions & 7 deletions src/image_processing.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// local
// header
#include "image_processing.hpp"

// libs
// opencv
#include <opencv2/core/mat.hpp>
#include <opencv2/core/hal/interface.h>
#include <opencv2/core/types.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/videoio.hpp>
// spdlog
Expand Down Expand Up @@ -33,12 +34,22 @@ cv::Mat RasterizeImage(const std::filesystem::path &filename) {
return rasterized_image;
}

cv::Mat CropImageToSymbol(const cv::Mat &input_mat) {
cv::Mat CropToSymbol(const cv::Mat &input_mat) {
if (input_mat.channels() != 1) {
spdlog::error("[CropToSymbol]: input_mat isn't grayscale");
}
cv::Mat cropped_mat;
input_mat.copyTo(cropped_mat);
if (input_mat.type() != CV_8UC1) {
cropped_mat *= 255.0;
double min_val = 0.0f;
double max_val = 0.0f;
cv::minMaxLoc(cropped_mat, &min_val, &max_val);
cropped_mat.convertTo(cropped_mat, CV_8UC1, (max_val - min_val),
-min_val * 255.0 / (max_val - min_val));
}
std::vector<std::vector<cv::Point>> contours;
// merges bounding boxes of all objects on image
// into one bigger bounding box containing all
// content of image
cv::findContours(input_mat, contours, cv::RETR_CCOMP,
cv::findContours(cropped_mat, contours, cv::RETR_CCOMP,
cv::CHAIN_APPROX_SIMPLE);
cv::Rect bounding_box;
for (std::size_t i = 0; i < contours.size(); i++) {
Expand Down Expand Up @@ -80,4 +91,56 @@ std::string RecognizeText(const cv::Mat &img) {
return text;
}

// it won't work unless size of strokes overlap with pixels on img passed
cv::Mat CombineStrokes(const std::vector<mathboard::Stroke> &strokes) {
if (strokes.empty()) {
spdlog::error("[CombineStrokes()]: strokes vector is empty");
}
// Compute bounding box
cv::Rect combined_rect;
for (const auto &stroke : strokes) {
combined_rect |= stroke.GetRect();
}

// Create empty matrix
cv::Mat stroke_combination = cv::Mat::zeros(combined_rect.size(), CV_32F);
for (const auto &stroke : strokes) {
if (stroke.GetMatrix().empty()) {
continue;
}
cv::Point offset = stroke.GetPosition() - combined_rect.tl();
cv::Rect roi(offset, stroke.GetSize());

// Ensure the region of interest is within bounds
if (roi.x >= 0 && roi.y >= 0 &&
roi.x + roi.width <= stroke_combination.cols &&
roi.y + roi.height <= stroke_combination.rows) {
cv::Mat stroke_combination_roi = stroke_combination(roi);
cv::max(stroke_combination_roi, stroke.GetMatrix(), stroke_combination_roi);
}
}
return stroke_combination;
}

cv::Mat ResizeToMNISTFormat(const cv::Mat &input_mat) {
cv::Mat output_mat;
cv::resize(input_mat, output_mat, cv::Size2i(28, 28), 0, 0, cv::INTER_CUBIC);
return output_mat;
}

std::vector<mathboard::Stroke>
FindIntersectingStrokes(const mathboard::Stroke &target_stroke,
const std::vector<mathboard::Stroke> &strokes) {
std::vector<mathboard::Stroke> intersecting_strokes;
for (const auto &stroke : strokes) {
// checks if rectangles are intersecting
const cv::Rect2i &intersection =
target_stroke.GetRect() & stroke.GetRect();
if (intersection.area() != 0) {
intersecting_strokes.push_back(stroke);
}
}
return intersecting_strokes;
}

} // namespace mathboard
69 changes: 53 additions & 16 deletions src/image_processing.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
#pragma once

// local
#include "grid.hpp"
#include "stroke.hpp"

// opencv
#include <opencv2/core/mat.hpp>

// std
#include <filesystem>
#include <vector>

namespace mathboard {
// take path to svg file and transform it into pixel representation using
// cv::Mat
cv::Mat RasterizeImage(const std::filesystem::path &filename);

// `input_array` has to be grayscale
// `input_mat` has to be grayscale
// crop image to tightly fit symbol on it
cv::Mat CropImageToSymbol(const cv::Mat &input_mat);
cv::Mat CropToSymbol(const cv::Mat &input_mat);

// Return grascaled version of input image
cv::Mat GrayScaleImage(const cv::Mat &input_mat);
Expand All @@ -23,6 +29,37 @@ cv::Mat BinarizeImage(const cv::Mat &input_mat);
// Returns image string
std::string RecognizeText(const cv::Mat &img);

// Combine all given strokes into a single matrix.
cv::Mat CombineStrokes(const std::vector<mathboard::Stroke> &strokes);

cv::Mat ResizeToMNISTFormat(const cv::Mat &input_mat);

std::vector<mathboard::Stroke>
FindIntersectingStrokes(const mathboard::Stroke &target_stroke,
const std::vector<mathboard::Stroke> &strokes);

// Generate all possible combinations of elements from the container,
// regardless of order.
// Example: input = [1, 2] -> output = [[1], [2], [1, 2]]
template <typename T>
std::vector<std::vector<T>>
GenerateCombinations(const std::vector<T> &elements) {
const std::size_t total = std::pow(2, elements.size());

std::vector<std::vector<T>> all_combinations;
all_combinations.resize(total - 1);
// start with one to avoid empty combination
for (std::size_t i = 1; i < total; ++i) {
for (std::size_t j = 0; j < elements.size(); ++j) {
// Check if the j-th bit is set
if (i & (1 << j)) {
all_combinations[i - 1].push_back(elements[j]);
}
}
}
return all_combinations;
}

// returns instance of Grid class with all images put on
// their positions ready to further interpreatation
// it sets grid cell size and boundaries of it
Expand All @@ -33,28 +70,28 @@ Grid<mathboard::Stroke> inline PlaceOnGrid(
cv::Point2f bot_right_corner{0, 0};
cv::Point2f top_left_corner{INFINITY, INFINITY};
for (std::size_t i = 0; i < strokes.size(); i++) {
const cv::Point2f line_pos = strokes[i].GetPosition();
const cv::Rect line_BB = strokes[i].GetBoundingBox();
const cv::Point2i stroke_pos = strokes[i].GetPosition();
const cv::Size2i stroke_size = strokes[i].GetSize();

if (bot_right_corner.x < line_pos.x + line_BB.width) {
bot_right_corner.x = line_pos.x + line_BB.width;
if (bot_right_corner.x < stroke_pos.x + stroke_size.width) {
bot_right_corner.x = stroke_pos.x + stroke_size.width;
}
if (top_left_corner.x > line_pos.x) {
top_left_corner.x = line_pos.x;
if (top_left_corner.x > stroke_pos.x) {
top_left_corner.x = stroke_pos.x;
}
if (bot_right_corner.y < line_pos.y + line_BB.height) {
bot_right_corner.y = line_pos.y + line_BB.height;
if (bot_right_corner.y < stroke_pos.y + stroke_size.height) {
bot_right_corner.y = stroke_pos.y + stroke_size.height;
}
if (top_left_corner.y > line_pos.y) {
top_left_corner.y = line_pos.y;
if (top_left_corner.y > stroke_pos.y) {
top_left_corner.y = stroke_pos.y;
}
}

// calculate average size of stroke
cv::Size2f average_stroke_size{0.0f, 0.0f};
for (std::size_t i = 0; i < strokes.size(); i++) {
average_stroke_size.width += strokes[i].GetWidth();
average_stroke_size.height += strokes[i].GetHeight();
cv::Size2i average_stroke_size = cv::Size2i(0, 0);
for (const auto &stroke : strokes) {
average_stroke_size.width += stroke.GetSize().width;
average_stroke_size.height += stroke.GetSize().width;
}
average_stroke_size.width /= static_cast<float>(strokes.size());
average_stroke_size.height /= static_cast<float>(strokes.size());
Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// local
#include "daemon.hpp"

int main(int argc, char **argv) { mathboard::Daemon(); }
int main(int argc, char **argv) { mathboard::Daemon(); }
25 changes: 16 additions & 9 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

// libs
// tensorflow-lite
#include "spdlog/spdlog.h"
#include "tensorflow/lite/core/interpreter_builder.h"
#include "tensorflow/lite/examples/label_image/get_top_n.h"
#include "tensorflow/lite/kernels/register.h"
// spdlog
#include <spdlog/spdlog.h>
// OpenCV
#include <opencv2/core/hal/interface.h>

// std
#include <opencv2/core/hal/interface.h>
#include <vector>

namespace mathboard {
Expand All @@ -26,16 +28,21 @@ Model::Model(const std::filesystem::path &model_filename) {
}
m_Interpreter->AllocateTensors();
}
uint32_t Model::Predict(cv::Mat character) const {
if(character.rows != 28 || character.cols != 28
|| character.channels() != 1 || character.type() != CV_32F) {
spdlog::error("Wrong matrix format\n");
std::pair<float, int> Model::Predict(cv::Mat input_mat) const {
if (input_mat.rows != 28 || input_mat.cols != 28) {
spdlog::error("[Model::Predict]: Wrong matrix size");
}
if (input_mat.channels() != 1) {
spdlog::error("[Model::Predict]: Matrix isn't grayscale");
}
if (input_mat.type() != CV_32F) {
spdlog::error("[Model::Predict]: Wrong matrix data type");
}
const float treshold = 0.1f;
std::vector<std::pair<float, int>> top_results;

memcpy(m_Interpreter->typed_input_tensor<float>(0), character.data,
character.total() * character.elemSize());
memcpy(m_Interpreter->typed_input_tensor<float>(0), input_mat.data,
input_mat.total() * input_mat.elemSize());

// inference
m_Interpreter->Invoke();
Expand All @@ -48,6 +55,6 @@ uint32_t Model::Predict(cv::Mat character) const {
tflite::label_image::get_top_n<float>(
m_Interpreter->typed_output_tensor<float>(0), output_size, 1, treshold,
&top_results, kTfLiteFloat32);
return top_results.front().second;
return top_results.front();
}
} // namespace mathboard
Loading