From ffb391f6e7ae1554986416356a6e5eb952bd5869 Mon Sep 17 00:00:00 2001 From: Yuzhong Liu Date: Fri, 30 Jun 2023 17:11:55 +0800 Subject: [PATCH] Add files via upload --- ...21\346\210\230\350\265\233_baseline.ipynb" | 380 ++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 "competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332023/\350\207\252\345\212\250\351\251\276\351\251\266\347\226\262\345\212\263\346\243\200\346\265\213\346\214\221\346\210\230\350\265\233_baseline.ipynb" diff --git "a/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332023/\350\207\252\345\212\250\351\251\276\351\251\266\347\226\262\345\212\263\346\243\200\346\265\213\346\214\221\346\210\230\350\265\233_baseline.ipynb" "b/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332023/\350\207\252\345\212\250\351\251\276\351\251\266\347\226\262\345\212\263\346\243\200\346\265\213\346\214\221\346\210\230\350\265\233_baseline.ipynb" new file mode 100644 index 0000000..0c06efb --- /dev/null +++ "b/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332023/\350\207\252\345\212\250\351\251\276\351\251\266\347\226\262\345\212\263\346\243\200\346\265\213\346\214\221\346\210\230\350\265\233_baseline.ipynb" @@ -0,0 +1,380 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c8d01073-09a8-4c68-b93b-aef463738bd0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "import os, sys, glob, argparse\n", + "import pandas as pd\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "%pylab inline\n", + "\n", + "import cv2\n", + "from PIL import Image\n", + "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n", + "\n", + "import torch\n", + "torch.manual_seed(0)\n", + "torch.backends.cudnn.deterministic = False\n", + "torch.backends.cudnn.benchmark = True\n", + "\n", + "import torchvision.models as models\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.autograd import Variable\n", + "from torch.utils.data.dataset import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "70d2263d-916c-431e-aca7-3e5bcd867f9b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "train_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/train/*/*')\n", + "test_path = glob.glob('./自动驾驶疲劳检测挑战赛公开数据/test/*')\n", + "\n", + "np.random.shuffle(train_path)\n", + "np.random.shuffle(test_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "584f2188-cf45-4d72-abee-4dae0d362926", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "DATA_CACHE = {}\n", + "class XunFeiDataset(Dataset):\n", + " def __init__(self, img_path, transform=None):\n", + " self.img_path = img_path\n", + " if transform is not None:\n", + " self.transform = transform\n", + " else:\n", + " self.transform = None\n", + " \n", + " def __getitem__(self, index):\n", + " img = cv2.imread(self.img_path[index])\n", + " \n", + " if self.transform is not None:\n", + " img = self.transform(image = img)['image']\n", + " \n", + " img = img.transpose([2,0,1])\n", + " img = img.astype(np.float32)\n", + " return img, torch.from_numpy(np.array(int('non-sleepy' in self.img_path[index])))\n", + " \n", + " def __len__(self):\n", + " return len(self.img_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class XunFeiNet(nn.Module):\n", + " def __init__(self):\n", + " super(XunFeiNet, self).__init__()\n", + " \n", + " model = models.resnet18(True)\n", + " model.avgpool = nn.AdaptiveAvgPool2d(1)\n", + " model.fc = nn.Linear(512, 2)\n", + " self.resnet = model\n", + " \n", + " def forward(self, img): \n", + " out = self.resnet(img)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "47539bef-14a7-4881-85e9-0882ff295e6a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def train(train_loader, model, criterion, optimizer):\n", + " model.train()\n", + " train_loss = 0.0\n", + " for i, (input, target) in enumerate(train_loader):\n", + " input = input.cuda(non_blocking=True)\n", + " target = target.cuda(non_blocking=True)\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " loss = criterion(output, target)\n", + "\n", + " # compute gradient and do SGD step\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if i % 100 == 0:\n", + " print(loss.item())\n", + " \n", + " train_loss += loss.item()\n", + " \n", + " return train_loss/len(train_loader)\n", + " \n", + "def validate(val_loader, model, criterion):\n", + " model.eval()\n", + " \n", + " val_acc = 0.0\n", + " \n", + " with torch.no_grad():\n", + " end = time.time()\n", + " for i, (input, target) in enumerate(val_loader):\n", + " input = input.cuda()\n", + " target = target.cuda()\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " loss = criterion(output, target)\n", + " \n", + " val_acc += (output.argmax(1) == target).sum().item()\n", + " \n", + " return val_acc / len(val_loader.dataset)\n", + "\n", + "def predict(test_loader, model, criterion):\n", + " model.eval()\n", + " val_acc = 0.0\n", + " \n", + " test_pred = []\n", + " with torch.no_grad():\n", + " end = time.time()\n", + " for i, (input, target) in enumerate(test_loader):\n", + " input = input.cuda()\n", + " target = target.cuda()\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " test_pred.append(output.data.cpu().numpy())\n", + " \n", + " return np.vstack(test_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c14b410f-de94-4f52-be5b-1377c34bda4a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/lyz/.local/lib/python3.9/site-packages/albumentations/augmentations/transforms.py:1639: FutureWarning: RandomContrast has been deprecated. Please use RandomBrightnessContrast\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import albumentations as A\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " XunFeiDataset(train_path[:-10],\n", + " A.Compose([\n", + " A.RandomRotate90(),\n", + " A.Resize(128, 128),\n", + " A.RandomCrop(120, 120),\n", + " A.HorizontalFlip(p=0.5),\n", + " A.RandomContrast(p=0.5),\n", + " A.RandomBrightnessContrast(p=0.5),\n", + " ])\n", + " ), batch_size=20, shuffle=True, num_workers=1, pin_memory=False\n", + ")\n", + "\n", + "val_loader = torch.utils.data.DataLoader(\n", + " XunFeiDataset(train_path[-10:],\n", + " A.Compose([\n", + " A.Resize(128, 128),\n", + " A.RandomCrop(120, 120),\n", + " # A.HorizontalFlip(p=0.5),\n", + " # A.RandomContrast(p=0.5),\n", + " ])\n", + " ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n", + ")\n", + "\n", + "test_loader = torch.utils.data.DataLoader(\n", + " XunFeiDataset(test_path,\n", + " A.Compose([\n", + " A.Resize(128, 128),\n", + " A.HorizontalFlip(p=0.5),\n", + " A.RandomContrast(p=0.5),\n", + " ])\n", + " ), batch_size=20, shuffle=False, num_workers=1, pin_memory=False\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:135: UserWarning: Using 'weights' as positional parameter(s) is deprecated since 0.13 and will be removed in 0.15. Please use keyword parameter(s) instead.\n", + " warnings.warn(\n", + "/home/lyz/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "model = XunFeiNet()\n", + "model = model.to('cuda')\n", + "criterion = nn.CrossEntropyLoss().cuda()\n", + "optimizer = torch.optim.AdamW(model.parameters(), 0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8071246147155762\n", + "0.40889906883239746\n", + "0.20135173201560974\n", + "0.4055064118653536 0.8306613226452906 0.8\n" + ] + } + ], + "source": [ + "for _ in range(1):\n", + " train_loss = train(train_loader, model, criterion, optimizer)\n", + " val_acc = validate(val_loader, model, criterion)\n", + " train_acc = validate(train_loader, model, criterion)\n", + " \n", + " print(train_loss, train_acc, val_acc)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6787d12a-b923-4777-abf5-4d90904551d1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pred = None\n", + "\n", + "for _ in range(10):\n", + " if pred is None:\n", + " pred = predict(test_loader, model, criterion)\n", + " else:\n", + " pred += predict(test_loader, model, criterion)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "submit = pd.DataFrame(\n", + " {\n", + " 'name': [x.split('/')[-1] for x in test_path],\n", + " 'label': pred.argmax(1)\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "submit['label'] = submit['label'].map({1:'non-sleepy', 0: 'sleepy'})\n", + "submit.to_csv('submit.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fc9e8ed-2c50-4f92-ad5b-d53083b20895", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3.10" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}