Skip to content

Commit 740afc2

Browse files
committed
convert loftr weights to onnx format
1 parent b022c78 commit 740afc2

File tree

6 files changed

+213
-0
lines changed

6 files changed

+213
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "scripts/superglue/SuperGluePretrainedNetwork"]
55
path = scripts/superglue/SuperGluePretrainedNetwork
66
url = https://github.com/magicleap/SuperGluePretrainedNetwork.git
7+
[submodule "scripts/loftr/LoFTR"]
8+
path = scripts/loftr/LoFTR
9+
url = https://github.com/xmba15/LoFTR

scripts/loftr/LoFTR

Submodule LoFTR added at 6cdf544

scripts/loftr/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# convert pre-trained loftr pytorch weights to onnx format
2+
3+
---
4+
5+
## dependencies
6+
7+
---
8+
9+
- python: 3x
10+
11+
-
12+
13+
```bash
14+
git submodule update --init --recursive
15+
16+
python3 -m pip install -r requirements.txt
17+
```
18+
19+
## :running: how to run
20+
21+
---
22+
23+
- download [LoFTR](https://github.com/zju3dv/LoFTR) weights indoor_ds_new.ckpt from [HERE](https://drive.google.com/drive/folders/1xu2Pq6mZT5hmFgiYMBT9Zt8h1yO-3SIp)
24+
25+
- export onnx weights
26+
27+
```
28+
python3 convert_to_onnx.py --model_path /path/to/indoor_ds_new.ckpt
29+
```
30+
## Note ##
31+
32+
- The LoFTR's [latest commit](b4ee7eb0359d0062e794c99f73e27639d7c7ac9f) seems to be only compatible with the new weights (Ref: https://github.com/zju3dv/LoFTR/issues/48). Hence, this onnx cpp application is only compatible with *indoor_ds_new.ckpt* weights.

scripts/loftr/convert_to_onnx.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
3+
import onnxruntime
4+
import torch
5+
6+
from loftr_wrapper import LoFTRWrapper as LoFTR
7+
8+
9+
def get_args():
10+
import argparse
11+
12+
parser = argparse.ArgumentParser("convert loftr torch weights to onnx format")
13+
parser.add_argument("--model_path", type=str, required=True)
14+
15+
return parser.parse_args()
16+
17+
18+
def main():
19+
args = get_args()
20+
model_path = args.model_path
21+
model = LoFTR()
22+
model.load_state_dict(torch.load(model_path)["state_dict"])
23+
model.eval()
24+
25+
batch_size = 1
26+
height = 480
27+
width = 640
28+
29+
data = {}
30+
data["image0"] = torch.randn(batch_size, 1, height, width)
31+
data["image1"] = torch.randn(batch_size, 1, height, width)
32+
33+
torch.onnx.export(
34+
model,
35+
data,
36+
"loftr.onnx",
37+
export_params=True,
38+
opset_version=12,
39+
do_constant_folding=True,
40+
input_names=list(data.keys()),
41+
output_names=["keypoints0", "keypoints1", "confidence"],
42+
dynamic_axes={
43+
"image0": {2: "height", 3: "width"},
44+
"image1": {2: "height", 3: "width"},
45+
"keypoints0": {0: "num_keypoints"},
46+
"keypoints1": {0: "num_keypoints"},
47+
"confidence": {0: "num_keypoints"},
48+
},
49+
)
50+
51+
print(f"\nonnx model is saved to: {os.getcwd()}/loftr.onnx")
52+
53+
print("\ntest inference using onnxruntime")
54+
sess = onnxruntime.InferenceSession("loftr.onnx")
55+
for input in sess.get_inputs():
56+
print("input: ", input)
57+
58+
print("\n")
59+
for output in sess.get_outputs():
60+
print("output: ", output)
61+
62+
63+
if __name__ == "__main__":
64+
import warnings
65+
66+
warnings.filterwarnings("ignore")
67+
68+
main()

scripts/loftr/loftr_wrapper.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python
2+
import copy
3+
import os
4+
import sys
5+
from typing import Any, Dict
6+
7+
import torch
8+
from einops.einops import rearrange
9+
10+
_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
11+
sys.path.append(os.path.join(_CURRENT_DIR, "LoFTR"))
12+
13+
from src.loftr import LoFTR, default_cfg
14+
15+
DEFAULT_CFG = copy.deepcopy(default_cfg)
16+
DEFAULT_CFG["coarse"]["temp_bug_fix"] = True
17+
18+
19+
class LoFTRWrapper(LoFTR):
20+
def __init__(
21+
self,
22+
config: Dict[str, Any] = DEFAULT_CFG,
23+
):
24+
LoFTR.__init__(self, config)
25+
26+
def forward(
27+
self,
28+
image0: torch.Tensor,
29+
image1: torch.Tensor,
30+
) -> Dict[str, torch.Tensor]:
31+
data = {
32+
"image0": image0,
33+
"image1": image1,
34+
}
35+
del image0, image1
36+
37+
data.update(
38+
{
39+
"bs": data["image0"].size(0),
40+
"hw0_i": data["image0"].shape[2:],
41+
"hw1_i": data["image1"].shape[2:],
42+
}
43+
)
44+
45+
if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence
46+
feats_c, feats_f = self.backbone(
47+
torch.cat([data["image0"], data["image1"]], dim=0)
48+
)
49+
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
50+
data["bs"]
51+
), feats_f.split(data["bs"])
52+
else: # handle different input shapes
53+
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
54+
data["image0"]
55+
), self.backbone(data["image1"])
56+
57+
data.update(
58+
{
59+
"hw0_c": feat_c0.shape[2:],
60+
"hw1_c": feat_c1.shape[2:],
61+
"hw0_f": feat_f0.shape[2:],
62+
"hw1_f": feat_f1.shape[2:],
63+
}
64+
)
65+
66+
# 2. coarse-level loftr module
67+
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
68+
feat_c0 = rearrange(self.pos_encoding(feat_c0), "n c h w -> n (h w) c")
69+
feat_c1 = rearrange(self.pos_encoding(feat_c1), "n c h w -> n (h w) c")
70+
71+
mask_c0 = mask_c1 = None # mask is useful in training
72+
if "mask0" in data:
73+
mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)
74+
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
75+
76+
# 3. match coarse-level
77+
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
78+
79+
# 4. fine-level refinement
80+
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
81+
feat_f0, feat_f1, feat_c0, feat_c1, data
82+
)
83+
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
84+
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
85+
feat_f0_unfold, feat_f1_unfold
86+
)
87+
88+
# 5. match fine-level
89+
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
90+
91+
rename_keys: Dict[str, str] = {
92+
"mkpts0_f": "keypoints0",
93+
"mkpts1_f": "keypoints1",
94+
"mconf": "confidence",
95+
}
96+
out: Dict[str, torch.Tensor] = {}
97+
for k, v in rename_keys.items():
98+
_d = data[k]
99+
if isinstance(_d, torch.Tensor):
100+
out[v] = _d
101+
else:
102+
raise TypeError(
103+
f"Expected torch.Tensor for item `{k}`. Gotcha {type(_d)}"
104+
)
105+
del data
106+
107+
return out

scripts/loftr/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
kornia==0.6.10
2+
onnxruntime

0 commit comments

Comments
 (0)