Skip to content

Commit

Permalink
fix: add max retries flag
Browse files Browse the repository at this point in the history
  • Loading branch information
rhysnewell committed Nov 16, 2023
1 parent 28e44a2 commit 76020d7
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rosella"
version = "0.5.0"
version = "0.5.1"
authors = ["Rhys Newell <rhys.newell94@gmail.com"]
license = "GPL-3.0"
description = "Metagenome assembled genome recovery from metagenomes using UMAP and HDBSCAN"
Expand Down
21 changes: 21 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ fn binning_params_section() -> Section {
default_roff("100")
)),
)
.option(
Opt::new("INT")
.long("--max-retries")
.help(&format!(
"Maximum number of times to retry refining a genome \
if it fails. [default: {}] \n",
default_roff("5")
)),
)
}

fn refining_options() -> Section {
Expand Down Expand Up @@ -829,6 +838,12 @@ pub fn build_cli() -> Command {
.value_parser(clap::value_parser!(usize))
.default_value("100"),
)
.arg(
Arg::new("max-retries")
.long("max-retries")
.value_parser(clap::value_parser!(usize))
.default_value("5"),
)
.arg(
Arg::new("max-nb-connections")
.long("max-nb-connections")
Expand Down Expand Up @@ -1235,6 +1250,12 @@ pub fn build_cli() -> Command {
.value_parser(clap::value_parser!(f64))
.default_value("15.0"),
)
.arg(
Arg::new("max-retries")
.long("max-retries")
.value_parser(clap::value_parser!(usize))
.default_value("5"),
)
.arg(
Arg::new("bin-tag")
.long("bin-tag")
Expand Down
1 change: 1 addition & 0 deletions src/kmers/kmer_counting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl KmerCounter {

let mut kmer_table = Vec::with_capacity(n_contigs);
let mut contig_names = Vec::with_capacity(n_contigs);
let mut n_contigs = 0;
while let Some(record) = reader.next() {
let seqrec = record?;
let contig_name = std::str::from_utf8(seqrec.id())?.to_string();
Expand Down
8 changes: 6 additions & 2 deletions src/recover/recover_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct RecoverEngine {
min_bin_size: usize,
min_contig_size: usize,
filtered_contigs: HashSet<String>,
max_retries: usize
}

impl RecoverEngine {
Expand Down Expand Up @@ -128,6 +129,7 @@ impl RecoverEngine {
let min_bin_size = m.get_one::<usize>("min-bin-size").unwrap().clone();

let n_contigs = coverage_table.table.nrows();
let max_retries = m.get_one::<usize>("max-retries").unwrap().clone();
Ok(
Self {
output_directory,
Expand All @@ -145,6 +147,7 @@ impl RecoverEngine {
min_contig_size,
// filtered_contigs,
filtered_contigs: HashSet::new(),
max_retries
}
)
}
Expand Down Expand Up @@ -199,7 +202,7 @@ impl RecoverEngine {

info!("Writing clusters.");
self.write_clusters(cluster_results, true)?;

self.run_refinery()?;
}

Expand Down Expand Up @@ -259,7 +262,8 @@ impl RecoverEngine {
checkm_results: None,
threads: rayon::current_num_threads(),
// bin_unbinned: true,
bin_unbinned: false
bin_unbinned: false,
max_retries: self.max_retries
};

refinery.run("refined_0")?;
Expand Down
58 changes: 54 additions & 4 deletions src/refine/refinery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct RefineEngine {
pub(crate) min_bin_size: usize,
pub(crate) n_neighbours: usize,
pub(crate) bin_unbinned: bool,
pub(crate) max_retries: usize
}

impl RefineEngine {
Expand Down Expand Up @@ -65,6 +66,7 @@ impl RefineEngine {
let min_contig_size = *m.get_one::<usize>("min-contig-size").unwrap();
let min_bin_size = *m.get_one::<usize>("min-bin-size").unwrap();
let n_neighbours = *m.get_one::<usize>("n-neighbours").unwrap();
let max_retries = *m.get_one::<usize>("max-retries").unwrap();


Ok(Self {
Expand All @@ -79,6 +81,7 @@ impl RefineEngine {
min_bin_size,
n_neighbours,
bin_unbinned: false,
max_retries
})
}

Expand Down Expand Up @@ -110,6 +113,7 @@ impl RefineEngine {
let extra_threads = max(self.threads / self.mags_to_refine.len(), 1);

info!("Beginning refinement of {} MAGs", self.mags_to_refine.len());
let flight_version = self.get_flight_version();
let progress_bar = ProgressBar::new(self.mags_to_refine.len() as u64);
progress_bar.set_style(ProgressStyle::default_bar()
.template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} {msg}"));
Expand All @@ -120,7 +124,7 @@ impl RefineEngine {
.par_iter()
.filter(|genome| !genome.contains(UNBINNED))
.map(|genome| {
let result = self.run_flight_refine(genome, extra_threads);
let result = self.run_flight_refine(genome, extra_threads, &flight_version);
progress_bar.inc(1);
(result, genome)
}).collect::<Vec<(_, _)>>();
Expand All @@ -131,7 +135,7 @@ impl RefineEngine {
.iter()
.filter(|genome| genome.contains(UNBINNED))
.map(|genome| {
let result = self.run_flight_refine(genome, self.threads);
let result = self.run_flight_refine(genome, self.threads, &flight_version);
progress_bar.inc(1);
(result, genome)
}).collect::<Vec<(_, _)>>();
Expand Down Expand Up @@ -162,6 +166,7 @@ impl RefineEngine {
}
}

debug!("Unified cluster map: {:?}", &unified_cluster_map);
// get outliers
let outliers = unified_cluster_map.remove(&0);
let cluster_results = self.get_cluster_result(unified_cluster_map, outliers.unwrap_or_else(|| HashSet::new()));
Expand All @@ -187,14 +192,32 @@ impl RefineEngine {
Ok(())
}

fn run_flight_refine(&self, genome_path: &str, extra_threads: usize) -> Result<String> {
fn get_flight_version(&self) -> Vec<usize> {
let mut flight_version = Command::new("flight");
flight_version.arg("--version");
let flight_version = flight_version.output().unwrap();
let flight_version = String::from_utf8(flight_version.stdout).unwrap();
// remove new line
let flight_version = flight_version.trim_end();
debug!("Flight version: {}", flight_version);

let version = flight_version.split(".").collect::<Vec<_>>();
debug!("Version: {:?}", version);
let version = (version[0].parse::<usize>().unwrap(), version[1].parse::<usize>().unwrap(), version[2].parse::<usize>().unwrap());

let version = vec![version.0, version.1, version.2];
version
}

fn run_flight_refine(&self, genome_path: &str, extra_threads: usize, version: &[usize]) -> Result<String> {

// get output prefix from genome path by remove path and extensions
let output_prefix = genome_path
.split("/")
.collect::<Vec<_>>()
.last().unwrap().to_string();


let mut flight_cmd = Command::new("flight");
flight_cmd.arg("refine");
// flight_cmd.arg("--assembly").arg(&self.assembly);
Expand All @@ -213,6 +236,14 @@ impl RefineEngine {
}


if version >= &[1, 6, 2] {
debug!("Using max_retries flag");
flight_cmd.arg("--max_retries").arg(format!("{}", self.max_retries));
} else {
debug!("Not using max_retries flag")
}


flight_cmd.stdout(std::process::Stdio::piped());
flight_cmd.stderr(std::process::Stdio::piped());

Expand All @@ -235,6 +266,25 @@ impl RefineEngine {
}
}
bail!("Flight failed with exit code: {}", exit_status);
} else {
if let Some(stdout) = child.stdout.take() {
let stdout = std::io::BufReader::new(stdout);
for line in stdout.lines() {
let line = line?;
let message = line.split("INFO: ").collect::<Vec<_>>();
debug!("{}", message[message.len() - 1]);
}
}

// same for stderr
if let Some(stderr) = child.stderr.take() {
let stderr = std::io::BufReader::new(stderr);
for line in stderr.lines() {
let line = line?;
let message = line.split("INFO: ").collect::<Vec<_>>();
debug!("{}", message[message.len() - 1]);
}
}
}

let output_json = format!("{}/{}.json", self.output_directory, output_prefix);
Expand Down Expand Up @@ -338,7 +388,7 @@ impl RefineEngine {
}
cluster_results.par_sort_unstable();

debug!("Cluster results: {:?}", &cluster_results[0..10]);
debug!("Cluster results: {:?}", &cluster_results);
cluster_results
}

Expand Down

0 comments on commit 76020d7

Please sign in to comment.