Skip to content

Commit

Permalink
Training code
Browse files Browse the repository at this point in the history
  • Loading branch information
georg-bn committed May 17, 2024
1 parent e5e3ef3 commit 766b685
Show file tree
Hide file tree
Showing 15 changed files with 1,304 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "DeDoDe"]
path = DeDoDe
url = git@github.com:georg-bn/DeDoDe
1 change: 1 addition & 0 deletions DeDoDe
Submodule DeDoDe added at a0451b
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ A steerer for D-dimensional keypoint descriptions is a DxD matrix that transform
<img src="example_images/method.png" width="500">

For running the code, create a new virtual environment of your preference (e.g. conda) with `python>=3.9`, `jupyter notebook` and GPU-enabled PyTorch.
Then install the `rotation_steerers` package using pip (this automatically installs DeDoDe from GitHub as well, see `setup.py`):
```
pip install .
```
Clone this repository with DeDoDe as submodule (`git clone --recursive-submodules git@github.com:georg-bn/rotation-steerers.git`).
Then install using `bash setup.sh`.

The weights are uploaded to [releases](https://github.com/georg-bn/rotation-steerers/releases). To download model weights needed for the demo and put them in a new folder `model_weights`, run
```
bash download_weights.sh
Expand Down Expand Up @@ -78,7 +77,9 @@ matches_A, matches_B = matcher.to_pixel_coords(

See the example notebook [demo.ipynb](demo.ipynb) for more simple matching examples.

We will publish training code and further model weights shortly.
## Training
Follow the instructions from [DeDoDe](DeDoDe/README.md) to get the data and annotations set up.
Representative experiments are present in the [experiments](experiments) folder.

## Short summary
A steerer is a linear map that modifies keypoint descriptions as if they were obtained from a rotated image.
Expand Down
100 changes: 100 additions & 0 deletions experiments/SettingA-C4-B.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from argparse import ArgumentParser

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import ConcatDataset
import torch.nn as nn

from DeDoDe.datasets.megadepth import MegadepthBuilder
from DeDoDe.encoder import VGG
from DeDoDe.decoder import ConvRefiner, Decoder
from DeDoDe import dedode_detector_L
from DeDoDe.benchmarks import MegadepthNLLBenchmark
from DeDoDe.model_zoo import dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G

from rotation_steerers.train import train_k_steps
from rotation_steerers.descriptor_loss import DescriptorLoss
from rotation_steerers.checkpoint import CheckPoint
from rotation_steerers.steerers import DiscreteSteerer


def train(detector_weights, descriptor):
NUM_PROTOTYPES = 256 # == descriptor size
model = descriptor.cuda()
model.eval()

generator = torch.nn.Linear(in_features=NUM_PROTOTYPES,
out_features=NUM_PROTOTYPES,
bias=False).cuda().weight.data

steerer = DiscreteSteerer(generator)

params = [
{"params": steerer.parameters(), "lr": 1e-3},
]
optim = AdamW(params, weight_decay = 0)
n0, N, k = 0, 10_000, 1000
lr_scheduler = CosineAnnealingLR(optim, T_max = N)
import os
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
checkpointer = CheckPoint("workspace/", name = experiment_name, only_steerer=True)

model, optim, lr_scheduler, n0 = checkpointer.load(model, optim, lr_scheduler, n0, steerer=steerer)

detector = dedode_detector_L(weights = detector_weights)
loss = DescriptorLoss(detector=detector, normalize_descriptions = True, inv_temp = 20)


H, W = 512, 512
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True, use_detections=False)
use_horizontal_flip_aug = False
megadepth_train1 = mega.build_scenes(
split="train_loftr", min_overlap=0.01, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)
megadepth_train2 = mega.build_scenes(
split="train_loftr", min_overlap=0.35, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)

megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)

megadepth_test = mega.build_scenes(
split="test_loftr", min_overlap=0.01, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)
mega_test = MegadepthNLLBenchmark(ConcatDataset(megadepth_test))
# grad_scaler = torch.cuda.amp.GradScaler()

checkpointer.save(model, optim, lr_scheduler, -1, steerer=steerer, label=-1)
for n in range(n0, N, k):
mega_sampler = torch.utils.data.WeightedRandomSampler(
mega_ws, num_samples = 8 * k, replacement=False
)
mega_dataloader = iter(
torch.utils.data.DataLoader(
megadepth_train,
batch_size = 8,
sampler = mega_sampler,
num_workers = 8,
)
)
train_k_steps(
n, k, mega_dataloader, model, loss, optim, lr_scheduler, grad_scaler = None, rot90=True, steerer=steerer,
)
checkpointer.save(model, optim, lr_scheduler, n, steerer=steerer, label=n)
# mega_test.benchmark(detector = detector, descriptor = model)


if __name__ == "__main__":
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
os.environ["OMP_NUM_THREADS"] = "16"

parser = ArgumentParser()
parser.add_argument("--detector_path", default="dedode_detector_C4.pth")
parser.add_argument("--descriptor_path", default="dedode_descriptor_B.pth")
args = parser.parse_args()
weights = torch.load(args.detector_path)
descriptor = dedode_descriptor_B(weights = torch.load(args.descriptor_path))

train(detector_weights=weights, descriptor=descriptor)
152 changes: 152 additions & 0 deletions experiments/SettingB-C4-Perm-B.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
from argparse import ArgumentParser

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import ConcatDataset
import torch.nn as nn

from DeDoDe.datasets.megadepth import MegadepthBuilder
from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor
from DeDoDe.encoder import VGG
from DeDoDe.decoder import ConvRefiner, Decoder
from DeDoDe import dedode_detector_L
from DeDoDe.benchmarks import MegadepthNLLBenchmark

from rotation_steerers.train import train_k_steps
from rotation_steerers.descriptor_loss import DescriptorLoss
from rotation_steerers.checkpoint import CheckPoint
from rotation_steerers.steerers import DiscreteSteerer


def train(detector_weights):
NUM_PROTOTYPES = 256 # == descriptor size
residual = True
hidden_blocks = 5
amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
amp = True
conv_refiner = nn.ModuleDict(
{
"8": ConvRefiner(
512,
512,
256 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,
),
"4": ConvRefiner(
256+256,
256,
128 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,

),
"2": ConvRefiner(
128+128,
64,
32 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,

),
"1": ConvRefiner(
64 + 32,
32,
1 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,
),
}
)
import os
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
encoder = VGG(size = "19", pretrained = True, amp = amp, amp_dtype = amp_dtype)
decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).cuda()

generator = torch.block_diag(
*(
torch.tensor([[0., 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0]],
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 4)
)
)
steerer = DiscreteSteerer(generator)

params = [
{"params": model.encoder.parameters(), "lr": 1e-5},
{"params": model.decoder.parameters(), "lr": 2e-4},
{"params": steerer.parameters(), "lr": 2e-4},
]
optim = AdamW(params, weight_decay = 1e-5)
n0, N, k = 0, 100_000, 1000
lr_scheduler = CosineAnnealingLR(optim, T_max = N)
checkpointer = CheckPoint("workspace/", name = experiment_name)

model, optim, lr_scheduler, n0 = checkpointer.load(model, optim, lr_scheduler, n0, steerer=steerer)

detector = dedode_detector_L(weights = detector_weights)
loss = DescriptorLoss(detector=detector, normalize_descriptions = True, inv_temp = 20)


H, W = 512, 512
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True, use_detections=False)
use_horizontal_flip_aug = False
megadepth_train1 = mega.build_scenes(
split="train_loftr", min_overlap=0.01, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)
megadepth_train2 = mega.build_scenes(
split="train_loftr", min_overlap=0.35, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)

megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)

megadepth_test = mega.build_scenes(
split="test_loftr", min_overlap=0.01, ht=H, wt=W, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug,
)
mega_test = MegadepthNLLBenchmark(ConcatDataset(megadepth_test))
grad_scaler = torch.cuda.amp.GradScaler()

for n in range(n0, N, k):
mega_sampler = torch.utils.data.WeightedRandomSampler(
mega_ws, num_samples = 8 * k, replacement=False
)
mega_dataloader = iter(
torch.utils.data.DataLoader(
megadepth_train,
batch_size = 8,
sampler = mega_sampler,
num_workers = 8,
)
)
train_k_steps(
n, k, mega_dataloader, model, loss, optim, lr_scheduler, grad_scaler = grad_scaler, rot90=True, steerer=steerer,
)
checkpointer.save(model, optim, lr_scheduler, n, steerer=steerer)
mega_test.benchmark(detector = detector, descriptor = model)


if __name__ == "__main__":
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
os.environ["OMP_NUM_THREADS"] = "16"
parser = ArgumentParser()
parser.add_argument("--detector_path", default="dedode_detector_C4.pth")
args = parser.parse_args()
weights = torch.load(args.detector_path)

train(weights)
Loading

0 comments on commit 766b685

Please sign in to comment.