-
Notifications
You must be signed in to change notification settings - Fork 312
/
env.py
107 lines (95 loc) · 3.64 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import sys
import json
import torch
import numpy as np
import argparse
import torchvision.transforms as transforms
import cv2
from DRL.ddpg import decode
from utils.util import *
from PIL import Image
from torchvision import transforms, utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
aug = transforms.Compose(
[transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
])
width = 128
convas_area = width * width
img_train = []
img_test = []
train_num = 0
test_num = 0
class Paint:
def __init__(self, batch_size, max_step):
self.batch_size = batch_size
self.max_step = max_step
self.action_space = (13)
self.observation_space = (self.batch_size, width, width, 7)
self.test = False
def load_data(self):
# CelebA
global train_num, test_num
for i in range(200000):
img_id = '%06d' % (i + 1)
try:
img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (width, width))
if i > 2000:
train_num += 1
img_train.append(img)
else:
test_num += 1
img_test.append(img)
finally:
if (i + 1) % 10000 == 0:
print('loaded {} images'.format(i + 1))
print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))
def pre_data(self, id, test):
if test:
img = img_test[id]
else:
img = img_train[id]
if not test:
img = aug(img)
img = np.asarray(img)
return np.transpose(img, (2, 0, 1))
def reset(self, test=False, begin_num=False):
self.test = test
self.imgid = [0] * self.batch_size
self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
for i in range(self.batch_size):
if test:
id = (i + begin_num) % test_num
else:
id = np.random.randint(train_num)
self.imgid[i] = id
self.gt[i] = torch.tensor(self.pre_data(id, test))
self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
self.stepnum = 0
self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
self.lastdis = self.ini_dis = self.cal_dis()
return self.observation()
def observation(self):
# canvas B * 3 * width * width
# gt B * 3 * width * width
# T B * 1 * width * width
ob = []
T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum
return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T
def cal_trans(self, s, t):
return (s.transpose(0, 3) * t).transpose(0, 3)
def step(self, action):
self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()
self.stepnum += 1
ob = self.observation()
done = (self.stepnum == self.max_step)
reward = self.cal_reward() # np.array([0.] * self.batch_size)
return ob.detach(), reward, np.array([done] * self.batch_size), None
def cal_dis(self):
return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)
def cal_reward(self):
dis = self.cal_dis()
reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
self.lastdis = dis
return to_numpy(reward)