Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ http_archive(
build_file = "//third_party:libyuv.BUILD",
)

http_archive(
name = "stblib",
strip_prefix = "stb-master",
urls = ["https://github.com/nothings/stb/archive/master.zip"],
build_file = "//third_party:stblib.BUILD",
)


load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

Expand Down
3 changes: 0 additions & 3 deletions tensorflow_lite_support/examples/task/vision/desktop/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ cc_binary(
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)

Expand All @@ -45,7 +44,6 @@ cc_binary(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)

Expand All @@ -66,6 +64,5 @@ cc_binary(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/flags/parse.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -40,7 +39,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' image classifier model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to classify. The image EXIF orientation "
"Absolute path to the image to classify. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
ABSL_FLAG(int32, max_results, 5,
"Maximum number of classification results to display.");
Expand Down Expand Up @@ -116,18 +116,28 @@ absl::Status Classify() {
ImageClassifier::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run classification and display results.
ASSIGN_OR_RETURN(ClassificationResult result,
image_classifier->Classify(*frame_buffer));
DisplayResult(result);

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand All @@ -154,10 +164,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run classification.
absl::Status status = tflite::support::task::vision::Classify();
if (status.ok()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -41,7 +40,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' image segmenter model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to segment. The image EXIF orientation "
"Absolute path to the image to segment. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
ABSL_FLAG(std::string, output_mask_png, "",
"Absolute path to the output category mask (confidence masks outputs "
Expand Down Expand Up @@ -75,9 +75,10 @@ absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {
// Create RgbImageData for the output mask.
uint8* pixel_data = static_cast<uint8*>(
malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8)));
RgbImageData mask = {.pixel_data = pixel_data,
.width = segmentation.width(),
.height = segmentation.height()};
ImageData mask = {.pixel_data = pixel_data,
.width = segmentation.width(),
.height = segmentation.height(),
.channels = 3};

// Populate RgbImageData from the raw mask and ColoredLabel-s.
for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) {
Expand All @@ -90,12 +91,12 @@ absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {

// Encode mask as PNG.
RETURN_IF_ERROR(
EncodeRgbImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
std::cout << absl::StrFormat("Category mask saved to: %s\n",
absl::GetFlag(FLAGS_output_mask_png));

// Cleanup and return.
RgbImageDataFree(&mask);
ImageDataFree(&mask);
return absl::OkStatus();
}

Expand Down Expand Up @@ -138,10 +139,20 @@ absl::Status Segment() {
ImageSegmenter::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run segmentation and save category mask.
ASSIGN_OR_RETURN(SegmentationResult result,
Expand All @@ -152,7 +163,7 @@ absl::Status Segment() {
RETURN_IF_ERROR(DisplayColorLegend(result));

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand Down Expand Up @@ -181,10 +192,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run segmentation.
absl::Status status = tflite::support::task::vision::Segment();
if (status.ok()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -44,7 +43,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' object detector model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to perform detection on. The image EXIF "
"Absolute path to the image to run detection on. The image must be "
"RGB or RGBA (grayscale is not supported). The image EXIF "
"orientation flag, if any, is NOT taken into account.");
ABSL_FLAG(std::string, output_png, "",
"Absolute path to a file where to draw the detection results on top "
Expand Down Expand Up @@ -111,7 +111,7 @@ ObjectDetectorOptions BuildOptions() {
}

absl::Status EncodeResultToPngFile(const DetectionResult& result,
const RgbImageData* image) {
const ImageData* image) {
for (int index = 0; index < result.detections_size(); ++index) {
// Get bounding box as left, top, right, bottom.
const BoundingBox& box = result.detections(index).bounding_box();
Expand All @@ -127,7 +127,7 @@ absl::Status EncodeResultToPngFile(const DetectionResult& result,
// is applied.
for (int y = std::max(0, top); y < std::min(image->height, bottom); ++y) {
for (int x = std::max(0, left); x < std::min(image->width, right); ++x) {
int pixel_index = 3 * (image->width * y + x);
int pixel_index = image->channels * (image->width * y + x);
if (x < left + kLineThickness || x > right - kLineThickness ||
y < top + kLineThickness || y > bottom - kLineThickness) {
image->pixel_data[pixel_index] = r;
Expand All @@ -139,7 +139,7 @@ absl::Status EncodeResultToPngFile(const DetectionResult& result,
}
// Encode to PNG and return.
RETURN_IF_ERROR(
EncodeRgbImageToPngFile(*image, absl::GetFlag(FLAGS_output_png)));
EncodeImageToPngFile(*image, absl::GetFlag(FLAGS_output_png)));
std::cout << absl::StrFormat("Results saved to: %s\n",
absl::GetFlag(FLAGS_output_png));
return absl::OkStatus();
Expand Down Expand Up @@ -183,10 +183,20 @@ absl::Status Detect() {
ObjectDetector::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run object detection and draw results on input image.
ASSIGN_OR_RETURN(DetectionResult result,
Expand All @@ -197,7 +207,7 @@ absl::Status Detect() {
DisplayResult(result);

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand Down Expand Up @@ -232,10 +242,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run detection.
absl::Status status = tflite::support::task::vision::Detect();
if (status.ok()) {
Expand Down
11 changes: 2 additions & 9 deletions tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ cc_library(
"//tensorflow_lite_support/cc/port:integral_types",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/core:external_file_handler",
"//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/cc:cc_ops",
"@org_tensorflow//tensorflow/cc:scope",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:tensorflow",
"@stblib//:stb_image",
"@stblib//:stb_image_write",
],
)
Loading