-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
92c82e6
commit ffb391f
Showing
1 changed file
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,380 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "c8d01073-09a8-4c68-b93b-aef463738bd0", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", | ||
"Populating the interactive namespace from numpy and matplotlib\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import os, sys, glob, argparse\n", | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"%pylab inline\n", | ||
"\n", | ||
"import cv2\n", | ||
"from PIL import Image\n", | ||
"from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n", | ||
"\n", | ||
"import torch\n", | ||
"torch.manual_seed(0)\n", | ||
"torch.backends.cudnn.deterministic = False\n", | ||
"torch.backends.cudnn.benchmark = True\n", | ||
"\n", | ||
"import torchvision.models as models\n", | ||
"import torchvision.transforms as transforms\n", | ||
"import torchvision.datasets as datasets\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"from torch.autograd import Variable\n", | ||
"from torch.utils.data.dataset import Dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "70d2263d-916c-431e-aca7-3e5bcd867f9b", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"train_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/train/*/*')\n", | ||
"test_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/test/*')\n", | ||
"\n", | ||
"np.random.shuffle(train_path)\n", | ||
"np.random.shuffle(test_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "584f2188-cf45-4d72-abee-4dae0d362926", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"DATA_CACHE = {}\n", | ||
"class XunFeiDataset(Dataset):\n", | ||
" def __init__(self, img_path, transform=None):\n", | ||
" self.img_path = img_path\n", | ||
" if transform is not None:\n", | ||
" self.transform = transform\n", | ||
" else:\n", | ||
" self.transform = None\n", | ||
" \n", | ||
" def __getitem__(self, index):\n", | ||
" img = cv2.imread(self.img_path[index])\n", | ||
" \n", | ||
" if self.transform is not None:\n", | ||
" img = self.transform(image = img)['image']\n", | ||
" \n", | ||
" img = img.transpose([2,0,1])\n", | ||
" img = img.astype(np.float32)\n", | ||
" return img, torch.from_numpy(np.array(int('non-sleepy' in self.img_path[index])))\n", | ||
" \n", | ||
" def __len__(self):\n", | ||
" return len(self.img_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class XunFeiNet(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(XunFeiNet, self).__init__()\n", | ||
" \n", | ||
" model = models.resnet18(True)\n", | ||
" model.avgpool = nn.AdaptiveAvgPool2d(1)\n", | ||
" model.fc = nn.Linear(512, 2)\n", | ||
" self.resnet = model\n", | ||
" \n", | ||
" def forward(self, img): \n", | ||
" out = self.resnet(img)\n", | ||
" return out" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "47539bef-14a7-4881-85e9-0882ff295e6a", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def train(train_loader, model, criterion, optimizer):\n", | ||
" model.train()\n", | ||
" train_loss = 0.0\n", | ||
" for i, (input, target) in enumerate(train_loader):\n", | ||
" input = input.cuda(non_blocking=True)\n", | ||
" target = target.cuda(non_blocking=True)\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" loss = criterion(output, target)\n", | ||
"\n", | ||
" # compute gradient and do SGD step\n", | ||
" optimizer.zero_grad()\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if i % 100 == 0:\n", | ||
" print(loss.item())\n", | ||
" \n", | ||
" train_loss += loss.item()\n", | ||
" \n", | ||
" return train_loss/len(train_loader)\n", | ||
" \n", | ||
"def validate(val_loader, model, criterion):\n", | ||
" model.eval()\n", | ||
" \n", | ||
" val_acc = 0.0\n", | ||
" \n", | ||
" with torch.no_grad():\n", | ||
" end = time.time()\n", | ||
" for i, (input, target) in enumerate(val_loader):\n", | ||
" input = input.cuda()\n", | ||
" target = target.cuda()\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" loss = criterion(output, target)\n", | ||
" \n", | ||
" val_acc += (output.argmax(1) == target).sum().item()\n", | ||
" \n", | ||
" return val_acc / len(val_loader.dataset)\n", | ||
"\n", | ||
"def predict(test_loader, model, criterion):\n", | ||
" model.eval()\n", | ||
" val_acc = 0.0\n", | ||
" \n", | ||
" test_pred = []\n", | ||
" with torch.no_grad():\n", | ||
" end = time.time()\n", | ||
" for i, (input, target) in enumerate(test_loader):\n", | ||
" input = input.cuda()\n", | ||
" target = target.cuda()\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" test_pred.append(output.data.cpu().numpy())\n", | ||
" \n", | ||
" return np.vstack(test_pred)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "c14b410f-de94-4f52-be5b-1377c34bda4a", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/lyz/.local/lib/python3.9/site-packages/albumentations/augmentations/transforms.py:1639: FutureWarning: RandomContrast has been deprecated. Please use RandomBrightnessContrast\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import albumentations as A\n", | ||
"\n", | ||
"train_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(train_path[:-10],\n", | ||
" A.Compose([\n", | ||
" A.RandomRotate90(),\n", | ||
" A.Resize(128, 128),\n", | ||
" A.RandomCrop(120, 120),\n", | ||
" A.HorizontalFlip(p=0.5),\n", | ||
" A.RandomContrast(p=0.5),\n", | ||
" A.RandomBrightnessContrast(p=0.5),\n", | ||
" ])\n", | ||
" ), batch_size=20, shuffle=True, num_workers=1, pin_memory=False\n", | ||
")\n", | ||
"\n", | ||
"val_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(train_path[-10:],\n", | ||
" A.Compose([\n", | ||
" A.Resize(128, 128),\n", | ||
" A.RandomCrop(120, 120),\n", | ||
" # A.HorizontalFlip(p=0.5),\n", | ||
" # A.RandomContrast(p=0.5),\n", | ||
" ])\n", | ||
" ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n", | ||
")\n", | ||
"\n", | ||
"test_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(test_path,\n", | ||
" A.Compose([\n", | ||
" A.Resize(128, 128),\n", | ||
" A.HorizontalFlip(p=0.5),\n", | ||
" A.RandomContrast(p=0.5),\n", | ||
" ])\n", | ||
" ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:135: UserWarning: Using 'weights' as positional parameter(s) is deprecated since 0.13 and will be removed in 0.15. Please use keyword parameter(s) instead.\n", | ||
" warnings.warn(\n", | ||
"/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n", | ||
" warnings.warn(msg)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = XunFeiNet()\n", | ||
"model = model.to('cuda')\n", | ||
"criterion = nn.CrossEntropyLoss().cuda()\n", | ||
"optimizer = torch.optim.AdamW(model.parameters(), 0.001)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.8071246147155762\n", | ||
"0.40889906883239746\n", | ||
"0.20135173201560974\n", | ||
"0.4055064118653536 0.8306613226452906 0.8\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for _ in range(1):\n", | ||
" train_loss = train(train_loader, model, criterion, optimizer)\n", | ||
" val_acc = validate(val_loader, model, criterion)\n", | ||
" train_acc = validate(train_loader, model, criterion)\n", | ||
" \n", | ||
" print(train_loss, train_acc, val_acc)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "6787d12a-b923-4777-abf5-4d90904551d1", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"pred = None\n", | ||
"\n", | ||
"for _ in range(10):\n", | ||
" if pred is None:\n", | ||
" pred = predict(test_loader, model, criterion)\n", | ||
" else:\n", | ||
" pred += predict(test_loader, model, criterion)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"submit = pd.DataFrame(\n", | ||
" {\n", | ||
" 'name': [x.split('/')[-1] for x in test_path],\n", | ||
" 'label': pred.argmax(1)\n", | ||
"})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"submit['label'] = submit['label'].map({1:'non-sleepy', 0: 'sleepy'})\n", | ||
"submit.to_csv('submit.csv', index=None)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3fc9e8ed-2c50-4f92-ad5b-d53083b20895", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3.10" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.10" | ||
}, | ||
"widgets": { | ||
"application/vnd.jupyter.widget-state+json": { | ||
"state": {}, | ||
"version_major": 2, | ||
"version_minor": 0 | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |