Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
finlay-liu authored Jul 27, 2023
1 parent 73e364c commit dec05c8
Showing 1 changed file with 351 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os, sys, glob, argparse\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"import cv2\n",
"from PIL import Image\n",
"from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
"\n",
"import torch\n",
"torch.manual_seed(0)\n",
"torch.backends.cudnn.deterministic = False\n",
"torch.backends.cudnn.benchmark = True\n",
"\n",
"import torchvision.models as models\n",
"import torchvision.transforms as transforms\n",
"import torchvision.datasets as datasets\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"from torch.utils.data.dataset import Dataset\n",
"\n",
"# Check if GPU is available\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"else:\n",
" device = torch.device(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"train_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/train/*/*')\n",
"test_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/test/*')\n",
"\n",
"np.random.shuffle(train_path)\n",
"test_path.sort()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "584f2188-cf45-4d72-abee-4dae0d362926",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class XunFeiDataset(Dataset):\n",
" def __init__(self, img_path, transform=None):\n",
" self.img_path = img_path\n",
" if transform is not None:\n",
" self.transform = transform\n",
" else:\n",
" self.transform = None\n",
" def __getitem__(self, index):\n",
" img = cv2.imread(self.img_path[index])\n",
" if self.transform is not None:\n",
" img = self.transform(image = img)['image']\n",
" img = img.transpose([2,0,1])\n",
" return img, torch.from_numpy(np.array(int('real' in self.img_path[index])))\n",
" def __len__(self):\n",
" return len(self.img_path)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class XunFeiNet(nn.Module):\n",
" def __init__(self):\n",
" super(XunFeiNet, self).__init__()\n",
" model = models.resnet34(True)\n",
" model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
" model.fc = nn.Linear(512, 2)\n",
" self.resnet = model\n",
" def forward(self, img):\n",
" out = self.resnet(img)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "47539bef-14a7-4881-85e9-0882ff295e6a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def train(train_loader, model, criterion, optimizer):\n",
" model.train()\n",
" train_loss = 0.0\n",
" for i, (input, target) in enumerate(train_loader):\n",
" input = input.to(device)\n",
" target = target.to(device)\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" loss = criterion(output, target)\n",
"\n",
" # compute gradient and do SGD step\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if i % 100 == 0:\n",
" print('Train loss', loss.item())\n",
" \n",
" train_loss += loss.item()\n",
" \n",
" return train_loss/len(train_loader)\n",
" \n",
"def validate(val_loader, model, criterion):\n",
" model.eval()\n",
" \n",
" val_acc = 0.0\n",
" \n",
" with torch.no_grad():\n",
" for i, (input, target) in enumerate(val_loader):\n",
" input = input.to(device)\n",
" target = target.to(device)\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" loss = criterion(output, target)\n",
" \n",
" val_acc += (output.argmax(1) == target).sum().item()\n",
" \n",
" return val_acc / len(val_loader.dataset)\n",
"\n",
"def predict(test_loader, model, criterion):\n",
" model.eval()\n",
" val_acc = 0.0\n",
" \n",
" test_pred = []\n",
" with torch.no_grad():\n",
" for i, (input, target) in enumerate(test_loader):\n",
" input = input.to(device)\n",
" target = target.to(device)\n",
"\n",
" # compute output\n",
" output = model(input)\n",
" test_pred.append(output.data.cpu().numpy())\n",
" \n",
" return np.vstack(test_pred)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import albumentations as A\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(train_path[:-1000],\n",
" A.Compose([\n",
" A.RandomRotate90(),\n",
" A.Resize(256, 256),\n",
" A.RandomCrop(224, 224),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomContrast(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.5),\n",
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
" ])\n",
" ), batch_size=30, shuffle=True, num_workers=1, pin_memory=False\n",
")\n",
"\n",
"val_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(train_path[-1000:],\n",
" A.Compose([\n",
" A.Resize(256, 256),\n",
" A.RandomCrop(224, 224),\n",
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
" ])\n",
" ), batch_size=30, shuffle=False, num_workers=1, pin_memory=False\n",
")\n",
"\n",
"test_loader = torch.utils.data.DataLoader(\n",
" XunFeiDataset(test_path,\n",
" A.Compose([\n",
" A.Resize(256, 256),\n",
" A.RandomCrop(224, 224),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomContrast(p=0.5),\n",
" A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
" ])\n",
" ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model = XunFeiNet()\n",
"model = model.to(device)\n",
"criterion = nn.CrossEntropyLoss().cuda()\n",
"optimizer = torch.optim.AdamW(model.parameters(), 0.001)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss 0.7959139347076416\n",
"Train loss 0.29768866300582886\n",
"Train loss 0.2745967209339142\n",
"0.3281211804474394 0.9055555555555556 0.914\n",
"Train loss 0.3272189795970917\n",
"Train loss 0.14214976131916046\n",
"Train loss 0.2875927686691284\n",
"0.2151400999724865 0.9335555555555556 0.943\n"
]
}
],
"source": [
"for _ in range(2):\n",
" train_loss = train(train_loader, model, criterion, optimizer)\n",
" val_acc = validate(val_loader, model, criterion)\n",
" train_acc = validate(train_loader, model, criterion)\n",
" \n",
" print(train_loss, train_acc, val_acc)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "6787d12a-b923-4777-abf5-4d90904551d1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pred = None\n",
"\n",
"for _ in range(3):\n",
" if pred is None:\n",
" pred = predict(test_loader, model, criterion)\n",
" else:\n",
" pred += predict(test_loader, model, criterion)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"submit = pd.DataFrame(\n",
" {\n",
" 'uuid': [x.split('/')[-1] for x in test_path],\n",
" 'label': pred.argmax(1)\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"submit['label'] = submit['label'].map({1: 'real', 0: 'ai'})\n",
"submit.to_csv('submit.csv', index=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2022bdd-5cee-4966-921a-aab21dd0ef45",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3.10"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit dec05c8

Please sign in to comment.