diff --git a/Colab_Notebooks/Subwaystation_Segmentation.ipynb b/Colab_Notebooks/Subwaystation_Segmentation.ipynb deleted file mode 100644 index 679b91a..0000000 --- a/Colab_Notebooks/Subwaystation_Segmentation.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Subwaystation_Segmentation.ipynb","provenance":[],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"Zqj7nzr_excp","colab_type":"code","colab":{}},"source":["import torch\n","import torchvision\n","from torchvision import transforms\n","from torchvision.transforms import ToTensor\n","from PIL import Image\n","from os import listdir\n","import random\n","import torch.optim as optim\n","from torch.autograd import Variable\n","import torch.nn.functional as F\n","import torch.nn as nn\n","import random\n","\n","import numpy as np\n","from scipy import misc\n","from PIL import Image\n","import glob\n","import imageio\n","import os\n","\n","import cv2\n","\n","import matplotlib.pyplot as plt\n","\n","from google.colab import files"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Sfpmw5KwfzUd","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":131},"executionInfo":{"status":"ok","timestamp":1599816353396,"user_tz":-120,"elapsed":53626,"user":{"displayName":"Martin Ludwig","photoUrl":"","userId":"16164250748375470154"}},"outputId":"853c849d-d357-450f-a38e-0800a351c5f6"},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code\n","\n","Enter your authorization code:\n","4/4AFLMe_6U04KYq6w0iwts1ktjW21MgSwsEayjD_ECQsC50DnhLkH0Rs\n","Mounted at /content/drive\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"3FVW8Z_EvmTY","colab_type":"code","colab":{}},"source":["class SegNet(nn.Module):\n"," \"\"\"neural network architecture inspired by SegNet\"\"\"\n","\n"," def __init__(self):\n"," super(SegNet, self).__init__()\n"," \n"," #Encoder\n"," self.conv1 = nn.Conv2d(3, 64, (3,3), padding=1)\n"," self.conv2 = nn.Conv2d(64, 64, (3,3), padding=1)\n"," self.enc1_bn = nn.BatchNorm2d(64)\n"," self.maxpool1 = nn.MaxPool2d(2,2)\n","\n"," self.conv3 = nn.Conv2d(64, 128, (3,3), padding=1)\n"," self.conv4 = nn.Conv2d(128, 128, (3,3), padding=1)\n"," self.enc2_bn = nn.BatchNorm2d(128)\n"," self.maxpool2 = nn.MaxPool2d((2,2),2)\n","\n"," self.conv5 = nn.Conv2d(128, 256, (3,3), padding=1)\n"," self.conv6 = nn.Conv2d(256, 256, (3,3), padding=1)\n"," self.conv7 = nn.Conv2d(256, 256, (3,3), padding=1)\n"," self.enc3_bn = nn.BatchNorm2d(256)\n"," self.maxpool3 = nn.MaxPool2d((2,2),2)\n","\n"," self.conv8 = nn.Conv2d(256, 512, (3,3), padding=1)\n"," self.conv9 = nn.Conv2d(512, 512, (3,3), padding=1)\n"," self.conv10 = nn.Conv2d(512, 512, (3,3), padding=1)\n"," self.enc4_bn = nn.BatchNorm2d(512)\n"," self.maxpool4 = nn.MaxPool2d((2,2),2)\n","\n"," self.conv11 = nn.Conv2d(512, 512, (3,3), padding=1)\n"," self.conv12 = nn.Conv2d(512, 512, (3,3), padding=1)\n"," self.conv13 = nn.Conv2d(512, 512, (3,3), padding=1)\n"," self.enc5_bn = nn.BatchNorm2d(512)\n"," self.maxpool5 = nn.MaxPool2d((2,2),2)\n","\n"," #Decoder\n"," self.upsample1 = nn.Upsample(scale_factor=2)\n"," self.conv14 = nn.Conv2d(512,512, (3,3), padding=1)\n"," self.conv15 = nn.Conv2d(512,512, (3,3), padding=1)\n"," self.conv16 = nn.Conv2d(512,512, (3,3), padding=1)\n"," self.dec1_bn = nn.BatchNorm2d(512)\n","\n"," self.upsample2 = nn.Upsample(scale_factor=2)\n"," self.conv17 = nn.Conv2d(512,512, (3,3), padding=1)\n"," self.conv18 = nn.Conv2d(512,512, (3,3), padding=1)\n"," self.conv19 = nn.Conv2d(512,256, (3,3), padding=1)\n"," self.dec2_bn = nn.BatchNorm2d(256)\n","\n"," self.upsample3 = nn.Upsample(scale_factor=2)\n"," self.conv20 = nn.Conv2d(256,256, (3,3), padding=1)\n"," self.conv21 = nn.Conv2d(256,256, (3,3), padding=1)\n"," self.conv22 = nn.Conv2d(256,128, (3,3), padding=1)\n"," self.dec3_bn = nn.BatchNorm2d(128)\n","\n"," self.upsample4 = nn.Upsample(scale_factor=2)\n"," self.conv23 = nn.Conv2d(128,128, (3,3), padding=1)\n"," self.conv24 = nn.Conv2d(128,64, (3,3), padding=1)\n"," self.dec4_bn = nn.BatchNorm2d(64)\n","\n"," self.upsample5 = nn.Upsample(scale_factor=2)\n"," self.conv25 = nn.Conv2d(64,64, (3,3), padding=1)\n"," self.conv26 = nn.Conv2d(64,5, (3,3), padding=1)\n","\n"," self.softmax = nn.Softmax(dim=1)\n"," \n","\n"," def forward(self, x):\n"," #Encoder\n"," x = F.relu(self.enc1_bn(self.conv2(self.conv1(x))))\n"," print(x.size())\n"," x = self.maxpool1(x)\n"," print(x.size())\n","\n"," x = F.relu(self.enc2_bn(self.conv4(self.conv3(x))))\n"," print(x.size())\n"," x = self.maxpool2(x)\n"," print(x.size())\n"," \n"," x = F.relu(self.enc3_bn(self.conv7(self.conv6(self.conv5(x)))))\n"," print(x.size())\n"," x = self.maxpool3(x)\n"," print(x.size())\n","\n"," x = F.relu(self.enc4_bn(self.conv10(self.conv9(self.conv8(x)))))\n"," print(x.size())\n"," x = self.maxpool4(x)\n"," print(x.size())\n","\n"," x = F.relu(self.enc5_bn(self.conv13(self.conv12(self.conv11(x)))))\n"," print(x.size())\n"," x = self.maxpool5(x)\n"," print(x.size())\n","\n"," print()\n"," #Decoder\n"," x = F.relu(self.dec1_bn(self.conv16(self.conv15(self.conv14(self.upsample1(x))))))\n"," print(x.size())\n"," x = F.relu(self.dec2_bn(self.conv19(self.conv18(self.conv17(self.upsample2(x))))))\n"," print(x.size())\n"," x = F.relu(self.dec3_bn(self.conv22(self.conv21(self.conv20(self.upsample3(x))))))\n"," print(x.size())\n"," x = F.relu(self.dec4_bn(self.conv24(self.conv23(self.upsample4(x)))))\n"," print(x.size())\n"," x = self.conv26(self.conv25(self.upsample4(x)))\n"," print(x.size())\n"," \n"," return x"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"HO6IxRmVeiFj","colab_type":"code","colab":{}},"source":["def create_data(data_start, data_size, batch_size, input_path, target_path, target_dict, real_sequence, is_train):\n"," \"\"\"create data for training/validation from img and xml to tensor\"\"\"\n","\n"," transform = transforms.Compose([transforms.Resize((320, 576)),\n"," transforms.ToTensor()])\n","\n"," input_list = []\n"," target_list = []\n"," data = []\n","\n"," weights = [0,0,0,0,0] #weights for cross entropy loss\n","\n"," pixel_class = [] #single pixel class \n","\n"," inputs = os.listdir(input_path)\n"," inputs.sort()\n","\n"," targets = os.listdir(target_path) \n"," targets.sort()\n"," \n"," \n","\n"," for x in range(data_start, data_size):\n","\n"," if(len(real_sequence) == 0):\n"," break\n","\n"," #print(\"len sequence\",len(real_sequence))\n","\n"," index = random.choice(real_sequence)\n"," real_sequence.remove(index)\n"," \n"," print(x)\n","\n","\n"," #if(len(data) == 8 and not is_train):\n"," # break\n","\n"," #if(len(data) == 4):\n"," # break\n","\n"," input = Image.open(input_path + inputs[index])\n"," input_list.append(transform(input))\n"," #input_list.append(ToTensor()(input))\n","\n"," target = Image.open(target_path + targets[index])\n"," target_tensor = torch.round(transform(target))\n"," #target_tensor = torch.round(ToTensor()(target))\n","\n"," if (is_train):\n"," target_tensor_final = torch.zeros(320,576, dtype=torch.long) #cross entropy loss allowed only torch.long\n"," else:\n"," target_tensor_final = torch.zeros(5,320,576, dtype=torch.long)\n","\n"," for i in range(320):\n"," for j in range(576):\n"," pixel_class = target_dict[tuple(target_tensor[:,i,j].tolist())]\n"," \n"," #print(\"pixel class\", pixel_class)\n"," #print(\"tensor\", torch.tensor(pixel_class, dtype=torch.long))\n"," #print(\"target size\", target_tensor_final.size())\n"," \n"," if (is_train):\n"," weights[pixel_class] += 1\n"," target_tensor_final[i,j] = torch.tensor(pixel_class, dtype=torch.long)\n"," else:\n"," target_tensor_final[:,i,j] = torch.tensor(pixel_class, dtype=torch.long)\n"," weights[pixel_class.index(1)] += 1\n","\n"," target_list.append(target_tensor_final)\n","\n"," if len(input_list) >= batch_size:\n"," data.append((torch.stack(input_list), torch.stack(target_list)))\n"," \n"," input_list = []\n"," target_list = []\n","\n"," print('Loaded batch ', len(data), 'of ', int(len(inputs) / batch_size))\n"," print('Percentage Done: ',\n"," 100 * (len(data) / int(len(inputs) / batch_size)), '%')\n","\n"," weights = torch.tensor(weights, dtype=torch.float64)\n"," #weights = 1/(weights/weights.min()) #press weights in [0,1], with maximum value for each class \n"," return data, weights"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"EviGvOgaX8r_","colab_type":"code","colab":{}},"source":["def train(train_data, model, optimizer, criterion, device):\n"," \"\"\"\n"," Trains/updates the model for one epoch on the training dataset.\n","\n"," Parameters:\n"," train_data (torch tensor): trainset\n"," model (torch.nn.module): Model to be trained\n"," optimizer (torch.optim.optimizer): optimizer instance like SGD or Adam\n"," criterion (torch.nn.modules.loss): loss function like CrossEntropyLoss\n"," device (string): cuda or cpu\n"," \"\"\"\n","\n"," # switch to train mode\n"," model.train()\n","\n"," # iterate through the dataset loader\n"," i = 0\n"," losses = []\n"," for (inp, target) in train_data:\n"," \n"," # transfer inputs and targets to the GPU (if it is available)\n"," inp = inp.to(device)\n"," target = target.to(device)\n"," \n"," # compute output, i.e. the model forward\n"," output = model(inp)\n"," \n"," # calculate the loss\n"," loss = criterion(output, target)\n"," #print(\"loss\", loss)\n"," losses.append(loss)\n"," \n"," \n"," print(\"loss {:.2}\".format(loss))\n"," # compute gradient and do the SGD step\n"," # we reset the optimizer with zero_grad to \"flush\" former gradients\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," \n"," avg_loss = torch.mean(torch.stack(losses)).item()\n"," print(\"avg.loss {:.2}\".format(avg_loss))\n"," return losses"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"utLbMONSGLCB","colab_type":"code","colab":{}},"source":["def calc_accuracy(output, target):\n"," \"\"\"calculate accuracy from tensor(b,c,x,y) for every category c\"\"\"\n"," accs = []\n"," acc_tensor = (output == target).int()\n"," for c in range(target.size(1)):\n"," correct_num = acc_tensor[:,c].sum().item() #item convert tensor in integer\n"," #print(correct_num)\n"," total_num = acc_tensor[:,c].numel()\n"," #print(total_num)\n"," accs.append(correct_num/total_num)\n"," return accs"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OqKInFHXGYNP","colab_type":"code","colab":{}},"source":["def calc_precision(output, target):\n"," \"\"\"calculate precision from tensor(b,c,x,y) for every category c\"\"\"\n","\n"," precs = []\n"," for c in range(target.size(1)):\n"," true_positives = ((output[:,c] - (output[:,c] != 1).int()) == target[:,c]).int().sum().item()\n"," #print(true_positives)\n"," false_positives = ((output[:,c] - (output[:,c] != 1).int()) == (target[:,c] != 1).int()).int().sum().item()\n"," #print(false_positives)\n","\n"," if(true_positives == 0):\n"," precs.append(1.0)\n"," else:\n"," precs.append(true_positives / (true_positives + false_positives))\n"," \n"," return precs"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"LK8pziSIGsuY","colab_type":"code","colab":{}},"source":["def calc_recall(output, target):\n"," \"\"\"calculate recall from tensor(b,c,x,y) for every category c\"\"\"\n"," \n"," recs = []\n"," for c in range(target.size(1)):\n"," relevants = (target[:,c] == 1).int().sum().item()\n"," #print(relevants)\n"," true_positives = ((output[:,c] - (output[:,c] != 1).int()) == target[:,c]).int().sum().item()\n"," #print(true_positives)\n"," \n"," if (relevants == 0):\n"," recs.append(1.0)\n"," else:\n"," recs.append(true_positives/relevants)\n"," \n"," return recs"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9RubCUwwG5Vm","colab_type":"code","colab":{}},"source":["def convert_to_one_hot(tensor, device):\n"," \"\"\"converts a tensor from size (b,c,x,y) to (b,c,x,y) one hot tensor for c categorys\"\"\"\n","\n"," for i in range(tensor.size(0)):\n"," max_idx = torch.argmax(tensor[i], 0, keepdim=True)\n"," one_hot = torch.FloatTensor(tensor[i].shape).to(device)\n"," one_hot.zero_()\n"," tensor[i] = one_hot.scatter_(0, max_idx, 1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wy8MoSpHBD3D","colab_type":"code","colab":{}},"source":["def validate(val_dataset, model, device, categories):\n"," \"\"\"\n"," validate the model with some validationfunctions on the test/validation dataset.\n","\n"," Parameters:\n"," val_data (torch tensor): test/validation dataset\n"," model (torch.nn.module): Model to be trained\n"," loss (torch.nn.modules.loss): loss function like CrossEntropyLoss\n"," device (string): cuda or cpu\n"," categories (list): names of categories\n"," \"\"\"\n"," model.eval()\n","\n"," # avoid computation of gradients and necessary storing of intermediate layer activations\n"," with torch.no_grad():\n","\n"," accs_avg = [0,0,0,0,0]\n"," precs_avg = [0,0,0,0,0]\n"," recs_avg = [0,0,0,0,0]\n"," counter = 0\n"," \n","\n"," for (inp, target) in val_dataset:\n"," # transfer to device\n"," inp = inp.to(device)\n"," target = target.to(device)\n","\n"," # compute output\n"," output = model(inp)\n","\n"," #print(\"before extra softmax\")\n"," #print(output[:,:,100,100])\n","\n"," output = model.softmax(output)\n"," #print(\"after extra softmax\")\n"," #print(output[:,:,100,100])\n","\n"," # convert from probabilities to one hot vectors\n"," convert_to_one_hot(output, device)\n","\n"," #print(\"after convert to one hot\")\n"," #print(output[:,:,100,100])\n","\n"," accs = calc_accuracy(output, target)\n"," precs = calc_precision(output, target)\n"," recs = calc_recall(output, target) \n","\n"," #print(\"loss {:.2} IOU {:.2}\".format(loss,iou))\n"," \n"," for i in range(len(categories)):\n"," print(\"category {:10} accuracy {:.2} precision {:.2} recall {:.2} \".format(categories[i], accs[i], precs[i], recs[i]))\n"," accs_avg[i] += accs[i]\n"," precs_avg[i] += precs[i] \n"," recs_avg[i] += recs[i]\n"," \n"," print()\n"," counter += 1\n","\n"," for i in range(len(categories)):\n"," accs_avg[i] /= counter\n"," precs_avg[i] /= counter\n"," recs_avg[i] /= counter\n"," \n"," print(\"avg.category {:10} accuracy {:.2} precision {:.2} recall {:.2} \".format(categories[i], accs_avg[i], precs_avg[i], recs_avg[i]))\n","\n"," return [accs_avg, precs_avg, recs_avg]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"zHj4zDuSR7Fv","colab_type":"code","colab":{}},"source":["def create_rgb_output(data, model, device, dict_reverse):\n"," \"\"\"create rgb pictures from model output for data (rgb-image) on device\n"," parameter:\n"," data: torch.tensor (b,3,x,y)\n"," model: torch#######################################################################\n","\n"," \"\"\"\n"," output = model(data.to(device))\n"," final_output = model.softmax(output)\n"," convert_to_one_hot(final_output, device)\n","\n"," real_output_tensor = torch.zeros(data.size(0),3,data.size(2), data.size(3), dtype=torch.float64)\n","\n"," for x in range(data.size(0)):\n"," for i in range(data.size(2)):\n"," for j in range(data.size(3)):\n"," real_output_tensor[x][:,i,j] = torch.tensor(dict_reverse[tuple(final_output[x,:,i,j].tolist())])\n","\n"," return real_output_tensor"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"jLo8TiB9R_UX","colab_type":"code","colab":{}},"source":["def plot_tensor(tensor):\n"," \"\"\"plot tensor(3,x,y) as rgb-image\"\"\"\n","\n"," plt.imshow(tensor.permute(1,2,0))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xOEYnrGpZlto","colab_type":"code","colab":{}},"source":["input_path = 'drive/My Drive/Data/Input/' #file for input images\n","target_path = 'drive/My Drive/Data/Target/' #file for target images\n","\n","batch_size = 8\n","\n","#for creating rgb pixel to class category (one_hot)\n","dict_val = {(0.0, 0.0, 0.0): (0.0, 1.0, 0.0, 0.0, 0.0), #black\n"," (0.0, 0.0, 1.0): (0.0, 1.0, 0.0, 0.0, 0.0), #black (fail)\n"," (0.0, 1.0, 0.0): (0.0, 0.0, 1.0, 0.0, 0.0), #green\n"," (0.0, 1.0, 1.0): (1.0, 0.0, 0.0, 0.0, 0.0), #white (fail)\n"," (1.0, 0.0, 0.0): (0.0, 0.0, 0.0, 1.0, 0.0), #red\n"," (1.0, 0.0, 1.0): (1.0, 0.0, 0.0, 0.0, 0.0), #white (fail)\n"," (1.0, 1.0, 0.0): (0.0, 0.0, 0.0, 0.0, 1.0), #yellow\n"," (1.0, 1.0, 1.0): (1.0, 0.0, 0.0, 0.0, 0.0)} #white\n","\n","#for making model output to real output\n","dict_reverse = {(0.0, 1.0, 0.0, 0.0, 0.0) : (0.0, 0.0, 0.0), #black\n"," (0.0, 0.0, 1.0, 0.0, 0.0) : (0.0, 1.0, 0.0), #green\n"," (0.0, 0.0, 0.0, 1.0, 0.0) : (1.0, 0.0, 0.0), #red\n"," (0.0, 0.0, 0.0, 0.0, 1.0) : (1.0, 1.0, 0.0), #yellow\n"," (1.0, 0.0, 0.0, 0.0, 0.0) : (1.0, 1.0, 1.0)} #white\n","\n","#for creating rgb pixel to class category (single value, cross entropyloss only allows single value)\n","dict_train = {(0.0, 0.0, 0.0): 1, #black\n"," (0.0, 0.0, 1.0): 1, #black (fail)\n"," (0.0, 1.0, 0.0): 2, #green\n"," (0.0, 1.0, 1.0): 0, #white (fail)\n"," (1.0, 0.0, 0.0): 3, #red\n"," (1.0, 0.0, 1.0): 0, #white (fail)\n"," (1.0, 1.0, 0.0): 4, #yellow\n"," (1.0, 1.0, 1.0): 0} #white\n","\n","categories = [\"white\", \"black\", \"green\", \"red\", \"yellow\"]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"HLcJmU26-OAs","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":1000},"outputId":"53a2c0c6-3344-4d96-cf38-23d5468e50d5"},"source":["real_sequence = list(range(len(os.listdir(input_path)))) #create a list from [0,...,number of input pictures] !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n","\n","indices = [i*2000 for i in range(21)] #size of train tensors always has to be rejusted !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n","\n","\n","for i in range(1,len(indices)): \n"," train_data, weights = create_data(indices[i-1],indices[i],batch_size, input_path, target_path, dict_train, real_sequence, True)\n"," torch.save(train_data, content_path + \"Train_Tensor\" + str(i) + \".pt\")\n"," torch.save(weights, content_path + \"Train_Weights\" + str(i) + \".pt\")\n","\n","real_sequence = list(range(len(os.listdir(input_path))))\n","val_data, _ = create_data(0,1000, batch_size, input_path, target_path, dict_val, real_sequence, False) #always has to be rejustec\n","torch.save(val_data, content_path + \"Val_Tensor_Test.pt\")\n"],"execution_count":null,"outputs":[{"output_type":"stream","text":["0\n","1\n","2\n","3\n","4\n","5\n","6\n","7\n","Loaded batch 1 of 5\n","Percentage Done: 20.0 %\n","8\n","9\n","10\n","11\n","12\n","13\n","14\n","15\n","Loaded batch 2 of 5\n","Percentage Done: 40.0 %\n","16\n","17\n","18\n","19\n","20\n","21\n","22\n","23\n","Loaded batch 3 of 5\n","Percentage Done: 60.0 %\n","24\n","25\n","26\n","27\n","28\n","29\n","30\n","31\n","Loaded batch 4 of 5\n","Percentage Done: 80.0 %\n","32\n","33\n","34\n","35\n","36\n","37\n","38\n","39\n","Loaded batch 5 of 5\n","Percentage Done: 100.0 %\n","0\n","1\n","2\n","3\n","4\n","5\n","6\n","7\n","Loaded batch 1 of 5\n","Percentage Done: 20.0 %\n","8\n","9\n","10\n","11\n","12\n","13\n","14\n","15\n","Loaded batch 2 of 5\n","Percentage Done: 40.0 %\n","16\n","17\n","18\n","19\n","20\n","21\n","22\n","23\n","Loaded batch 3 of 5\n","Percentage Done: 60.0 %\n","24\n","25\n","26\n","27\n","28\n","29\n","30\n","31\n","Loaded batch 4 of 5\n","Percentage Done: 80.0 %\n","32\n","33\n","34\n","35\n","36\n","37\n","38\n","39\n","Loaded batch 5 of 5\n","Percentage Done: 100.0 %\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"NVLDwkBugDU_","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"f2VEpVq0ZYD-","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":884},"outputId":"b1f537a1-31e5-4d92-b994-ff158b26f458"},"source":["# set a boolean flag that indicates whether a cuda capable GPU is available\n","# we will need this for transferring our tensors to the device and\n","# for persistent memory in the data loader\n","is_gpu = torch.cuda.is_available()\n","print(\"GPU is available:\", is_gpu)\n","print(\"If you are receiving False, try setting your runtime to GPU\")\n","\n","# set the device to cuda if a GPU is available\n","device = torch.device(\"cuda\" if is_gpu else \"cpu\")\n","\n","#create model\n","model = SegNet().to(device)\n","\n","#model.load_state_dict(torch.load(\"/content/drive/My Drive/weights15.pt\"))#####################################################################\n","\n","#define loss function\n","tensor_number = 20\n","weights = torch.load(content_path + \"/Train_Weights_Test1.pt\")\n","\n","for i in range(2,tensor_number):\n"," weights += torch.load(content_path + \"/drive/My Drive/Train_Weights\" + str(i+1) + \".pt\")\n","\n","weights = 1/(weights/weights.min()) #press weights in [0,1], with maximum value for each class\n","weights = weights.type(torch.FloatTensor)\n","weights = weights.to(device)\n","print(\"weights\", weights)\n","\n","criterion = nn.CrossEntropyLoss(weights)\n","\n","#set optimizer for backpropagation\n","optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)\n","\n","print(model)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["GPU is available: True\n","If you are receiving False, try setting your runtime to GPU\n","weights tensor([0.0053, 1.0000, 0.2481, 0.5966, 0.8477], device='cuda:0')\n","SegNet(\n"," (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (enc1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (enc2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (maxpool2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (enc3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (maxpool3): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (conv8): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv9): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv10): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (enc4_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (maxpool4): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (conv11): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv12): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (enc5_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (maxpool5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)\n"," (upsample1): Upsample(scale_factor=2.0, mode=nearest)\n"," (conv14): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv15): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (dec1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (upsample2): Upsample(scale_factor=2.0, mode=nearest)\n"," (conv17): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv19): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (dec2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (upsample3): Upsample(scale_factor=2.0, mode=nearest)\n"," (conv20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv21): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv22): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (dec3_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (upsample4): Upsample(scale_factor=2.0, mode=nearest)\n"," (conv23): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv24): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (dec4_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n"," (upsample5): Upsample(scale_factor=2.0, mode=nearest)\n"," (conv25): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (conv26): Conv2d(64, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n"," (softmax): Softmax(dim=1)\n",")\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"6AlmdUAxq9My","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"6oML5n20Zc5s","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":1000},"outputId":"365f17b6-57c3-45be-f04a-7083bf142bde"},"source":["total_epochs = 50\n","val_list = []\n","loss_list = []\n","#train_list = []\n","#percent_val = 0.5########################################################################################################\n","\n","val_data = torch.load(content + \"Val_Tensor.pt\")\n","for epoch in range(0, total_epochs):\n","\n"," print(\"EPOCH:\", epoch + 1)\n"," print(\"TRAIN\")\n"," for i in range(1,21): #tensor_number):\n"," print(\"train_data_number:\", i+1)\n"," train_data = torch.load(content_path + \"Train_Tensor\" +str(i) +\".pt\")\n"," loss_list.append(train(train_data, model, optimizer, criterion, device))\n"," print(\"VALIDATION\")\n"," val_list.append(validate(val_data, model, device, categories))\n"," \n"," if ((epoch) % 5 == 0):\n"," torch.save(model.state_dict(), content_path + \"Model_weights_\" + str(epoch) + \".pt\")\n"," torch.save(val_list, content_path + \"val_list.pt\")\n"," torch.save(loss_list, content_path + \"loss_list.pt\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["EPOCH: 1\n","TRAIN\n","train_data_number: 2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.7\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.6\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.6\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.5\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.4\n","VALIDATION\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.058 precision 0.99 recall 0.013 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.019 precision 0.006 recall 1.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.033 precision 0.99 recall 0.013 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.99 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.015 precision 0.0016 recall 1.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.063 precision 0.98 recall 0.013 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.024 precision 0.011 recall 1.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.054 precision 0.99 recall 0.013 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.02 precision 0.0067 recall 1.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.056 precision 0.98 recall 0.013 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.018 precision 0.0049 recall 1.0 \n","\n","avg.category white accuracy 0.053 precision 0.99 recall 0.013 \n","avg.category black accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category green accuracy 0.98 precision 1.0 recall 0.0 \n","avg.category red accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category yellow accuracy 0.019 precision 0.0061 recall 1.0 \n","EPOCH: 2\n","TRAIN\n","train_data_number: 2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.4\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.4\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.3\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.2\n","VALIDATION\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 5.3e-05 recall 0.00011 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.97 precision 0.98 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.99 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 5.3e-05 recall 0.00043 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.0017 recall 0.002 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.0017 recall 0.0033 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.00095 recall 0.0025 \n","\n","avg.category white accuracy 0.95 precision 0.96 recall 0.99 \n","avg.category black accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category green accuracy 0.98 precision 1.0 recall 0.0 \n","avg.category red accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category yellow accuracy 0.98 precision 0.0009 recall 0.0017 \n","EPOCH: 3\n","TRAIN\n","train_data_number: 2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.1\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.1\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.0\n","VALIDATION\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 5.3e-05 recall 0.00011 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.97 precision 0.98 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.99 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 5.3e-05 recall 0.00043 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.0017 recall 0.002 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.0017 recall 0.0033 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 0.00095 recall 0.0025 \n","\n","avg.category white accuracy 0.95 precision 0.96 recall 0.99 \n","avg.category black accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category green accuracy 0.98 precision 1.0 recall 0.0 \n","avg.category red accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category yellow accuracy 0.98 precision 0.0009 recall 0.0017 \n","EPOCH: 4\n","TRAIN\n","train_data_number: 2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.0\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.1\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 1.0\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.92\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.9\n","VALIDATION\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.97 precision 0.98 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.99 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","avg.category white accuracy 0.95 precision 0.96 recall 0.99 \n","avg.category black accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category green accuracy 0.98 precision 1.0 recall 0.0 \n","avg.category red accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category yellow accuracy 0.98 precision 1.0 recall 0.0 \n","EPOCH: 5\n","TRAIN\n","train_data_number: 2\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.91\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.93\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.91\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.82\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","loss 0.81\n","VALIDATION\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.97 precision 0.98 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.99 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.94 precision 0.95 recall 0.99 \n","category black accuracy 0.99 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.98 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","torch.Size([8, 64, 320, 576])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 128, 160, 288])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 256, 80, 144])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 512, 40, 72])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 512, 10, 18])\n","\n","torch.Size([8, 512, 20, 36])\n","torch.Size([8, 256, 40, 72])\n","torch.Size([8, 128, 80, 144])\n","torch.Size([8, 64, 160, 288])\n","torch.Size([8, 5, 320, 576])\n","category white accuracy 0.95 precision 0.96 recall 0.99 \n","category black accuracy 1.0 precision 1.0 recall 0.0 \n","category green accuracy 0.98 precision 1.0 recall 0.0 \n","category red accuracy 0.99 precision 1.0 recall 0.0 \n","category yellow accuracy 0.99 precision 1.0 recall 0.0 \n","\n","avg.category white accuracy 0.95 precision 0.96 recall 0.99 \n","avg.category black accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category green accuracy 0.98 precision 1.0 recall 0.0 \n","avg.category red accuracy 0.99 precision 1.0 recall 0.0 \n","avg.category yellow accuracy 0.99 precision 1.0 recall 0.0 \n"],"name":"stdout"},{"output_type":"execute_result","data":{"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"},"text/plain":["'val_list.append(validate(test_data, model, device, percent_val))\\n train_list.append(validate(train_data, model, device, percent_val))\\n if((epoch + 1) % 10 == 0):\\n torch.save(model, \"weights_small\" + str(epoch) + \".pt\")\\n if((epoch+1) == 125):\\n torch.save(model, \"/content/drive/My Drive/\" + \"weights_small50.pt\")\\n for lis in train_list:\\n print(lis)\\n for lis in val_list:\\n print(lis)\\n \\n\\ntorch.save(val_list, \"/content/drive/My Drive/val_list.pt\")\\n'"]},"metadata":{"tags":[]},"execution_count":28}]},{"cell_type":"code","metadata":{"id":"RmglFAk3ueD6","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]}]} \ No newline at end of file