Skip to content

Commit

Permalink
support brats2020 dataset for unet3d (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas authored Dec 19, 2023
1 parent 33f8133 commit 40cbcb9
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 3 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ python threshold_relu_example.py
* `coco`: `yolov8n`
* `camvid`: `unet`
* `cityscapes`: `unet`
* `llgmri`: `unet`
* `ucf101`: `x3d_s`, `x3d_m`
* `brats20`: `unet3d`

## Quantization Results
## Quantization Results

### imagenet (val, top-1 acc)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
Expand Down
8 changes: 7 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def initialize_wrapper(dataset_name, model_name,
elif dataset_name == "cityscapes":
os.environ['CITYSCAPES_PATH'] = dataset_path
if model_name in ["unet"]:
from models.segmentation.cityscapes import MmsegmentationModelWrapper
from models.segmentation.cityscapes import \
MmsegmentationModelWrapper
model_wrapper = MmsegmentationModelWrapper(model_name)
elif dataset_name == "lggmri":
os.environ['LGGMRI_PATH'] = dataset_path
Expand All @@ -45,6 +46,11 @@ def initialize_wrapper(dataset_name, model_name,
if model_name in ["x3d_s", "x3d_m"]:
from models.action_recognition.ucf101 import MmactionModelWrapper
model_wrapper = MmactionModelWrapper(model_name)
elif dataset_name == "brats2020":
os.environ['BRATS2020_PATH'] = dataset_path
if model_name in ["unet3d"]:
from models.segmentation.brats2020 import Unet3DKaggleModelWrapper
model_wrapper = Unet3DKaggleModelWrapper(model_name)

if model_wrapper is None:
raise NotImplementedError("Unknown dataset/model combination")
Expand Down
4 changes: 3 additions & 1 deletion models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from abc import ABC, abstractmethod


class TorchModelWrapper(nn.Module, ABC):
def __init__(self, model_name):
Expand Down
Loading

0 comments on commit 40cbcb9

Please sign in to comment.