Skip to content

Commit 215f277

Browse files
committed
Experimental: Add initial wavefront/obj parser for vertices
This PR is an early experimental implementation of wavefront obj parser in tensorflow-io for 3D objects. This PR is the first step to obtain raw vertices in float32 tensor with shape of `[n, 3]`. Additional follow up PRs will be needed to handle meshs with different shapes (not sure if ragged tensor will be a good fit in that case) Some background on obj file: Wavefront (obj) is a format widely used in 3D (another is ply) modeling (http://paulbourke.net/dataformats/obj/). It is simple (ASCII) with good support for many softwares. Machine learning in 3D has been an active field with some advances such as PolyGen (https://arxiv.org/abs/2002.10880) Processing obj files are needed to process 3D with tensorflow. In 3D the basic elements could be vertices or faces. This PR tries to cover vertices first so that vertices in obj file can be loaded into TF's graph for further processing within graph pipeline. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent ac75e1c commit 215f277

File tree

9 files changed

+206
-0
lines changed

9 files changed

+206
-0
lines changed

WORKSPACE

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,3 +1110,14 @@ http_archive(
11101110
"https://github.com/mongodb/mongo-c-driver/releases/download/1.16.2/mongo-c-driver-1.16.2.tar.gz",
11111111
],
11121112
)
1113+
1114+
http_archive(
1115+
name = "tinyobjloader",
1116+
build_file = "//third_party:tinyobjloader.BUILD",
1117+
sha256 = "b8c972dfbbcef33d55554e7c9031abe7040795b67778ad3660a50afa7df6ec56",
1118+
strip_prefix = "tinyobjloader-2.0.0rc8",
1119+
urls = [
1120+
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz",
1121+
"https://github.com/tinyobjloader/tinyobjloader/archive/v2.0.0rc8.tar.gz",
1122+
],
1123+
)

tensorflow_io/core/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,22 @@ cc_library(
695695
alwayslink = 1,
696696
)
697697

698+
cc_library(
699+
name = "obj_ops",
700+
srcs = [
701+
"kernels/obj_kernels.cc",
702+
"ops/obj_ops.cc",
703+
],
704+
copts = tf_io_copts(),
705+
linkstatic = True,
706+
deps = [
707+
"@local_config_tf//:libtensorflow_framework",
708+
"@local_config_tf//:tf_header_lib",
709+
"@tinyobjloader",
710+
],
711+
alwayslink = 1,
712+
)
713+
698714
cc_binary(
699715
name = "python/ops/libtensorflow_io.so",
700716
copts = tf_io_copts(),
@@ -717,6 +733,7 @@ cc_binary(
717733
"//tensorflow_io/core:parquet_ops",
718734
"//tensorflow_io/core:pcap_ops",
719735
"//tensorflow_io/core:pulsar_ops",
736+
"//tensorflow_io/core:obj_ops",
720737
"//tensorflow_io/core:operation_ops",
721738
"//tensorflow_io/core:pubsub_ops",
722739
"//tensorflow_io/core:serialization_ops",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow/core/platform/logging.h"
18+
#include "tiny_obj_loader.h"
19+
20+
namespace tensorflow {
21+
namespace io {
22+
namespace {
23+
24+
class DecodeObjOp : public OpKernel {
25+
public:
26+
explicit DecodeObjOp(OpKernelConstruction* context) : OpKernel(context) {}
27+
28+
void Compute(OpKernelContext* context) override {
29+
const Tensor* input_tensor;
30+
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
31+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_tensor->shape()),
32+
errors::InvalidArgument("input must be scalar, got shape ",
33+
input_tensor->shape().DebugString()));
34+
const tstring& input = input_tensor->scalar<tstring>()();
35+
36+
tinyobj::ObjReader reader;
37+
38+
if (!reader.ParseFromString(input.c_str(), "")) {
39+
OP_REQUIRES(
40+
context, false,
41+
errors::Internal("Unable to read obj file: ", reader.Error()));
42+
}
43+
44+
if (!reader.Warning().empty()) {
45+
LOG(WARNING) << "TinyObjReader: " << reader.Warning();
46+
}
47+
48+
auto& attrib = reader.GetAttrib();
49+
50+
int64 count = attrib.vertices.size() / 3;
51+
52+
Tensor* output_tensor = nullptr;
53+
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({count, 3}),
54+
&output_tensor));
55+
// Loop over attrib.vertices:
56+
for (int64 i = 0; i < count; i++) {
57+
tinyobj::real_t x = attrib.vertices[i * 3 + 0];
58+
tinyobj::real_t y = attrib.vertices[i * 3 + 1];
59+
tinyobj::real_t z = attrib.vertices[i * 3 + 2];
60+
output_tensor->tensor<float, 2>()(i, 0) = x;
61+
output_tensor->tensor<float, 2>()(i, 1) = y;
62+
output_tensor->tensor<float, 2>()(i, 2) = z;
63+
}
64+
}
65+
};
66+
REGISTER_KERNEL_BUILDER(Name("IO>DecodeObj").Device(DEVICE_CPU), DecodeObjOp);
67+
68+
} // namespace
69+
} // namespace io
70+
} // namespace tensorflow

tensorflow_io/core/ops/obj_ops.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
namespace io {
22+
namespace {
23+
24+
REGISTER_OP("IO>DecodeObj")
25+
.Input("input: string")
26+
.Output("output: float32")
27+
.SetShapeFn([](shape_inference::InferenceContext* c) {
28+
shape_inference::ShapeHandle unused;
29+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
30+
c->set_output(0, c->MakeShape({c->UnknownDim(), 3}));
31+
return Status::OK();
32+
});
33+
34+
} // namespace
35+
} // namespace io
36+
} // namespace tensorflow

tensorflow_io/core/python/api/experimental/image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@
2727
decode_yuy2,
2828
decode_avif,
2929
decode_jp2,
30+
decode_obj,
3031
)

tensorflow_io/core/python/experimental/image_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,17 @@ def decode_jp2(contents, dtype=tf.uint8, name=None):
208208
A `Tensor` of type `uint8` and shape of `[height, width, 3]` (RGB).
209209
"""
210210
return core_ops.io_decode_jpeg2k(contents, dtype=dtype, name=name)
211+
212+
213+
def decode_obj(contents, name=None):
214+
"""
215+
Decode a Wavefront (obj) file into a float32 tensor.
216+
217+
Args:
218+
contents: A `Tensor` of type `string`. 0-D. The Wavefront (obj) file.
219+
name: A name for the operation (optional).
220+
221+
Returns:
222+
A `Tensor` of type `float32` and shape of `[n, 3]` for vertices.
223+
"""
224+
return core_ops.io_decode_obj(contents, name=name)

tests/test_obj.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
# use this file except in compliance with the License. You may obtain a copy of
5+
# the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations under
13+
# the License.
14+
# ==============================================================================
15+
"""Test Wavefront OBJ"""
16+
17+
import os
18+
import numpy as np
19+
import pytest
20+
21+
import tensorflow as tf
22+
import tensorflow_io as tfio
23+
24+
25+
def test_decode_obj():
26+
"""Test case for decode obj"""
27+
filename = os.path.join(
28+
os.path.dirname(os.path.abspath(__file__)), "test_obj", "sample.obj",
29+
)
30+
filename = "file://" + filename
31+
32+
obj = tfio.experimental.image.decode_obj(tf.io.read_file(filename))
33+
expected = np.array(
34+
[[-0.5, 0.0, 0.4], [-0.5, 0.0, -0.8], [-0.5, 1.0, -0.8], [-0.5, 1.0, 0.4]],
35+
dtype=np.float32,
36+
)
37+
assert np.array_equal(obj, expected)

tests/test_obj/sample.obj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Simple Wavefront file
2+
v -0.500000 0.000000 0.400000
3+
v -0.500000 0.000000 -0.800000
4+
v -0.500000 1.000000 -0.800000
5+
v -0.500000 1.000000 0.400000
6+
f -4 -3 -2 -1

third_party/tinyobjloader.BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
licenses(["notice"]) # MIT license
4+
5+
cc_library(
6+
name = "tinyobjloader",
7+
srcs = [
8+
"tiny_obj_loader.cc",
9+
],
10+
hdrs = [
11+
"tiny_obj_loader.h",
12+
],
13+
copts = [],
14+
)

0 commit comments

Comments
 (0)