|
| 1 | +function [cp, param_grad] = conv_net(params, layers, data, labels) |
| 2 | + |
| 3 | + l = length(layers); |
| 4 | + batch_size = layers{1}.batch_size; |
| 5 | + %% Forward pass |
| 6 | + output = convnet_forward(params, layers, data); |
| 7 | + |
| 8 | + %% Loss layer |
| 9 | + i = l; |
| 10 | + assert(strcmp(layers{i}.type, 'LOSS') == 1, 'last layer must be loss layer'); |
| 11 | + |
| 12 | + wb = [params{i-1}.w(:); params{i-1}.b(:)]; |
| 13 | + [cost, grad, input_od, percent] = mlrloss(wb, output{i-1}.data, labels, layers{i}.num, 0, 1); |
| 14 | + |
| 15 | + %% Back prop |
| 16 | + if nargout >= 2 |
| 17 | + param_grad{i-1}.w = reshape(grad(1:length(params{i-1}.w(:))), size(params{i-1}.w)); |
| 18 | + param_grad{i-1}.b = reshape(grad(end - length(params{i-1}.b(:)) + 1 : end), size(params{i-1}.b)); |
| 19 | + param_grad{i-1}.w = param_grad{i-1}.w / batch_size; |
| 20 | + param_grad{i-1}.b = param_grad{i-1}.b /batch_size; |
| 21 | + end |
| 22 | + |
| 23 | + cp.cost = cost/batch_size; |
| 24 | + cp.percent = percent; |
| 25 | + |
| 26 | + if nargout >= 2 |
| 27 | + for i = l-1:-1:2 |
| 28 | + switch layers{i}.type |
| 29 | + case 'CONV' |
| 30 | + output{i}.diff = input_od; |
| 31 | + [param_grad{i-1}, input_od] = conv_layer_backward(output{i}, output{i-1}, layers{i}, params{i-1}); |
| 32 | + case 'POOLING' |
| 33 | + output{i}.diff = input_od; |
| 34 | + [input_od] = pooling_layer_backward(output{i}, output{i-1}, layers{i}); |
| 35 | + param_grad{i-1}.w = []; |
| 36 | + param_grad{i-1}.b = []; |
| 37 | + case 'IP' |
| 38 | + output{i}.diff = input_od; |
| 39 | + [param_grad{i-1}, input_od] = inner_product_backward(output{i}, output{i-1}, layers{i}, params{i-1}); |
| 40 | + case 'RELU' |
| 41 | + output{i}.diff = input_od; |
| 42 | + [input_od] = relu_backward(output{i}, output{i-1}, layers{i}); |
| 43 | + param_grad{i-1}.w = []; |
| 44 | + param_grad{i-1}.b = []; |
| 45 | + case 'ELU' |
| 46 | + output{i}.diff = input_od; |
| 47 | + [input_od] = elu_backward(output{i}, output{i-1}, layers{i}); |
| 48 | + param_grad{i-1}.w = []; |
| 49 | + param_grad{i-1}.b = []; |
| 50 | + end |
| 51 | + param_grad{i-1}.w = param_grad{i-1}.w / batch_size; |
| 52 | + param_grad{i-1}.b = param_grad{i-1}.b / batch_size; |
| 53 | + end |
| 54 | + end |
| 55 | +end |
0 commit comments