Skip to content

Commit

Permalink
feat+refactor: big changes
Browse files Browse the repository at this point in the history
additional SR and IC models, checkpoint storage, and more.
  • Loading branch information
chompaa committed Dec 1, 2023
1 parent ba192a8 commit aabad40
Show file tree
Hide file tree
Showing 23 changed files with 666 additions and 428 deletions.
10 changes: 8 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
datasets/
output/
./*.png
./*.jpg
weights/
slices/
channels/
_misc/
data/
*.png
*.jpg
*.bmp

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Clarisolve

> A DL-based super-resolution and colorization tool.
## Requirements

To install the required packages, you can run:

```shell
pip install -r requirements.txt
```

## Usage

### Evaluation

A GUI tool is provided to super-resolve and colorize images. To run, use:

```shell
python main.py
```

CLI options are also present, see the `super-resolve.py` and `colorize.py` files for
details.

### Training

You can train any model yourself using `train.py` as follows:

```shell
python train.py --model { "srcnn", "srcnnc", "srres", "iccnn", "icres" } \
--train-data TRAIN_DATA \
--eval-data EVAL_DATA \
--output-dir OUTPUT_DIR \
[--checkpoint-path CHECKPOINT_PATH] \
[--learn-rate LEARN_RATE] \
[--end-epoch END_EPOCH] \
[--num-workers NUM_WORKERS] \
[--seed SEED]
```
Note that for SR models, a `.h5` file is required for both datasets, and for IC, a
directory is required.
### Datasets
A utility script `util/make.py` is provided for `.h5` file creation.
Binary file removed butterfly.png
Binary file not shown.
Binary file removed butterfly_GT.bmp
Binary file not shown.
45 changes: 31 additions & 14 deletions ic.py → colorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,51 @@
import models


def ic(weights_file: str, output_folder: str, image_file: str):
def colorize(
model: torch.nn.Module, weights_file: str, output_folder: str, image_file: str
):
torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.ICCNN().to(device)
model = model.to(device)

state_dict = model.state_dict()

for n, p in torch.load(
weights_file, map_location=lambda storage, _: storage
).items():
for n, p in torch.load(weights_file, map_location=lambda storage, _: storage)[
"model"
].items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)

model.eval()
model.eval().to(device)

image_name = image_file.split("/")[-1].split("\\")[-1]

with PIL.Image.open(image_file).convert("RGB") as image:
gray_image = skimage.color.rgb2gray(image)

gray_image = torch.from_numpy(gray_image).unsqueeze(0).float()
gray_image = torch.from_numpy(gray_image).to(device).unsqueeze(0).float()
# gray_image = torch.unsqueeze(gray_image, dim=0)

# we don't need to calculate gradients during inference (speeds up computation)
with torch.no_grad():
predictions = model(gray_image).clamp(0.0, 1.0)

output = torch.cat((gray_image, predictions), 0).numpy()
predictions = model(gray_image).to(device).clamp(0.0, 1.0).squeeze()

# expand/reduce predictions dimensionality if it's bigger than the gray image
if (
predictions.shape[1] != gray_image.shape[1]
or predictions.shape[2] != gray_image.shape[2]
):
predictions = torch.nn.functional.interpolate(
predictions.unsqueeze(0),
size=(gray_image.shape[1], gray_image.shape[2]),
mode="bilinear",
align_corners=False,
).squeeze()

output = torch.cat((gray_image, predictions), 0).cpu().numpy()

output = output.transpose((1, 2, 0))
output[:, :, 0:1] = output[:, :, 0:1] * 100
Expand All @@ -50,15 +64,18 @@ def ic(weights_file: str, output_folder: str, image_file: str):
PIL.Image.fromarray((output * 255).astype(np.uint8)).save(
f"{output_folder}{image_name.replace('.', f'_colorized.')}"
)
PIL.Image.fromarray((gray_image.squeeze().numpy() * 255).astype(np.uint8)).save(
f"{output_folder}{image_name.replace('.', f'_gray.')}"
)
PIL.Image.fromarray(
(gray_image.squeeze().cpu().numpy() * 255).astype(np.uint8)
).save(f"{output_folder}{image_name.replace('.', f'_gray.')}")


if __name__ == "__main__":
model_list = models.IC_MODELS

parser = argparse.ArgumentParser()
parser.add_argument("--model", choices=model_list.keys(), required=True)
parser.add_argument("--weights-file", type=str, required=True)
parser.add_argument("--image-file", type=str, required=True)
args = parser.parse_args()

ic(args.weights_file, "", args.image_file)
colorize(model_list[args.model](), args.weights_file, "", args.image_file)
24 changes: 16 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
QWidget,
)

from ic import ic
from isr import isr
import models
from colorize import colorize
from super_resolve import super_resolve


class ScaleSelector(QGroupBox):
Expand Down Expand Up @@ -63,8 +64,6 @@ def on_selected(self):
if scale_button.isChecked():
self.scale = int(scale_button.text())

print(self.scale)


class FolderSelector(QGroupBox):
def __init__(self, title, default_path):
Expand Down Expand Up @@ -294,17 +293,26 @@ def enhance(self):

open_on_complete = self.open_on_complete.isChecked()

colorize = self.colorize_selector.get_state()
colorize_image = self.colorize_selector.get_state()

if colorize:
ic(sic_weights, output_folder, image_path)
if colorize_image:
colorize(
models.IC_MODELS["icres"](), sic_weights, output_folder, image_path
)
image_name, image_ext = os.path.splitext(os.path.basename(image_path))
image_path = os.path.join(
output_folder, f"{image_name}_colorized{image_ext}"
)

if scale != 1:
isr(sisr_weights, output_folder, image_path, scale, True)
super_resolve(
models.SR_MODELS["srcnn"](),
sisr_weights,
output_folder,
image_path,
scale,
True,
)

if not open_on_complete:
return
Expand Down
44 changes: 0 additions & 44 deletions models.py

This file was deleted.

16 changes: 16 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .iccnn import ICCNN
from .icres import ICRes
from .srcnn import SRCNN
from .srcnnc import SRCNNC
from .srres import SRRes

SR_MODELS = {
"srcnn": SRCNN,
"srcnnc": SRCNNC,
"srres": SRRes,
}

IC_MODELS = {
"iccnn": ICCNN,
"icres": ICRes,
}
38 changes: 38 additions & 0 deletions models/iccnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch


class ICCNN(torch.nn.Module):
def __init__(self):
super(ICCNN, self).__init__()

self.features = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=2),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
)

self.upsample = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2),
torch.nn.Conv2d(128, 64, kernel_size=3, padding=1, stride=1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(inplace=True),
torch.nn.Upsample(scale_factor=2),
torch.nn.Conv2d(64, 32, kernel_size=3, padding=1, stride=1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(inplace=True),
torch.nn.Upsample(scale_factor=2),
torch.nn.Conv2d(32, 2, kernel_size=3, padding=1, stride=1),
)

def forward(self, output):
output = self.features(output)
output = self.upsample(output)

return output

def __str__(self):
return "iccnn"
Loading

0 comments on commit aabad40

Please sign in to comment.