Skip to content

Commit

Permalink
feat(trtllm): detect stop_words from generation_config.json
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 23, 2024
1 parent 6376fec commit 9cee00e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
8 changes: 8 additions & 0 deletions backends/trtllm/include/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ namespace huggingface::tgi::backends {
uint64_t seed
) noexcept;

/**
* Attempt to retrieve the
* @param generationConfigPath
* @return
*/
std::optional<std::list<std::vector<TokenId>>>
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;

/**
*
*/
Expand Down
42 changes: 27 additions & 15 deletions backends/trtllm/lib/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
);
}

std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
huggingface::tgi::backends::GetStopWordsFromConfig(
const std::filesystem::path &generationConfigPath) noexcept {
if (exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());

const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};

std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
return stopWords;
} else {
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
}

return std::nullopt;
}

huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
Expand All @@ -125,21 +150,8 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();

// Attempt to discover stopWords from the generation_config.json
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());

const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
stopWords = {};
}
const auto generationConfigPath = enginesFolder / "generation_config.json";
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
}

[[nodiscard("Returned number of requests needs to be consumed")]]
Expand Down

0 comments on commit 9cee00e

Please sign in to comment.