Skip to content

Commit edeb731

Browse files
committed
resolution progressive model
1 parent 14eadd5 commit edeb731

File tree

6 files changed

+80
-19
lines changed

6 files changed

+80
-19
lines changed

__pycache__/data.cpython-37.pyc

808 Bytes
Binary file not shown.

__pycache__/models.cpython-37.pyc

825 Bytes
Binary file not shown.

data.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_embedding(rays, n_freqs=10, n_steps=64, start=0, stop=6):
2626
embed_vals.append(np.sin(2**L * np.pi * points[..., [d]]))
2727
return np.concatenate(embed_vals, -1)
2828

29-
def data_generator(data, batch_size=8, patch_size=8, random_rays=True):
29+
def ray_data_generator(data, batch_size=8, patch_size=8, random_rays=True):
3030
ps = patch_size
3131
while True:
3232
if random_rays:
@@ -70,19 +70,33 @@ def embedded_data_generator(data, batch_size=8, patch_size=8, random_rays=True,
7070
batch_x = get_embedding(batch_rays, n_freqs, n_steps, start, stop)
7171
yield batch_x, batch_y
7272

73+
def camera_data_generator(data, batch_size=8):
74+
while True:
75+
batch_x = np.zeros((batch_size, 6))
76+
batch_y = np.zeros((batch_size, data.H, data.W, 3))
77+
for b_idx in range(batch_size):
78+
data_idx = np.random.randint(len(data.transforms))
79+
c2w = data.transforms[data_idx]['c2w_matrix']
80+
cam_d = np.sum(np.array([[0, 0, -1]]) * c2w[:3, :3], -1)
81+
cam_o = c2w[:3, -1]
82+
batch_x[b_idx] = np.concatenate((cam_o, cam_d), -1)
83+
batch_y[b_idx] = data.imgs[data_idx]
84+
yield batch_x, batch_y
85+
86+
7387
class Data:
74-
def __init__(self, scene='lego', mode='train'):
88+
def __init__(self, scene='lego', mode='train', resize=None):
7589
"""Load data
7690
7791
scene: 'lego'
7892
mode: 'train', 'test', 'val'
93+
resize: None or value for width & height (ex: 512)
7994
"""
8095

8196
data_path = 'data/{}/{}'.format(scene, mode)
8297
self.imgs = []
8398
for i in range(100):
84-
img = load_img('{}/r_{}.png'.format(data_path, i))
85-
self.imgs.append(img_to_array(img) / 255.)
99+
self.imgs.append(load_img('{}/r_{}.png'.format(data_path, i)))
86100

87101
with open('data/{}/transforms_{}.json'.format(scene, mode), 'r') as f:
88102
transforms_json = json.load(f)
@@ -93,7 +107,14 @@ def __init__(self, scene='lego', mode='train'):
93107
'c2w_matrix': np.array(v['transform_matrix'])
94108
} for v in transforms_json['frames']]
95109

96-
self.H, self.W = self.imgs[0].shape[:2]
110+
if resize == None:
111+
self.W, self.H = self.imgs[0].size
112+
else:
113+
self.W, self.H = resize, resize
114+
self.imgs = [img.resize((self.W, self.H)) for img in self.imgs]
115+
116+
self.imgs = [img_to_array(img) / 255. for img in self.imgs]
117+
97118
self.focal = 0.5 * self.W / np.tan(0.5 * self.camera_angle_x)
98119
self.near = 2.
99120
self.far = 6.

generate_video.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from data import Data, get_rays, get_embedding
99

1010

11-
model_path = 'models/m3.h5'
12-
video_path = 'videos/v3.mp4'
11+
model_path = 'models/m4.h5'
12+
video_path = 'videos/v4.mp4'
1313
model = load_model(model_path)
1414

1515
data = Data('lego', 'test')
@@ -46,7 +46,18 @@ def gen_v3():
4646
save_img(path, y)
4747
frame_paths.append(path)
4848

49-
gen_v3()
49+
def gen_v4():
50+
for i, t in tqdm(enumerate(data.transforms), total=len(data.transforms)):
51+
c2w = t['c2w_matrix']
52+
cam_d = np.sum(np.array([[0, 0, -1]]) * c2w[:3, :3], -1)
53+
cam_o = c2w[:3, -1]
54+
x = np.concatenate((cam_o, cam_d), -1)
55+
y = model.predict(x[np.newaxis, ...])[0]
56+
path = 'video_frames/{}.png'.format(i)
57+
save_img(path, y)
58+
frame_paths.append(path)
59+
60+
gen_v4()
5061

5162
os.makedirs('videos', exist_ok=True)
5263
mvp.ImageSequenceClip(frame_paths, fps=30.0).write_videofile(video_path)

models.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
import os
33
from keras.models import Model
4-
from keras.layers import Input, Conv2D, Conv3D, Lambda
4+
from keras.layers import Input, Reshape, Lambda
5+
from keras.layers import Conv2D, Conv3D, UpSampling2D, ZeroPadding2D
56

67

78
class V1:
@@ -13,8 +14,9 @@ def __init__(self):
1314

1415
h_conv1 = self.l_conv1(self.l_in)
1516
h_conv2 = self.l_conv2(h_conv1)
17+
1618
self.model = Model(self.l_in, h_conv2)
17-
19+
1820
def serialize_lua(self):
1921
[d1_kernel, d1_bias] = self.l_conv1.get_weights()
2022
d1_kernel = d1_kernel[0, 0] # all 1x1
@@ -52,7 +54,7 @@ def serialize_lua(self):
5254

5355
class V2:
5456
def __init__(self):
55-
"""Model 2: 6 -> 3x 128 -> 3"""
57+
"""Model 2: 6 -> 3 dense(128) -> 3"""
5658
self.l_in = Input(shape=(None, None, 6))
5759
self.l_conv1 = Conv2D(128, kernel_size=(1, 1), activation='relu', kernel_initializer='he_normal')
5860
self.l_conv2 = Conv2D(128, kernel_size=(1, 1), activation='relu', kernel_initializer='he_normal')
@@ -63,12 +65,13 @@ def __init__(self):
6365
h_conv2 = self.l_conv2(h_conv1)
6466
h_conv3 = self.l_conv3(h_conv2)
6567
h_conv4 = self.l_conv4(h_conv3)
68+
6669
self.model = Model(self.l_in, h_conv4)
6770

6871

6972
class V3:
7073
def __init__(self):
71-
"""Model 3: (64, 36) -> 6x strided 1d convs - f=64, ks=2 -> 3"""
74+
"""Model 3: (64, 36) -> 6 2-strided conv1d(64) -> 3"""
7275
self.l_in = Input(shape=(None, None, 64, 36))
7376
self.conv_layers = []
7477
for i in range(6):
@@ -78,5 +81,26 @@ def __init__(self):
7881
h = self.l_in
7982
for cl in self.conv_layers:
8083
h = cl(h)
84+
8185
h_out = Lambda(lambda x: x[..., 0, :], output_shape=lambda s: s[:-2] + s[-1:])(h)
8286
self.model = Model(self.l_in, h_out)
87+
88+
89+
class V4:
90+
def __init__(self):
91+
"""Model 4: 6 -> (1, 1, 6) -> 4 (upsampling2d(4) + conv2d(32)) -> conv2d(32) -> (256, 256, 3)"""
92+
self.l_in = Input(shape=(6,))
93+
self.conv_layers = []
94+
95+
h = Reshape((1, 1, 6))(self.l_in)
96+
for i in range(5):
97+
if i != 4:
98+
h = UpSampling2D((4, 4))(h)
99+
h = ZeroPadding2D((2, 2))(h)
100+
self.conv_layers.append(Conv2D(8, kernel_size=(5, 5), padding='valid', activation='relu', kernel_initializer='he_normal'))
101+
h = self.conv_layers[-1](h)
102+
103+
self.conv_layers.append(Conv2D(3, kernel_size=(3, 3)))
104+
h = ZeroPadding2D((1, 1))(h)
105+
h = self.conv_layers[-1](h)
106+
self.model = Model(self.l_in, h)

train.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11

22
import os
33
from keras.optimizers import Adam
4-
from models import V3
5-
from data import Data, embedded_data_generator
4+
from models import V4
5+
from data import Data, camera_data_generator
66

77

8-
v3 = V3()
9-
model = v3.model
8+
print('Building model...')
9+
v4 = V4()
10+
model = v4.model
1011
model.compile(optimizer=Adam(0.001), loss='mse')
1112

12-
data = Data('lego', 'train')
13+
print('Loading Data...')
14+
data = Data('lego', 'train', resize=256)
1315

14-
model.fit_generator(embedded_data_generator(data), steps_per_epoch=1.25e5, epochs=1)
16+
print('Training...')
17+
# actual steps_per_epoch = 12.5
18+
model.fit_generator(camera_data_generator(data), steps_per_epoch=100, epochs=5)
1519
os.makedirs('models', exist_ok=True)
16-
model.save('models/m3.h5')
20+
model.save('models/m4.h5')
21+
print('Training complete')

0 commit comments

Comments
 (0)