forked from idlegene/HEVC-deep-learning-pipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuse_model.py
127 lines (121 loc) · 5.11 KB
/
use_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from torchvision import transforms
import os,shutil
import numpy as np
from PIL import Image
import time
import math
'''
NEXT STEP:
1. Use dataloader to accelerate
2. HEVC每次读入一整帧的预测内容
'''
class ConvNet2(nn.Module):
def __init__(self):
super().__init__()
# (3,32,32)
self.conv1 = nn.Sequential(
nn.Conv2d(3,16,5,padding=2),
nn.BatchNorm2d(16,affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
) # (16,16,16)
self.conv2 = nn.Sequential(
nn.Conv2d(32,64,3,padding=1),
nn.BatchNorm2d(64,affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
) # (64,8,8)
self.conv3 = nn.Sequential(
nn.Conv2d(64,128,3,padding=1),
nn.BatchNorm2d(128,affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
) # (128,4,4)
self.fc1 = nn.Sequential(nn.Linear(128*4*4,256),nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(256,64),nn.ReLU())
self.fc3 = nn.Linear(64,16)
self.conv64 = nn.Sequential(
nn.Conv2d(3,16,5,padding=2),
nn.BatchNorm2d(16,affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4)
) # (16,16,16) -> (64,16,16)
# self.dropout = nn.Dropout(0.25)
def forward(self,x32,x64):
in_size = x32.size(0)
out = torch.cat([self.conv1(x32),self.conv64(x64)],dim=1)
out = self.conv2(out)
out = self.conv3(out)
out = out.view(in_size,-1) # 扁平化flat然后传入全连接层
out = self.fc1(out)
# out = self.dropout(out)
out = self.fc2(out)
out = self.fc3(out)
return out
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU
model = ConvNet2().to(DEVICE)
model.load_state_dict(torch.load('./rec/hevc_encoder_model.pt',map_location=DEVICE))
print("loaded model from drive")
with open("bitstream.cfg",'r') as f:
for i,line in enumerate(f):
line = line.split(":")
if i == 7:
frame_tobe_encoded = line[1].strip(" ").strip('\n')
else:
pass
total_frames = len(list(os.listdir("./rec/frames")))
for frame_number in range(1,total_frames+1):
if frame_number > int(frame_tobe_encoded):
break
os.mkdir("./pred/{}".format(frame_number-1))
img = Image.open("./rec/frames/{}.jpg".format(frame_number))
img_width, img_height = img.size
ctu_numbers = math.ceil(img_width / 64) * math.ceil(img_height / 64)
label = []
for i in range(16):
label.append(str(i))
with torch.no_grad():
for i in range(ctu_numbers):
img_row = i // math.ceil(img_width / 64)
img_colonm = i % math.ceil(img_width / 64)
for layer2 in range(4):
start_pixel_x = img_colonm * 64 + (layer2 % 2)*32
start_pixel_y = img_row * 64 + (layer2 // 2)*32
cropped_img32 = img.crop((start_pixel_x, start_pixel_y, start_pixel_x + 32, start_pixel_y + 32))
cropped_img64 = img.crop((img_colonm * 64, img_row * 64, img_colonm * 64 + 64, img_row * 64 + 64))
data32 = transforms.ToTensor()(cropped_img32).unsqueeze(0)
data64 = transforms.ToTensor()(cropped_img64).unsqueeze(0)
cropped_img32.close()
cropped_img64.close()
data32 = data32.to(DEVICE)
data64 = data64.to(DEVICE)
output = model(data32,data64)
pred = str(int(torch.argmax(output[0,0:4]))) + str(int(torch.argmax(output[0,4:8]))) + str(int(torch.argmax(output[0,8:12]))) + str(int(torch.argmax(output[0,12:16])))
if "0" in pred and pred != "0000":
pred = pred.replace("0","1")
if "1" in pred and pred != "1111":
pred = pred.replace("1","2")
if layer2 == 0:
label[0],label[1],label[4],label[5] = pred[0],pred[1],pred[2],pred[3]
elif layer2 == 1:
if pred == "0000" and label[0] != "0":
pred = "1111"
label[2],label[3],label[6],label[7] = pred[0],pred[1],pred[2],pred[3]
elif layer2 == 2:
if pred == "0000" and label[2] != "0":
pred = "1111"
label[8],label[9],label[12],label[13] = pred[0],pred[1],pred[2],pred[3]
else:
if pred == "0000" and label[8] != "0":
pred = "1111"
label[10],label[11],label[14],label[15] = pred[0],pred[1],pred[2],pred[3]
# print(label)
with open("./pred/{}/ctu.txt".format(frame_number-1),'w',encoding='utf-8') as f:
for m in range(16):
f.write(label[m])
f.write(" ")
os.rename("./pred/{}/ctu.txt".format(frame_number-1),"./pred/{}/ctu{}.txt".format(frame_number-1,i))
img.close()
os.remove("./rec/frames/{}.jpg".format(frame_number))