Skip to content

Commit

Permalink
betterbetter
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Jul 31, 2024
1 parent 67b3c9a commit 39f64af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ KERNEL(lstm_seq)(
const uint b = get_global_id(1);
const int weight_offsets[4] = {GEMM_OFFSET_F, GEMM_OFFSET_I, GEMM_OFFSET_Z, GEMM_OFFSET_O};
const int gate_num = 4;
printf("b %d is hidden is %d hsize is %d \n", b, hidden_idx, HIDDEN_SIZE);
OUTPUT_TYPE local_hidden_state = 0;
OUTPUT_TYPE hidden_result[gate_num];
OUTPUT_TYPE input_result[gate_num];
OUTPUT_TYPE gate_output[gate_num];
printf("b %d is hidden is %d hsize is %d seq len is %d\n", b, hidden_idx, HIDDEN_SIZE, sequence_lengths[INPUT3_GET_INDEX_SAFE(b, 0, 0, 0)]);
ACCUMULATOR_TYPE hidden_result[gate_num];
ACCUMULATOR_TYPE input_result[gate_num];
ACCUMULATOR_TYPE gate_output[gate_num];

for(int k=0;k<gate_num;k++){
gate_output[k] = 0;
Expand Down Expand Up @@ -66,15 +65,14 @@ KERNEL(lstm_seq)(
}

if (i==0){
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = (OUTPUT_TYPE)(gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)]);
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += (OUTPUT_TYPE)(gate_output[1]*gate_output[2]);
}else{
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] *= gate_output[0];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += gate_output[1]*gate_output[2];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] *= (OUTPUT_TYPE)gate_output[0];
cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] += (OUTPUT_TYPE)(gate_output[1]*gate_output[2]);
}
local_hidden_state = gate_output[3]*ACTIVATION_H(cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = local_hidden_state;
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, i, hidden_idx)] = (OUTPUT_TYPE)(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
}
//printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
//printf("result is %f %f \n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 0, 0)], hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 1, 0)]);
//printf("result is %f %f fb %d\n", hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 0, hidden_idx)], hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, 1, hidden_idx)], b);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ JitConstants LSTMSeqKernelBase::GetJitConstants(const lstm_seq_params& params) c
jit.AddConstants({MakeJitConstant("MAX_SEQ_LENGTH", params.inputs[0].Feature().v)});
jit.AddConstants({MakeJitConstant("INPUT_SIZE", params.inputs[0].Y().v)});
jit.AddConstants({MakeJitConstant("HIDDEN_SIZE", params.inputs[1].Y().v)});
auto ftype = GetUnitType(params);
auto ftype = params.inputs[0].GetDType();
// if ReLU activation present, we have to reset accumulator type for the kernel to FP32
// to avoid possible overflows on FP16, since ReLU doesn't limit upper border of its result
for (size_t i = 0; i < params.activations.size(); i++) {
Expand Down

0 comments on commit 39f64af

Please sign in to comment.