Skip to content

Commit

Permalink
Merge pull request #18 from GilesBathgate/free-memory
Browse files Browse the repository at this point in the history
Remove unneeded buffer
  • Loading branch information
ankan-ban authored Apr 29, 2024
2 parents 684c797 + ef07c80 commit ca913d6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
1 change: 0 additions & 1 deletion common.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ typedef struct {
half* x; // activation at current time stamp (dim,)
half* xb; // same, but inside a residual branch (dim,)
half* hb; // buffer for hidden dimension in the ffn (hidden_dim,)
half* hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
half* q; // query (dim,)
half* att; // buffer for scores/attention values (n_heads, seq_len)
half* logits; // output logits
Expand Down
4 changes: 1 addition & 3 deletions llama2_q4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ void malloc_run_state(RunState* s, Config* p, bool allocLogitsArray) {
cudaMalloc((void**)&s->x, p->dim * sizeof(half));
cudaMalloc((void**)&s->xb, p->dim * sizeof(half));
cudaMalloc((void**)&s->hb, p->hidden_dim * sizeof(half));
cudaMalloc((void**)&s->hb2, p->hidden_dim * sizeof(half));
cudaMalloc((void**)&s->q, p->dim * sizeof(half));
cudaMalloc((void**)&s->att, p->n_heads * p->dim * sizeof(half));
cudaMalloc((void**)&s->logits, p->vocab_size * sizeof(half));
Expand All @@ -51,7 +50,7 @@ void malloc_run_state(RunState* s, Config* p, bool allocLogitsArray) {
cudaMallocHost((void**)&s->shared_data, sizeof(SharedData));

// ensure all mallocs went fine
if (!s->x || !s->xb || !s->pos || !s->hb || !s->hb2 || !s->q
if (!s->x || !s->xb || !s->pos || !s->hb || !s->q
|| !s->att || !s->logits || !s->key_cache
|| !s->value_cache || !s->shared_data) {
printf("malloc failed for allocaing run state!\n");
Expand All @@ -72,7 +71,6 @@ void free_run_state(RunState* s) {
cudaFree(s->xb);
cudaFree(s->pos);
cudaFree(s->hb);
cudaFree(s->hb2);
cudaFree(s->q);
cudaFree(s->att);
cudaFree(s->logits);
Expand Down

0 comments on commit ca913d6

Please sign in to comment.