Skip to content

Commit

Permalink
upload model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
coca-huang committed Apr 24, 2021
1 parent c0fcba2 commit ce85952
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 123 deletions.
119 changes: 7 additions & 112 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,120 +1,15 @@
# Created by .ignore support plugin (hsz.mobi)

### User Template
data/
### PyTorch template
tmp/
.idea/
*.log
*.pyc

### VisualStudioCode template
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace

# Local History for Visual Studio Code
.history/
log/
!log/log.conf

### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
.idea/
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### VisualStudioCode template
.vscode/
.history/

### macOS template
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon

# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MFEIF: Learning a Deep Multi-scale Feature Ensemble and an Edge-attention Guidan

### Installation

Clone this github repository to your local.
Clone this GitHub repository to your local.

### Test

Expand Down
16 changes: 16 additions & 0 deletions functions/feather_fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torch.nn as nn
from torch import Tensor


class FeatherFuse(nn.Module):
def __init__(self):
super(FeatherFuse, self).__init__()

@staticmethod
def forward(ir_b: [Tensor], vi_b: [Tensor], mode='min-mean') -> [Tensor]:
b_1 = torch.min(ir_b[0], vi_b[0])
b_2 = torch.min(ir_b[1], vi_b[1])
b_3 = (ir_b[0] + vi_b[0] + b_1) / 3
b_4 = (ir_b[1] + vi_b[1] + b_2) / 3
return (b_1, b_2) if mode == 'min' else (b_3, b_4)
26 changes: 16 additions & 10 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from functions.feather_fuse import FeatherFuse
from models.attention import Attention
from models.constructor import Constructor
from models.extractor import Extractor
from models.fuse_dataset import FuseDataset
# load config
from tools.imsave import imsave

# load config
with open('config.yml', 'r') as f:
config = yaml.safe_load(f)

Expand All @@ -26,7 +27,7 @@

# load device config
cuda = config['environment']['cuda']
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# load snapshots folder
sf = config['test']['snapshots_folder']
Expand All @@ -38,7 +39,7 @@

# load extractor network
lpm = config['test']['load_pretrain_model'][0]
pm = os.path.join(sf, 'fuse_ext_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
pm = os.path.join(sf, 'epoch_ext_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
net_ext = Extractor()
net_ext = nn.DataParallel(net_ext)
net_ext.load_state_dict(torch.load(pm, map_location='cpu')) if pm else None
Expand All @@ -47,7 +48,7 @@

# load constructor network
lpm = config['test']['load_pretrain_model'][1]
pm = os.path.join(sf, 'fuse_con_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
pm = os.path.join(sf, 'epoch_con_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
net_con = Constructor()
net_con = nn.DataParallel(net_con)
net_con.load_state_dict(torch.load(pm, map_location='cpu')) if pm else None
Expand All @@ -56,7 +57,7 @@

# load attention network
lpm = config['test']['load_pretrain_model'][2]
pm = os.path.join(sf, 'fuse_att_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
pm = os.path.join(sf, 'epoch_att_{}.pth'.format(str(lpm).zfill(3))) if lpm else None
net_att = Attention()
net_att = nn.DataParallel(net_att)
net_att.load_state_dict(torch.load(pm, map_location='cpu')) if pm else None
Expand All @@ -69,9 +70,14 @@
data = FuseDataset(it, iz, cuda, True)
loader = DataLoader(data, 1, False)

# start test
tr = [] # time record
# softmax
sm = nn.Softmax(dim=1)

# load feather fuse
ff = FeatherFuse()

# start test
tr = []
with torch.no_grad():
for ir, vi, label in tqdm(loader):
st = time.time()
Expand All @@ -83,16 +89,16 @@
vi_att = net_att(vi)

fus_1 = ir_1 * ir_att + vi_1 * vi_att
fus_b_1 = ir_b_1 + vi_b_1
fus_b_2 = ir_b_2 + vi_b_2
fus_1 = sm(fus_1)
fus_b_1, fus_b_2 = ff((ir_b_1, ir_b_2), (vi_b_1, vi_b_2))

fus_2 = net_con(fus_1, fus_b_1, fus_b_2)

torch.cuda.synchronize()
et = time.time()
tr.append(et - st)

p = os.path.join(rf, 'FUS_{}.jpg'.format(label[0]))
p = os.path.join(rf, '{}.jpg'.format(label[0]))
imsave(p, fus_2)

# time record
Expand Down
Binary file added weights/epoch_att_495.pth
Binary file not shown.
Binary file added weights/epoch_con_495.pth
Binary file not shown.
Binary file added weights/epoch_ext_495.pth
Binary file not shown.

0 comments on commit ce85952

Please sign in to comment.