-
Notifications
You must be signed in to change notification settings - Fork 190
/
realesrgan_preproc_tta.comp
113 lines (89 loc) · 3.02 KB
/
realesrgan_preproc_tta.comp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#define sfp float16_t
#else
#define sfp float
#endif
#if NCNN_int8_storage
#extension GL_EXT_shader_8bit_storage: require
#endif
layout (constant_id = 0) const int bgr = 0;
#if NCNN_int8_storage
layout (binding = 0) readonly buffer bottom_blob { uint8_t bottom_blob_data[]; };
#else
layout (binding = 0) readonly buffer bottom_blob { float bottom_blob_data[]; };
#endif
layout (binding = 1) writeonly buffer top_blob0 { sfp top_blob0_data[]; };
layout (binding = 2) writeonly buffer top_blob1 { sfp top_blob1_data[]; };
layout (binding = 3) writeonly buffer top_blob2 { sfp top_blob2_data[]; };
layout (binding = 4) writeonly buffer top_blob3 { sfp top_blob3_data[]; };
layout (binding = 5) writeonly buffer top_blob4 { sfp top_blob4_data[]; };
layout (binding = 6) writeonly buffer top_blob5 { sfp top_blob5_data[]; };
layout (binding = 7) writeonly buffer top_blob6 { sfp top_blob6_data[]; };
layout (binding = 8) writeonly buffer top_blob7 { sfp top_blob7_data[]; };
layout (binding = 9) writeonly buffer alpha_blob { sfp alpha_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
int outw;
int outh;
int outcstep;
int pad_top;
int pad_left;
int crop_x;
int crop_y;
int channels;
int alphaw;
int alphah;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.outw || gy >= p.outh || gz >= p.channels)
return;
int x = gx + p.crop_x - p.pad_left;
int y = gy + p.crop_y - p.pad_top;
x = abs(x);
y = abs(y);
x = (p.w - 1) - abs(x - (p.w - 1));
y = (p.h - 1) - abs(y - (p.h - 1));
#if NCNN_int8_storage
int v_offset = y * p.w + x;
float v;
if (bgr == 1 && gz != 3)
v = float(uint(bottom_blob_data[v_offset * p.channels + 2 - gz]));
else
v = float(uint(bottom_blob_data[v_offset * p.channels + gz]));
#else
int v_offset = gz * p.cstep + y * p.w + x;
float v = bottom_blob_data[v_offset];
#endif
if (gz == 3)
{
gx -= p.pad_left;
gy -= p.pad_top;
if (gx >= 0 && gx < p.alphaw && gy >= 0 && gy < p.alphah)
{
alpha_blob_data[gy * p.alphaw + gx] = sfp(v);
}
}
else
{
const float norm_val = 1 / 255.f;
v = v * norm_val;
int gzi = gz * p.outcstep;
top_blob0_data[gzi + gy * p.outw + gx] = sfp(v);
top_blob1_data[gzi + gy * p.outw + (p.outw - 1 - gx)] = sfp(v);
top_blob2_data[gzi + (p.outh - 1 - gy) * p.outw + (p.outw - 1 - gx)] = sfp(v);
top_blob3_data[gzi + (p.outh - 1 - gy) * p.outw + gx] = sfp(v);
top_blob4_data[gzi + gx * p.outh + gy] = sfp(v);
top_blob5_data[gzi + gx * p.outh + (p.outh - 1 - gy)] = sfp(v);
top_blob6_data[gzi + (p.outw - 1 - gx) * p.outh + (p.outh - 1 - gy)] = sfp(v);
top_blob7_data[gzi + (p.outw - 1 - gx) * p.outh + gy] = sfp(v);
}
}