-
Notifications
You must be signed in to change notification settings - Fork 391
/
Copy pathpost-operation.c
52 lines (47 loc) · 1.83 KB
/
post-operation.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <xnnpack/allocator.h>
#include <xnnpack/microparams.h>
#include <xnnpack/params.h>
#include <xnnpack/post-operation.h>
char* allocate_and_initialize_post_operation_params(
size_t num_post_operations,
const struct xnn_post_operation* post_operations) {
union {
union xnn_f32_hswish_params hswish_params;
} post_op_params; // Anonymous union to hold params of all valid post operations.
// Calculate how much space all post operation params will take.
size_t total_size = 0;
for (size_t i = 0; i < num_post_operations; i++) {
const struct xnn_post_operation post_op = post_operations[i];
switch (post_op.op_type) {
case xnn_post_operation_type_hardswish:
if (xnn_params.f32.hswish.init.f32_hswish != NULL) {
total_size += xnn_params.f32.hswish.init.f32_hswish(&post_op_params.hswish_params);
}
break;
default:
XNN_UNREACHABLE;
}
}
// Copy all params compactly into post_operation_params.
char* post_operation_params = xnn_allocate_zero_memory(total_size);
char* cur_params = post_operation_params;
for (size_t i = 0; i < num_post_operations; i++) {
const struct xnn_post_operation post_op = post_operations[i];
switch (post_op.op_type) {
case xnn_post_operation_type_hardswish:
if (xnn_params.f32.hswish.init.f32_hswish != NULL) {
const size_t initialized_size = xnn_params.f32.hswish.init.f32_hswish(&post_op_params.hswish_params);
memcpy(cur_params, &post_op_params.hswish_params, initialized_size);
cur_params += initialized_size;
}
break;
default:
XNN_UNREACHABLE;
}
}
return post_operation_params;
}