@@ -60,21 +60,22 @@ bool EngineInitCmd::Exec() const {
6060 variants.push_back (asset_name);
6161 }
6262
63- auto cuda_version = system_info_utils::GetCudaVersion ();
64- LOG_INFO << " engineName_: " << engineName_;
65- LOG_INFO << " CUDA version: " << cuda_version;
66- std::string matched_variant = " " ;
63+ auto cuda_driver_version = system_info_utils::GetCudaVersion ();
64+ LOG_INFO << " Engine: " << engineName_
65+ << " , CUDA driver version: " << cuda_driver_version;
66+
67+ std::string matched_variant{" " };
6768 if (engineName_ == " cortex.tensorrt-llm" ) {
6869 matched_variant = engine_matcher_utils::ValidateTensorrtLlm (
69- variants, system_info.os , cuda_version );
70+ variants, system_info.os , cuda_driver_version );
7071 } else if (engineName_ == " cortex.onnx" ) {
7172 matched_variant = engine_matcher_utils::ValidateOnnx (
7273 variants, system_info.os , system_info.arch );
7374 } else if (engineName_ == " cortex.llamacpp" ) {
7475 auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant ();
7576 matched_variant = engine_matcher_utils::Validate (
7677 variants, system_info.os , system_info.arch , suitable_avx,
77- cuda_version );
78+ cuda_driver_version );
7879 }
7980 LOG_INFO << " Matched variant: " << matched_variant;
8081 if (matched_variant.empty ()) {
@@ -128,22 +129,50 @@ bool EngineInitCmd::Exec() const {
128129 LOG_INFO << " Finished!" ;
129130 });
130131 if (system_info.os == " mac" || engineName_ == " cortex.onnx" ) {
131- return false ;
132+ // mac and onnx engine does not require cuda toolkit
133+ return true ;
132134 }
135+
133136 // download cuda toolkit
134137 const std::string jan_host = " https://catalog.jan.ai" ;
135138 const std::string cuda_toolkit_file_name = " cuda.tar.gz" ;
136139 const std::string download_id = " cuda" ;
137140
138- auto gpu_driver_version = system_info_utils::GetDriverVersion ();
141+ // TODO: we don't have API to retrieve list of cuda toolkit dependencies atm
142+ // will have better logic after https://github.com/janhq/cortex/issues/1046 finished
143+ // for now, assume that we have only 11.7 and 12.4
144+ auto suitable_toolkit_version = " " ;
145+ if (engineName_ == " cortex.tensorrt-llm" ) {
146+ // for tensorrt-llm, we need to download cuda toolkit v12.4
147+ suitable_toolkit_version = " 12.4" ;
148+ } else {
149+ // llamacpp
150+ if (cuda_driver_version.starts_with (" 11." )) {
151+ suitable_toolkit_version = " 11.7" ;
152+ } else if (cuda_driver_version.starts_with (" 12." )) {
153+ suitable_toolkit_version = " 12.4" ;
154+ }
155+ }
139156
140- auto cuda_runtime_version =
141- cuda_toolkit_utils::GetCompatibleCudaToolkitVersion (
142- gpu_driver_version, system_info.os , engineName_);
157+ // compare cuda driver version with cuda toolkit version
158+ // cuda driver version should be greater than toolkit version to ensure compatibility
159+ if (semantic_version_utils::CompareSemanticVersion (
160+ cuda_driver_version, suitable_toolkit_version) < 0 ) {
161+ LOG_ERROR << " Your Cuda driver version " << cuda_driver_version
162+ << " is not compatible with cuda toolkit version "
163+ << suitable_toolkit_version;
164+ return false ;
165+ }
166+
167+ std::string cuda_version_path{" " };
168+ if (!cuda_driver_version.empty ()) {
169+ cuda_version_path = semantic_version_utils::ConvertToPath (
170+ suitable_toolkit_version);
171+ }
143172
144173 std::ostringstream cuda_toolkit_path;
145- cuda_toolkit_path << " dist/cuda-dependencies/" << 11.7 << " / "
146- << system_info.os << " /"
174+ cuda_toolkit_path << " dist/cuda-dependencies/" << cuda_version_path
175+ << " / " << system_info.os << " /"
147176 << cuda_toolkit_file_name;
148177
149178 LOG_DEBUG << " Cuda toolkit download url: " << jan_host
0 commit comments