Skip to content

Commit

Permalink
Fix backward compatability of pos_enc bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
angshine committed Aug 17, 2021
1 parent 3955e99 commit da756ef
Show file tree
Hide file tree
Showing 18 changed files with 164 additions and 7 deletions.
6 changes: 6 additions & 0 deletions configs/loftr/indoor/buggy_pos_enc/loftr_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'

cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
8 changes: 8 additions & 0 deletions configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'

cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False

cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
6 changes: 6 additions & 0 deletions configs/loftr/indoor/buggy_pos_enc/loftr_ot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'

cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
8 changes: 8 additions & 0 deletions configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'

cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False

cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
1 change: 1 addition & 0 deletions configs/loftr/indoor/scannet/loftr_ds_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'

cfg.LOFTR.MATCH_COARSE.BORDER_RM = 0
18 changes: 18 additions & 0 deletions configs/loftr/indoor/scannet/loftr_ds_eval_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
""" A config only for reproducing the ScanNet evaluation results.
We remove border matches by default, but the originally implemented
`remove_border()` has a bug, leading to only two sides of
all borders are actually removed. However, the [bug fix](https://github.com/zju3dv/LoFTR/commit/e9146c8144dea5f3cbdd98b225f3e147a171c216)
makes the scannet evaluation results worse (auc@10=40.8 => 39.5), which should be
caused by tiny result fluctuation of few image pairs. This config set `BORDER_RM` to 0
to be consistent with the results in our paper.
Update: This config is for testing the re-trained model with the pos-enc bug fixed.
"""

from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = True
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'

cfg.LOFTR.MATCH_COARSE.BORDER_RM = 0
16 changes: 16 additions & 0 deletions configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'

cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]

# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5

cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
17 changes: 17 additions & 0 deletions configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False

cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]

# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5

cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
16 changes: 16 additions & 0 deletions configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'

cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]

# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5

cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
17 changes: 17 additions & 0 deletions configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from src.config.default import _CN as cfg

cfg.LOFTR.COARSE.TEMP_BUG_FIX = False
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False

cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]

# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5

cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
30 changes: 30 additions & 0 deletions scripts/reproduce_test/indoor_ds_new.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash -l
# a indoor_ds model with the pos_enc impl bug fixed.

SCRIPTPATH=$(dirname $(readlink -f "$0"))
PROJECT_DIR="${SCRIPTPATH}/../../"

# conda activate loftr
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR

data_cfg_path="configs/data/scannet_test_1500.py"
main_cfg_path="configs/loftr/indoor/scannet/loftr_ds_eval_new.py"
ckpt_path="weights/indoor_ds_new.ckpt"
dump_dir="dump/loftr_ds_indoor_new"
profiler_name="inference"
n_nodes=1 # mannually keep this the same with --nodes
n_gpus_per_node=-1
torch_num_workers=4
batch_size=1 # per gpu

python -u ./test.py \
${data_cfg_path} \
${main_cfg_path} \
--ckpt_path=${ckpt_path} \
--dump_dir=${dump_dir} \
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
--batch_size=${batch_size} --num_workers=${torch_num_workers}\
--profiler_name=${profiler_name} \
--benchmark

2 changes: 1 addition & 1 deletion scripts/reproduce_test/indoor_ot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR

data_cfg_path="configs/data/scannet_test_1500.py"
main_cfg_path="configs/loftr/indoor/loftr_ot.py"
main_cfg_path="configs/loftr/indoor/buggy_pos_enc/loftr_ot.py"
ckpt_path="weights/indoor_ot.ckpt"
dump_dir="dump/loftr_ot_indoor"
profiler_name="inference"
Expand Down
2 changes: 1 addition & 1 deletion scripts/reproduce_test/outdoor_ds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR

data_cfg_path="configs/data/megadepth_test_1500.py"
main_cfg_path="configs/loftr/outdoor/loftr_ds.py"
main_cfg_path="configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py"
ckpt_path="weights/outdoor_ds.ckpt"
dump_dir="dump/loftr_ds_outdoor"
profiler_name="inference"
Expand Down
2 changes: 1 addition & 1 deletion scripts/reproduce_test/outdoor_ot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR

data_cfg_path="configs/data/megadepth_test_1500.py"
main_cfg_path="configs/loftr/outdoor/loftr_ot.py"
main_cfg_path="configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py"
ckpt_path="weights/outdoor_ot.ckpt"
dump_dir="dump/loftr_ot_outdoor"
profiler_name="inference"
Expand Down
1 change: 1 addition & 0 deletions src/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_CN.LOFTR.COARSE.NHEAD = 8
_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
_CN.LOFTR.COARSE.TEMP_BUG_FIX = True

# 3. Coarse-Matching config
_CN.LOFTR.MATCH_COARSE = CN()
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/lightning_loftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):

# Pretrained weights
if pretrained_ckpt:
self.matcher.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu')['state_dict'])
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
for k in list(state_dict.keys()):
if k.startswith('matcher.'):
state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
self.matcher.load_state_dict(state_dict, strict=True)
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")

# Testing
Expand Down
4 changes: 3 additions & 1 deletion src/loftr/loftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, config):

# Modules
self.backbone = build_backbone(config)
self.pos_encoding = PositionEncodingSine(config['coarse']['d_model'])
self.pos_encoding = PositionEncodingSine(
config['coarse']['d_model'],
temp_bug_fix=config['coarse']['temp_bug_fix'])
self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
self.coarse_matching = CoarseMatching(config['match_coarse'])
self.fine_preprocess = FinePreprocess(config)
Expand Down
11 changes: 9 additions & 2 deletions src/loftr/utils/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,24 @@ class PositionEncodingSine(nn.Module):
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""

def __init__(self, d_model, max_shape=(256, 256)):
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
on the final performance. For now, we keep both impls for backward compatability.
We will remove the buggy impl after re-training all variants of our released models.
"""
super().__init__()

pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
if temp_bug_fix:
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
else: # a buggy implementation (for backward compatability only)
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
Expand Down

0 comments on commit da756ef

Please sign in to comment.