Skip to content

Commit

Permalink
minor simplification in get_dataset_perplexity function
Browse files Browse the repository at this point in the history
  • Loading branch information
ankan-ban committed Nov 2, 2023
1 parent dab073d commit e3f6986
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions perplexity.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void run_transformer(bool gen_token, Config* p, RunState* s, TransformerWeights*
// ----------------------------------------------------------------------------
float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config, RunState* state, TransformerWeights* weights, Sampler *pSampler) {
int bytes = strlen(dataset);
int* datasetTokens = (int*)malloc(bytes * sizeof(int));
int* datasetTokens = &(state->shared_data->tokens[1]);

printf("\nTokenizing Dataset...");
int totalTokens;
Expand All @@ -76,7 +76,6 @@ float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config
cudaMemset(state->pos, 0, sizeof(int));
state->shared_data->pos = 0;
state->shared_data->tokens[0] = bos_token;
memcpy(&(state->shared_data->tokens[1]), datasetTokens, sizeof(int) * numTokens);
for (int pos = 0; pos < numTokens; pos++) {
run_transformer(false, config, state, weights, true, pSampler);
cudaDeviceSynchronize();
Expand All @@ -93,7 +92,6 @@ float get_dataset_perplexity(char* dataset, Tokenizer* tokenizer, Config* config

printf("\nPerplexity computed on %d tokens: %f\n\n", numTokens, pplx);
free(logits_arr);
free(datasetTokens);
return pplx;
}

Expand Down

0 comments on commit e3f6986

Please sign in to comment.