Skip to content

Commit

Permalink
testing autoencoder architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
brandicec06 committed Apr 26, 2021
1 parent db0e1b2 commit d80c7da
Show file tree
Hide file tree
Showing 246 changed files with 300 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
Binary file added data/.DS_Store
Binary file not shown.
Binary file not shown.
37 changes: 37 additions & 0 deletions do_not_use_load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from torch.utils.data import Dataset, DataLoader
from PIL import *

import os

class CustomDataSet(Dataset):
def __init__(self, main_dir, transform):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
# self.total_imgs = natsort.natsorted(all_imgs)
self.total_imgs = all_imgs

def __len__(self):
return len(self.total_imgs)

def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
return tensor_image
# return image


img_folder_path = "data/load_test"
batch_size = 10

my_dataset = CustomDataSet(img_folder_path, transform=transforms.ToTensor())
train_loader = DataLoader(my_dataset , batch_size=batch_size, shuffle=True)
print(my_dataset[1])
263 changes: 263 additions & 0 deletions network_test/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import numpy as np

from torch.utils.data import Dataset, DataLoader
from PIL import *

import os
import glob
import math
import random
import re

np.random.seed(2)


#############################
#### Reduce Image Size ######
#############################


img_folder_path = "small_train"
org_folder_path = "load_test"
org_folder_pathc1 = "load_test_c1"
org_folder_pathc2 = "load_test_c2"

def reduce_img(path, train_path, test_path):
image_list = []
for i, filename in enumerate(glob.glob(path + '/*.jpg')):

if filename.lower().endswith('.jpg'):
g = filename.split('.')[0]

img=Image.open(filename).convert("RGB")
new_name = filename.split('/')[1]

img_c1 = Image.open(filename).convert("RGB")

w,h = img.size
# print(w,h)
# img = Image.open(filename)
border = 230
x = 230
y = 230
w_ = 1328
h_ = 240
img = img.crop((x,y,x+w_,y+h_))
new_img = img.resize((math.ceil(w *.3), int(h*.3)))
if i % 7 == 0:
new_img.save(test_path + str(i) + "_" + new_name, optimize=True)
else:
new_img.save(train_path + new_name, optimize=True)



# reduce_img(org_folder_path, 'small_train/', 'small_test/')
# reduce_img(org_folder_pathc1, 'small_train_c1/', 'small_test_c1/')
# reduce_img(org_folder_pathc2, 'small_train_c2/', 'small_test_c2/')
# exit()

#############################
#############################



class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential( # like the Composition layer you built
nn.Conv2d(6, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 7)
)

self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
# nn.Sigmoid()
# nn.Tanh
)

def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x


#############################
###### Load Custom Data #####
#############################

class CustomDataSet(Dataset):
def __init__(self, main_dir,c1_dir, c2_dir, transform):
self.main_dir = main_dir
self.c1_dir = c1_dir
self.c2_dir = c2_dir

self.transform = transform
all_imgs = os.listdir(main_dir)
c1_imgs = os.listdir(c1_dir)
c2_imgs = os.listdir(c2_dir)

all_imgs = self.sort_img(all_imgs)
c1_imgs = self.sort_img(c1_imgs)
c2_imgs = self.sort_img(c2_imgs)



self.total_imgs = all_imgs
self.c1_imgs = c1_imgs
self.c2_imgs = c2_imgs

def sort_img(self, l):
keys = []
names = []
for j, i in enumerate(l):
k = i.split('_')[0]
keys.append(int(k))
names.append(i)

_, sorted_list = (list(t) for t in zip(*sorted(zip(keys, names))))
# print(sorted_list[:5])
return sorted_list

def __len__(self):
return len(self.total_imgs)


def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
img_loc_c1 = os.path.join(self.c1_dir, self.c1_imgs[idx])
img_loc_c2 = os.path.join(self.c2_dir, self.c2_imgs[idx])

image = Image.open(img_loc).convert("RGB")
image_c1 = Image.open(img_loc_c1).convert("RGB")
image_c2 = Image.open(img_loc_c2).convert("RGB")

x = torch.cat((self.transform(image_c1), self.transform(image_c2)), dim=0)

y = self.transform(image)
return x, y


## Train Data folders
img_folder_path = "small_train"
c1_folder_path = "small_train_c1"
c2_folder_path = "small_train_c2"

## Test Data Folders
img_folder_path_test = "small_test"
c1_folder_path_test = "small_test_c1"
c2_folder_path_test = "small_test_c2"

train_data = CustomDataSet(img_folder_path, c1_folder_path, c2_folder_path, transform=transforms.ToTensor())
test_data = CustomDataSet(img_folder_path_test, c1_folder_path_test, c2_folder_path_test, transform=transforms.ToTensor())



#########################################
########### Train Model #################
#########################################


def train(model, num_epochs=5, batch_size=4, learning_rate=1e-3):
torch.manual_seed(42)
criterion = nn.MSELoss() # mean square error loss

optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=1e-5) # <--

train_loader = DataLoader(train_data , batch_size=batch_size, shuffle=True)

outputs = []
for epoch in range(num_epochs):
for data in train_loader:

img, targ = data

recon = model.forward(img)

loss = criterion(recon, targ)
loss.backward()
optimizer.step()
optimizer.zero_grad()

print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
# outputs.append((epoch, img, recon),)
outputs.append((epoch, targ, recon),)
return outputs


model = Autoencoder()
max_epochs = 30
outputs = train(model, num_epochs=max_epochs)
# print(len(outputs))
for k in range(0, max_epochs, 5):
# print(k)
# print(outputs[k][1])
plt.figure(figsize=(9, 2))
imgs = outputs[k][1].detach().numpy()
recon = outputs[k][2].detach().numpy()
# print(len(imgs))
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
plt.imshow(item[0], cmap='jet')


for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1)
plt.imshow(item[0], cmap = 'jet')
plt.show()


#########################################
############ Test Model #################
#########################################
print('Viewing Test Images')
def test(model, batch_size=64, learning_rate=1e-3):
torch.manual_seed(42)

test_loader = torch.utils.data.DataLoader(test_data,
batch_size=batch_size,
shuffle=True)
outputs = []

for data in test_loader:
img, _ = data
recon = model(img)

outputs.append((img, recon),)
return outputs


outputs_test = test(model)


for k in range(0, len(outputs_test), 5):
plt.figure(figsize=(9, 2))
imgs = outputs_test[k][0].detach().numpy()
recon = outputs_test[k][1].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
plt.imshow(item[0])


for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1)
plt.imshow(item[0])
plt.show()
Binary file added network_test/load_test/10_solar_east.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/11_solar_north 2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/12_solar_north 3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/13_solar_north 4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/14_solar_north 5.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/15_solar_north 6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/16_solar_north 7.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/17_solar_north 8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/18_solar_north 9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/19_solar_north 10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/1_solar_east 2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/20_solar_north.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/21_solar_south 2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/22_solar_south 3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/23_solar_south 4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/24_solar_south 5.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/25_solar_south 6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/26_solar_south 7.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/27_solar_south 8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/28_solar_south 9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/29_solar_south 10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/2_solar_east 3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/30_solar_south.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/31_solar_west 2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/32_solar_west 3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added network_test/load_test/33_solar_west 4.jpg
Binary file added network_test/load_test/34_solar_west 5.jpg
Binary file added network_test/load_test/35_solar_west 6.jpg
Binary file added network_test/load_test/36_solar_west 7.jpg
Binary file added network_test/load_test/37_solar_west 8.jpg
Binary file added network_test/load_test/38_solar_west 9.jpg
Binary file added network_test/load_test/39_solar_west 10.jpg
Binary file added network_test/load_test/3_solar_east 4.jpg
Binary file added network_test/load_test/40_solar_west.jpg
Binary file added network_test/load_test/4_solar_east 5.jpg
Binary file added network_test/load_test/5_solar_east 6.jpg
Binary file added network_test/load_test/6_solar_east 7.jpg
Binary file added network_test/load_test/7_solar_east 8.jpg
Binary file added network_test/load_test/8_solar_east 9.jpg
Binary file added network_test/load_test/9_solar_east 10.jpg
Binary file added network_test/load_test_c1/.DS_Store
Binary file not shown.
Binary file added network_test/load_test_c1/10_depth_east.jpg
Binary file added network_test/load_test_c1/11_depth_north 2.jpg
Binary file added network_test/load_test_c1/12_depth_north 3.jpg
Binary file added network_test/load_test_c1/13_depth_north 4.jpg
Binary file added network_test/load_test_c1/14_depth_north 5.jpg
Binary file added network_test/load_test_c1/15_depth_north 6.jpg
Binary file added network_test/load_test_c1/16_depth_north 7.jpg
Binary file added network_test/load_test_c1/17_depth_north 8.jpg
Binary file added network_test/load_test_c1/18_depth_north 9.jpg
Binary file added network_test/load_test_c1/19_depth_north 10.jpg
Binary file added network_test/load_test_c1/1_depth_east 2.jpg
Binary file added network_test/load_test_c1/20_depth_north.jpg
Binary file added network_test/load_test_c1/21_depth_south 2.jpg
Binary file added network_test/load_test_c1/22_depth_south 3.jpg
Binary file added network_test/load_test_c1/23_depth_south 4.jpg
Binary file added network_test/load_test_c1/24_depth_south 5.jpg
Binary file added network_test/load_test_c1/25_depth_south 6.jpg
Binary file added network_test/load_test_c1/26_depth_south 7.jpg
Binary file added network_test/load_test_c1/27_depth_south 8.jpg
Binary file added network_test/load_test_c1/28_depth_south 9.jpg
Binary file added network_test/load_test_c1/29_depth_south 10.jpg
Binary file added network_test/load_test_c1/2_depth_east 3.jpg
Binary file added network_test/load_test_c1/30_depth_south.jpg
Binary file added network_test/load_test_c1/31_depth_west 2.jpg
Binary file added network_test/load_test_c1/32_depth_west 3.jpg
Binary file added network_test/load_test_c1/33_depth_west 4.jpg
Binary file added network_test/load_test_c1/34_depth_west 5.jpg
Binary file added network_test/load_test_c1/35_depth_west 6.jpg
Binary file added network_test/load_test_c1/36_depth_west 7.jpg
Binary file added network_test/load_test_c1/37_depth_west 8.jpg
Binary file added network_test/load_test_c1/38_depth_west 9.jpg
Binary file added network_test/load_test_c1/39_depth_west 10.jpg
Binary file added network_test/load_test_c1/3_depth_east 4.jpg
Binary file added network_test/load_test_c1/40_depth_west.jpg
Binary file added network_test/load_test_c1/4_depth_east 5.jpg
Binary file added network_test/load_test_c1/5_depth_east 6.jpg
Binary file added network_test/load_test_c1/6_depth_east 7.jpg
Binary file added network_test/load_test_c1/7_depth_east 8.jpg
Binary file added network_test/load_test_c1/8_depth_east 9.jpg
Binary file added network_test/load_test_c1/9_depth_east 10.jpg
Binary file added network_test/small_test/0_33_solar_west 4.jpg
Binary file added network_test/small_test/14_9_solar_east 10.jpg
Binary file added network_test/small_test/21_28_solar_south 9.jpg
Binary file added network_test/small_test/28_14_solar_north 5.jpg
Binary file added network_test/small_test/35_39_solar_west 10.jpg
Binary file added network_test/small_test/7_29_solar_south 10.jpg
Binary file added network_test/small_test_c1/0_34_depth_west 5.jpg
Binary file added network_test/small_test_c1/14_31_depth_west 2.jpg
Binary file added network_test/small_test_c1/28_5_depth_east 6.jpg
Binary file added network_test/small_train/10_solar_east.jpg
Binary file added network_test/small_train/11_solar_north 2.jpg
Binary file added network_test/small_train/12_solar_north 3.jpg
Binary file added network_test/small_train/13_solar_north 4.jpg
Binary file added network_test/small_train/15_solar_north 6.jpg
Binary file added network_test/small_train/16_solar_north 7.jpg
Binary file added network_test/small_train/17_solar_north 8.jpg
Binary file added network_test/small_train/18_solar_north 9.jpg
Binary file added network_test/small_train/19_solar_north 10.jpg
Binary file added network_test/small_train/1_solar_east 2.jpg
Binary file added network_test/small_train/20_solar_north.jpg
Binary file added network_test/small_train/21_solar_south 2.jpg
Binary file added network_test/small_train/22_solar_south 3.jpg
Binary file added network_test/small_train/23_solar_south 4.jpg
Binary file added network_test/small_train/24_solar_south 5.jpg
Binary file added network_test/small_train/25_solar_south 6.jpg
Binary file added network_test/small_train/26_solar_south 7.jpg
Binary file added network_test/small_train/27_solar_south 8.jpg
Binary file added network_test/small_train/2_solar_east 3.jpg
Binary file added network_test/small_train/30_solar_south.jpg
Binary file added network_test/small_train/31_solar_west 2.jpg
Binary file added network_test/small_train/32_solar_west 3.jpg
Binary file added network_test/small_train/34_solar_west 5.jpg
Binary file added network_test/small_train/35_solar_west 6.jpg
Binary file added network_test/small_train/36_solar_west 7.jpg
Binary file added network_test/small_train/37_solar_west 8.jpg
Binary file added network_test/small_train/38_solar_west 9.jpg
Binary file added network_test/small_train/3_solar_east 4.jpg
Binary file added network_test/small_train/40_solar_west.jpg
Binary file added network_test/small_train/4_solar_east 5.jpg
Binary file added network_test/small_train/5_solar_east 6.jpg
Binary file added network_test/small_train/6_solar_east 7.jpg
Binary file added network_test/small_train/7_solar_east 8.jpg
Binary file added network_test/small_train/8_solar_east 9.jpg
Binary file added network_test/small_train_c1/10_depth_east.jpg
Binary file added network_test/small_train_c1/11_depth_north 2.jpg
Binary file added network_test/small_train_c1/12_depth_north 3.jpg
Binary file added network_test/small_train_c1/13_depth_north 4.jpg
Binary file added network_test/small_train_c1/15_depth_north 6.jpg
Binary file added network_test/small_train_c1/18_depth_north 9.jpg
Binary file added network_test/small_train_c1/19_depth_north 10.jpg
Binary file added network_test/small_train_c1/1_depth_east 2.jpg
Binary file added network_test/small_train_c1/20_depth_north.jpg
Binary file added network_test/small_train_c1/21_depth_south 2.jpg
Binary file added network_test/small_train_c1/22_depth_south 3.jpg
Binary file added network_test/small_train_c1/23_depth_south 4.jpg
Binary file added network_test/small_train_c1/24_depth_south 5.jpg
Binary file added network_test/small_train_c1/25_depth_south 6.jpg
Binary file added network_test/small_train_c1/26_depth_south 7.jpg
Binary file added network_test/small_train_c1/27_depth_south 8.jpg
Binary file added network_test/small_train_c1/28_depth_south 9.jpg
Binary file added network_test/small_train_c1/29_depth_south 10.jpg
Binary file added network_test/small_train_c1/2_depth_east 3.jpg
Binary file added network_test/small_train_c1/30_depth_south.jpg
Binary file added network_test/small_train_c1/32_depth_west 3.jpg
Binary file added network_test/small_train_c1/33_depth_west 4.jpg
Binary file added network_test/small_train_c1/35_depth_west 6.jpg
Binary file added network_test/small_train_c1/36_depth_west 7.jpg
Binary file added network_test/small_train_c1/37_depth_west 8.jpg
Binary file added network_test/small_train_c1/38_depth_west 9.jpg
Binary file added network_test/small_train_c1/39_depth_west 10.jpg
Binary file added network_test/small_train_c1/3_depth_east 4.jpg
Binary file added network_test/small_train_c1/40_depth_west.jpg
Binary file added network_test/small_train_c1/4_depth_east 5.jpg
Binary file added network_test/small_train_c1/6_depth_east 7.jpg
Binary file added network_test/small_train_c1/7_depth_east 8.jpg
Binary file added network_test/small_train_c1/8_depth_east 9.jpg
Binary file added network_test/small_train_c1/9_depth_east 10.jpg

0 comments on commit d80c7da

Please sign in to comment.