Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit f74fd1b

Browse files
committed
fix: dynamically get cuda toolkit version
1 parent 670a477 commit f74fd1b

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

engine/commands/engine_init_cmd.cc

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,22 @@ bool EngineInitCmd::Exec() const {
6060
variants.push_back(asset_name);
6161
}
6262

63-
auto cuda_version = system_info_utils::GetCudaVersion();
64-
LOG_INFO << "engineName_: " << engineName_;
65-
LOG_INFO << "CUDA version: " << cuda_version;
66-
std::string matched_variant = "";
63+
auto cuda_driver_version = system_info_utils::GetCudaVersion();
64+
LOG_INFO << "Engine: " << engineName_
65+
<< ", CUDA driver version: " << cuda_driver_version;
66+
67+
std::string matched_variant{""};
6768
if (engineName_ == "cortex.tensorrt-llm") {
6869
matched_variant = engine_matcher_utils::ValidateTensorrtLlm(
69-
variants, system_info.os, cuda_version);
70+
variants, system_info.os, cuda_driver_version);
7071
} else if (engineName_ == "cortex.onnx") {
7172
matched_variant = engine_matcher_utils::ValidateOnnx(
7273
variants, system_info.os, system_info.arch);
7374
} else if (engineName_ == "cortex.llamacpp") {
7475
auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant();
7576
matched_variant = engine_matcher_utils::Validate(
7677
variants, system_info.os, system_info.arch, suitable_avx,
77-
cuda_version);
78+
cuda_driver_version);
7879
}
7980
LOG_INFO << "Matched variant: " << matched_variant;
8081
if (matched_variant.empty()) {
@@ -128,22 +129,50 @@ bool EngineInitCmd::Exec() const {
128129
LOG_INFO << "Finished!";
129130
});
130131
if (system_info.os == "mac" || engineName_ == "cortex.onnx") {
131-
return false;
132+
// mac and onnx engine does not require cuda toolkit
133+
return true;
132134
}
135+
133136
// download cuda toolkit
134137
const std::string jan_host = "https://catalog.jan.ai";
135138
const std::string cuda_toolkit_file_name = "cuda.tar.gz";
136139
const std::string download_id = "cuda";
137140

138-
auto gpu_driver_version = system_info_utils::GetDriverVersion();
141+
// TODO: we don't have API to retrieve list of cuda toolkit dependencies atm
142+
// will have better logic after https://github.com/janhq/cortex/issues/1046 finished
143+
// for now, assume that we have only 11.7 and 12.4
144+
auto suitable_toolkit_version = "";
145+
if (engineName_ == "cortex.tensorrt-llm") {
146+
// for tensorrt-llm, we need to download cuda toolkit v12.4
147+
suitable_toolkit_version = "12.4";
148+
} else {
149+
// llamacpp
150+
if (cuda_driver_version.starts_with("11.")) {
151+
suitable_toolkit_version = "11.7";
152+
} else if (cuda_driver_version.starts_with("12.")) {
153+
suitable_toolkit_version = "12.4";
154+
}
155+
}
139156

140-
auto cuda_runtime_version =
141-
cuda_toolkit_utils::GetCompatibleCudaToolkitVersion(
142-
gpu_driver_version, system_info.os, engineName_);
157+
// compare cuda driver version with cuda toolkit version
158+
// cuda driver version should be greater than toolkit version to ensure compatibility
159+
if (semantic_version_utils::CompareSemanticVersion(
160+
cuda_driver_version, suitable_toolkit_version) < 0) {
161+
LOG_ERROR << "Your Cuda driver version " << cuda_driver_version
162+
<< " is not compatible with cuda toolkit version "
163+
<< suitable_toolkit_version;
164+
return false;
165+
}
166+
167+
std::string cuda_version_path{""};
168+
if (!cuda_driver_version.empty()) {
169+
cuda_version_path = semantic_version_utils::ConvertToPath(
170+
suitable_toolkit_version);
171+
}
143172

144173
std::ostringstream cuda_toolkit_path;
145-
cuda_toolkit_path << "dist/cuda-dependencies/" << 11.7 << "/"
146-
<< system_info.os << "/"
174+
cuda_toolkit_path << "dist/cuda-dependencies/" << cuda_version_path
175+
<< "/" << system_info.os << "/"
147176
<< cuda_toolkit_file_name;
148177

149178
LOG_DEBUG << "Cuda toolkit download url: " << jan_host

engine/utils/engine_matcher_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
#include <trantor/utils/Logger.h>
12
#include <algorithm>
2-
#include <iostream>
33
#include <iterator>
44
#include <regex>
55
#include <string>
@@ -177,4 +177,4 @@ inline std::string Validate(const std::vector<std::string>& variants,
177177

178178
return cuda_compatible;
179179
}
180-
} // namespace engine_matcher_utils
180+
} // namespace engine_matcher_utils

engine/utils/semantic_version_utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,14 @@ inline int CompareSemanticVersion(const std::string& version1,
3131
}
3232
return 0;
3333
}
34-
} // namespace semantic_version_utils
34+
35+
// convert 11.7 to 11-7 for compatible to download url
36+
inline std::string ConvertToPath(const std::string& version) {
37+
std::string result = version;
38+
int pos = result.find('.');
39+
if (pos != std::string::npos) {
40+
result[pos] = '-';
41+
}
42+
return result;
43+
}
44+
} // namespace semantic_version_utils

0 commit comments

Comments
 (0)