Skip to content

Commit

Permalink
Merge branch 'INSTX-5919_make_auto_device_resolution_available' into …
Browse files Browse the repository at this point in the history
…'master'

INSTX-5919 Make auto device resolution available to server

See merge request machine-learning/dorado!1137
  • Loading branch information
hpendry-ont committed Jul 30, 2024
2 parents 7d74246 + c4da8b5 commit df605b6
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 14 deletions.
3 changes: 2 additions & 1 deletion dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "read_pipeline/ReadFilterNode.h"
#include "read_pipeline/ReadToBamTypeNode.h"
#include "read_pipeline/ResumeLoader.h"
#include "torch_utils/auto_detect_device.h"
#include "utils/SampleSheet.h"
#include "utils/arg_parse_ext.h"
#include "utils/bam_utils.h"
Expand Down Expand Up @@ -656,7 +657,7 @@ int basecaller(int argc, char* argv[]) {
return EXIT_FAILURE;
}
if (device == cli::AUTO_DETECT_DEVICE) {
device = cli::get_auto_detected_device();
device = utils::get_auto_detected_device();
}

auto hts_file = cli::extract_hts_file(parser);
Expand Down
10 changes: 0 additions & 10 deletions dorado/cli/cli_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,6 @@ inline bool validate_device_string(const std::string& device) {
return false;
}

inline std::string get_auto_detected_device() {
#if DORADO_METAL_BUILD
return "metal";
#elif DORADO_CUDA_BUILD
return torch::cuda::is_available() ? "cuda:all" : "cpu";
#else
return "cpu";
#endif
}

} // namespace cli

} // namespace dorado
3 changes: 2 additions & 1 deletion dorado/cli/correct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "read_pipeline/CorrectionNode.h"
#include "read_pipeline/ErrorCorrectionMapperNode.h"
#include "read_pipeline/HtsWriter.h"
#include "torch_utils/auto_detect_device.h"
#include "torch_utils/torch_utils.h"
#include "utils/arg_parse_ext.h"
#include "utils/fs_utils.h"
Expand Down Expand Up @@ -92,7 +93,7 @@ int correct(int argc, char* argv[]) {
#if DORADO_METAL_BUILD
device = "cpu";
#else
device = cli::get_auto_detected_device();
device = utils::get_auto_detected_device();
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
#include "read_pipeline/ProgressTracker.h"
#include "read_pipeline/ReadFilterNode.h"
#include "read_pipeline/ReadToBamTypeNode.h"
#include "torch_utils/auto_detect_device.h"
#include "utils/SampleSheet.h"
#include "utils/arg_parse_ext.h"
#include "utils/bam_utils.h"
#include "utils/basecaller_utils.h"

#if DORADO_CUDA_BUILD
#include "torch_utils/cuda_utils.h"
#endif
Expand Down Expand Up @@ -365,7 +365,7 @@ int duplex(int argc, char* argv[]) {
return EXIT_FAILURE;
}
if (device == cli::AUTO_DETECT_DEVICE) {
device = cli::get_auto_detected_device();
device = utils::get_auto_detected_device();
}

auto model(parser.visible.get<std::string>("model"));
Expand Down
1 change: 1 addition & 0 deletions dorado/torch_utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_library(dorado_torch_utils
auto_detect_device.h
duplex_utils.cpp
duplex_utils.h
gpu_monitor.cpp
Expand Down
21 changes: 21 additions & 0 deletions dorado/torch_utils/auto_detect_device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#if DORADO_CUDA_BUILD
#include <torch/cuda.h>
#endif

#include <string>

namespace dorado::utils {

inline std::string get_auto_detected_device() {
#if DORADO_METAL_BUILD
return "metal";
#elif DORADO_CUDA_BUILD
return torch::cuda::is_available() ? "cuda:all" : "cpu";
#else
return "cpu";
#endif
}

} // namespace dorado::utils

0 comments on commit df605b6

Please sign in to comment.