Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
finlay-liu authored Jun 30, 2023
1 parent 92c82e6 commit ffb391f
Showing 1 changed file with 380 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"import os, sys, glob, argparse\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"%pylab inline\n",
"\n",
"import cv2\n",
"from PIL import Image\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
"\n",
"import torch\n",
"torch.manual_seed(0)\n",
"torch.backends.cudnn.deterministic = False\n",
"torch.backends.cudnn.benchmark = True\n",
"\n",
"import torchvision.models as models\n",
"import torchvision.transforms as transforms\n",
"import torchvision.datasets as datasets\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"from torch.utils.data.dataset import Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"train_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/train/*/*')\n",
"test_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/test/*')\n",
"\n",
"np.random.shuffle(train_path)\n",
"np.random.shuffle(test_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "584f2188-cf45-4d72-abee-4dae0d362926",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"DATA_CACHE = {}\n",
"class XunFeiDataset(Dataset):\n",
" def __init__(self, img_path, transform=None):\n",
" self.img_path = img_path\n",
" if transform is not None:\n",
" self.transform = transform\n",
" else:\n",
" self.transform = None\n",
" \n",
" def __getitem__(self, index):\n",
" img = cv2.imread(self.img_path[index])\n",
" \n",
" if self.transform is not None:\n",
" img = self.transform(image = img)['image']\n",
" \n",
" img = img.transpose([2,0,1])\n",
" img = img.astype(np.float32)\n",
" return img, torch.from_numpy(np.array(int('non-sleepy' in self.img_path[index])))\n",
" \n",
" def __len__(self):\n",
" return len(self.img_path)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class XunFeiNet(nn.Module):\n",
" def __init__(self):\n",
" super(XunFeiNet, self).__init__()\n",
" \n",
" model = models.resnet18(True)\n",
" model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
" model.fc = nn.Linear(512, 2)\n",
" self.resnet = model\n",
" \n",
" def forward(self, img): \n",
" out = self.resnet(img)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "47539bef-14a7-4881-85e9-0882ff295e6a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def train(train_loader, model, criterion, optimizer):\n",
" model.train()\n",
" train_loss = 0.0\n",
" for i, (input, target) in enumerate(train_loader):\n",
" input = input.cuda(non_blocking=True)\n",
" target = target.cuda(non_blocking=True)\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" loss = criterion(output, target)\n",
"\n",
" # compute gradient and do SGD step\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if i % 100 == 0:\n",
" print(loss.item())\n",
" \n",
" train_loss += loss.item()\n",
" \n",
" return train_loss/len(train_loader)\n",
" \n",
"def validate(val_loader, model, criterion):\n",
" model.eval()\n",
" \n",
" val_acc = 0.0\n",
" \n",
" with torch.no_grad():\n",
" end = time.time()\n",
" for i, (input, target) in enumerate(val_loader):\n",
" input = input.cuda()\n",
" target = target.cuda()\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" loss = criterion(output, target)\n",
" \n",
" val_acc += (output.argmax(1) == target).sum().item()\n",
" \n",
" return val_acc / len(val_loader.dataset)\n",
"\n",
"def predict(test_loader, model, criterion):\n",
" model.eval()\n",
" val_acc = 0.0\n",
" \n",
" test_pred = []\n",
" with torch.no_grad():\n",
" end = time.time()\n",
" for i, (input, target) in enumerate(test_loader):\n",
" input = input.cuda()\n",
" target = target.cuda()\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" test_pred.append(output.data.cpu().numpy())\n",
" \n",
" return np.vstack(test_pred)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/lyz/.local/lib/python3.9/site-packages/albumentations/augmentations/transforms.py:1639: FutureWarning: RandomContrast has been deprecated. Please use RandomBrightnessContrast\n",
" warnings.warn(\n"
]
}
],
"source": [
"import albumentations as A\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(train_path[:-10],\n",
" A.Compose([\n",
" A.RandomRotate90(),\n",
" A.Resize(128, 128),\n",
" A.RandomCrop(120, 120),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomContrast(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.5),\n",
" ])\n",
" ), batch_size=20, shuffle=True, num_workers=1, pin_memory=False\n",
")\n",
"\n",
"val_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(train_path[-10:],\n",
" A.Compose([\n",
" A.Resize(128, 128),\n",
" A.RandomCrop(120, 120),\n",
" # A.HorizontalFlip(p=0.5),\n",
" # A.RandomContrast(p=0.5),\n",
" ])\n",
" ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n",
")\n",
"\n",
"test_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(test_path,\n",
" A.Compose([\n",
" A.Resize(128, 128),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomContrast(p=0.5),\n",
" ])\n",
" ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:135: UserWarning: Using 'weights' as positional parameter(s) is deprecated since 0.13 and will be removed in 0.15. Please use keyword parameter(s) instead.\n",
" warnings.warn(\n",
"/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"model = XunFeiNet()\n",
"model = model.to('cuda')\n",
"criterion = nn.CrossEntropyLoss().cuda()\n",
"optimizer = torch.optim.AdamW(model.parameters(), 0.001)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8071246147155762\n",
"0.40889906883239746\n",
"0.20135173201560974\n",
"0.4055064118653536 0.8306613226452906 0.8\n"
]
}
],
"source": [
"for _ in range(1):\n",
" train_loss = train(train_loader, model, criterion, optimizer)\n",
" val_acc = validate(val_loader, model, criterion)\n",
" train_acc = validate(train_loader, model, criterion)\n",
" \n",
" print(train_loss, train_acc, val_acc)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6787d12a-b923-4777-abf5-4d90904551d1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pred = None\n",
"\n",
"for _ in range(10):\n",
" if pred is None:\n",
" pred = predict(test_loader, model, criterion)\n",
" else:\n",
" pred += predict(test_loader, model, criterion)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"submit = pd.DataFrame(\n",
" {\n",
" 'name': [x.split('/')[-1] for x in test_path],\n",
" 'label': pred.argmax(1)\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"submit['label'] = submit['label'].map({1:'non-sleepy', 0: 'sleepy'})\n",
"submit.to_csv('submit.csv', index=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fc9e8ed-2c50-4f92-ad5b-d53083b20895",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3.10"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit ffb391f

Please sign in to comment.