-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
73e364c
commit dec05c8
Showing
1 changed file
with
351 additions
and
0 deletions.
There are no files selected for viewing
351 changes: 351 additions & 0 deletions
351
competition/科大讯飞AI开发者大赛2023/Stable Diffusion鉴别器挑战赛_baseline.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "c8d01073-09a8-4c68-b93b-aef463738bd0", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import os, sys, glob, argparse\n", | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"import cv2\n", | ||
"from PIL import Image\n", | ||
"from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n", | ||
"\n", | ||
"import torch\n", | ||
"torch.manual_seed(0)\n", | ||
"torch.backends.cudnn.deterministic = False\n", | ||
"torch.backends.cudnn.benchmark = True\n", | ||
"\n", | ||
"import torchvision.models as models\n", | ||
"import torchvision.transforms as transforms\n", | ||
"import torchvision.datasets as datasets\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"from torch.autograd import Variable\n", | ||
"from torch.utils.data.dataset import Dataset\n", | ||
"\n", | ||
"# Check if GPU is available\n", | ||
"if torch.cuda.is_available():\n", | ||
" device = torch.device(\"cuda\")\n", | ||
"else:\n", | ||
" device = torch.device(\"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"id": "70d2263d-916c-431e-aca7-3e5bcd867f9b", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"train_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/train/*/*')\n", | ||
"test_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/test/*')\n", | ||
"\n", | ||
"np.random.shuffle(train_path)\n", | ||
"test_path.sort()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"id": "584f2188-cf45-4d72-abee-4dae0d362926", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class XunFeiDataset(Dataset):\n", | ||
" def __init__(self, img_path, transform=None):\n", | ||
" self.img_path = img_path\n", | ||
" if transform is not None:\n", | ||
" self.transform = transform\n", | ||
" else:\n", | ||
" self.transform = None\n", | ||
" def __getitem__(self, index):\n", | ||
" img = cv2.imread(self.img_path[index])\n", | ||
" if self.transform is not None:\n", | ||
" img = self.transform(image = img)['image']\n", | ||
" img = img.transpose([2,0,1])\n", | ||
" return img, torch.from_numpy(np.array(int('real' in self.img_path[index])))\n", | ||
" def __len__(self):\n", | ||
" return len(self.img_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class XunFeiNet(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(XunFeiNet, self).__init__()\n", | ||
" model = models.resnet34(True)\n", | ||
" model.avgpool = nn.AdaptiveAvgPool2d(1)\n", | ||
" model.fc = nn.Linear(512, 2)\n", | ||
" self.resnet = model\n", | ||
" def forward(self, img):\n", | ||
" out = self.resnet(img)\n", | ||
" return out" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 44, | ||
"id": "47539bef-14a7-4881-85e9-0882ff295e6a", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def train(train_loader, model, criterion, optimizer):\n", | ||
" model.train()\n", | ||
" train_loss = 0.0\n", | ||
" for i, (input, target) in enumerate(train_loader):\n", | ||
" input = input.to(device)\n", | ||
" target = target.to(device)\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" loss = criterion(output, target)\n", | ||
"\n", | ||
" # compute gradient and do SGD step\n", | ||
" optimizer.zero_grad()\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if i % 100 == 0:\n", | ||
" print('Train loss', loss.item())\n", | ||
" \n", | ||
" train_loss += loss.item()\n", | ||
" \n", | ||
" return train_loss/len(train_loader)\n", | ||
" \n", | ||
"def validate(val_loader, model, criterion):\n", | ||
" model.eval()\n", | ||
" \n", | ||
" val_acc = 0.0\n", | ||
" \n", | ||
" with torch.no_grad():\n", | ||
" for i, (input, target) in enumerate(val_loader):\n", | ||
" input = input.to(device)\n", | ||
" target = target.to(device)\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" loss = criterion(output, target)\n", | ||
" \n", | ||
" val_acc += (output.argmax(1) == target).sum().item()\n", | ||
" \n", | ||
" return val_acc / len(val_loader.dataset)\n", | ||
"\n", | ||
"def predict(test_loader, model, criterion):\n", | ||
" model.eval()\n", | ||
" val_acc = 0.0\n", | ||
" \n", | ||
" test_pred = []\n", | ||
" with torch.no_grad():\n", | ||
" for i, (input, target) in enumerate(test_loader):\n", | ||
" input = input.to(device)\n", | ||
" target = target.to(device)\n", | ||
"\n", | ||
" # compute output\n", | ||
" output = model(input)\n", | ||
" test_pred.append(output.data.cpu().numpy())\n", | ||
" \n", | ||
" return np.vstack(test_pred)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 45, | ||
"id": "c14b410f-de94-4f52-be5b-1377c34bda4a", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import albumentations as A\n", | ||
"\n", | ||
"train_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(train_path[:-1000],\n", | ||
" A.Compose([\n", | ||
" A.RandomRotate90(),\n", | ||
" A.Resize(256, 256),\n", | ||
" A.RandomCrop(224, 224),\n", | ||
" A.HorizontalFlip(p=0.5),\n", | ||
" A.RandomContrast(p=0.5),\n", | ||
" A.RandomBrightnessContrast(p=0.5),\n", | ||
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | ||
" ])\n", | ||
" ), batch_size=30, shuffle=True, num_workers=1, pin_memory=False\n", | ||
")\n", | ||
"\n", | ||
"val_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(train_path[-1000:],\n", | ||
" A.Compose([\n", | ||
" A.Resize(256, 256),\n", | ||
" A.RandomCrop(224, 224),\n", | ||
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | ||
" ])\n", | ||
" ), batch_size=30, shuffle=False, num_workers=1, pin_memory=False\n", | ||
")\n", | ||
"\n", | ||
"test_loader = torch.utils.data.DataLoader(\n", | ||
" XunFeiDataset(test_path,\n", | ||
" A.Compose([\n", | ||
" A.Resize(256, 256),\n", | ||
" A.RandomCrop(224, 224),\n", | ||
" A.HorizontalFlip(p=0.5),\n", | ||
" A.RandomContrast(p=0.5),\n", | ||
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", | ||
" ])\n", | ||
" ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 46, | ||
"id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model = XunFeiNet()\n", | ||
"model = model.to(device)\n", | ||
"criterion = nn.CrossEntropyLoss().cuda()\n", | ||
"optimizer = torch.optim.AdamW(model.parameters(), 0.001)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Train loss 0.7959139347076416\n", | ||
"Train loss 0.29768866300582886\n", | ||
"Train loss 0.2745967209339142\n", | ||
"0.3281211804474394 0.9055555555555556 0.914\n", | ||
"Train loss 0.3272189795970917\n", | ||
"Train loss 0.14214976131916046\n", | ||
"Train loss 0.2875927686691284\n", | ||
"0.2151400999724865 0.9335555555555556 0.943\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for _ in range(2):\n", | ||
" train_loss = train(train_loader, model, criterion, optimizer)\n", | ||
" val_acc = validate(val_loader, model, criterion)\n", | ||
" train_acc = validate(train_loader, model, criterion)\n", | ||
" \n", | ||
" print(train_loss, train_acc, val_acc)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 48, | ||
"id": "6787d12a-b923-4777-abf5-4d90904551d1", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"pred = None\n", | ||
"\n", | ||
"for _ in range(3):\n", | ||
" if pred is None:\n", | ||
" pred = predict(test_loader, model, criterion)\n", | ||
" else:\n", | ||
" pred += predict(test_loader, model, criterion)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 49, | ||
"id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"submit = pd.DataFrame(\n", | ||
" {\n", | ||
" 'uuid': [x.split('/')[-1] for x in test_path],\n", | ||
" 'label': pred.argmax(1)\n", | ||
"})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 50, | ||
"id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"submit['label'] = submit['label'].map({1: 'real', 0: 'ai'})\n", | ||
"submit.to_csv('submit.csv', index=None)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f2022bdd-5cee-4966-921a-aab21dd0ef45", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3.10" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.10" | ||
}, | ||
"widgets": { | ||
"application/vnd.jupyter.widget-state+json": { | ||
"state": {}, | ||
"version_major": 2, | ||
"version_minor": 0 | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |