Skip to content

Commit d1fcb67

Browse files
committed
convert superpoint onnx
1 parent deae84b commit d1fcb67

File tree

5 files changed

+80
-0
lines changed

5 files changed

+80
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "scripts/superpoint/SuperPointPretrainedNetwork"]
2+
path = scripts/superpoint/SuperPointPretrainedNetwork
3+
url = https://github.com/magicleap/SuperPointPretrainedNetwork

scripts/superpoint/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# convert pre-trained superpoint pytorch weights to onnx format
2+
3+
---
4+
5+
## dependencies
6+
7+
---
8+
9+
- python: 3x
10+
11+
-
12+
13+
```bash
14+
python3 -m pip install -r requirements.txt
15+
```
16+
17+
## :running: how to run
18+
19+
---
20+
21+
- update submodule
22+
23+
```bash
24+
git submodule update --init --recursive
25+
```
26+
27+
- export onnx weights
28+
29+
```
30+
python3 convert_to_onnx.py
31+
```

scripts/superpoint/convert_to_onnx.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python
2+
import os
3+
4+
import torch
5+
import torch.onnx
6+
7+
from SuperPointPretrainedNetwork.demo_superpoint import SuperPointNet
8+
9+
_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
10+
_WEIGHTS_PATH = os.path.join(
11+
_CURRENT_DIR, "SuperPointPretrainedNetwork/superpoint_v1.pth"
12+
)
13+
14+
15+
def main():
16+
assert os.path.isfile(_WEIGHTS_PATH)
17+
model = SuperPointNet()
18+
model.load_state_dict(torch.load(_WEIGHTS_PATH))
19+
model.eval()
20+
21+
batch_size = 1
22+
height = 16
23+
width = 16
24+
x = torch.randn(batch_size, 1, height, width)
25+
26+
torch.onnx.export(
27+
model,
28+
x,
29+
"super_point.onnx",
30+
export_params=True,
31+
opset_version=10,
32+
do_constant_folding=True,
33+
input_names=["input"],
34+
output_names=["output"],
35+
dynamic_axes={
36+
"input": {0: "batch_size", 2: "height", 3: "width"},
37+
"output": {0: "batch_size"},
38+
},
39+
)
40+
41+
42+
if __name__ == "__main__":
43+
main()

scripts/superpoint/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
opencv-python
2+
torch>=0.4

0 commit comments

Comments
 (0)