Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 66 additions & 19 deletions implants/imix/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use tokio::sync::mpsc::{self, UnboundedSender};
#[derive(Debug)]
struct StreamPrinter {
tx: UnboundedSender<String>,
error_tx: UnboundedSender<String>,
}

impl StreamPrinter {
fn new(tx: UnboundedSender<String>) -> Self {
Self { tx }
fn new(tx: UnboundedSender<String>, error_tx: UnboundedSender<String>) -> Self {
Self { tx, error_tx }
}
}

Expand All @@ -29,7 +30,7 @@ impl Printer for StreamPrinter {

fn print_err(&self, _span: &Span, s: &str) {
// We format with newline to match BufferPrinter behavior
let _ = self.tx.send(format!("{}\n", s));
let _ = self.error_tx.send(format!("{}\n", s));
}
}

Expand Down Expand Up @@ -168,7 +169,8 @@ fn execute_task(
) {
// Setup StreamPrinter and Interpreter
let (tx, rx) = mpsc::unbounded_channel();
let printer = Arc::new(StreamPrinter::new(tx));
let (error_tx, error_rx) = mpsc::unbounded_channel();
let printer = Arc::new(StreamPrinter::new(tx, error_tx));
let mut interp = setup_interpreter(task_context.clone(), &tome, agent.clone(), printer.clone());

// Report Start
Expand All @@ -180,6 +182,7 @@ fn execute_task(
agent.clone(),
runtime_handle.clone(),
rx,
error_rx,
);

// Run Interpreter with panic protection
Expand Down Expand Up @@ -267,27 +270,71 @@ fn spawn_output_consumer(
agent: Arc<dyn Agent>,
runtime_handle: tokio::runtime::Handle,
mut rx: mpsc::UnboundedReceiver<String>,
mut error_rx: mpsc::UnboundedReceiver<String>,
) -> tokio::task::JoinHandle<()> {
runtime_handle.spawn(async move {
#[cfg(debug_assertions)]
log::info!("task={} Started output stream", task_context.task_id);
let task_id = task_context.task_id;
while let Some(msg) = rx.recv().await {
match agent.report_task_output(ReportTaskOutputRequest {
output: Some(TaskOutput {
id: task_id,
output: msg,
error: None,
exec_started_at: None,
exec_finished_at: None,
}),
context: Some(task_context.clone().into()),
}) {
Ok(_) => {}
Err(_e) => {
#[cfg(debug_assertions)]
log::error!("task={task_id} failed to report output: {_e}");
let mut rx_open = true;
let mut error_rx_open = true;

loop {
tokio::select! {
val = rx.recv(), if rx_open => {
match val {
Some(msg) => {
match agent.report_task_output(ReportTaskOutputRequest {
output: Some(TaskOutput {
id: task_id,
output: msg,
error: None,
exec_started_at: None,
exec_finished_at: None,
}),
context: Some(task_context.clone().into()),
}) {
Ok(_) => {}
Err(_e) => {
#[cfg(debug_assertions)]
log::error!("task={task_id} failed to report output: {_e}");
}
}
}
None => {
rx_open = false;
}
}
}
val = error_rx.recv(), if error_rx_open => {
match val {
Some(msg) => {
match agent.report_task_output(ReportTaskOutputRequest {
output: Some(TaskOutput {
id: task_id,
output: String::new(),
error: Some(TaskError { msg }),
exec_started_at: None,
exec_finished_at: None,
}),
context: Some(task_context.clone().into()),
}) {
Ok(_) => {}
Err(_e) => {
#[cfg(debug_assertions)]
log::error!("task={task_id} failed to report error: {_e}");
}
}
}
None => {
error_rx_open = false;
}
}
}
}

if !rx_open && !error_rx_open {
break;
}
}
})
Expand Down
58 changes: 58 additions & 0 deletions implants/imix/src/tests/task_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,61 @@ async fn test_task_registry_list_and_stop() {
"Task should be removed from list"
);
}

#[tokio::test]
async fn test_task_eprint_behavior() {
let agent = Arc::new(MockAgent::new());
let task_id = 111;
let code = "eprint(\"This is an error\")\nprint(\"This is output\")";

let task = c2::Task {
id: task_id,
tome: Some(Tome {
eldritch: code.to_string(),
..Default::default()
}),
quest_name: "eprint_test".to_string(),
..Default::default()
};

let registry = TaskRegistry::new();
registry.spawn(task, agent.clone());

tokio::time::sleep(Duration::from_secs(3)).await;

let reports = agent.output_reports.lock().unwrap();

// Check if "This is an error" appears in output or error field
let error_in_output = reports.iter().any(|r| {
r.output
.as_ref()
.map(|o| o.output.contains("This is an error"))
.unwrap_or(false)
});

let error_in_error = reports.iter().any(|r| {
r.output
.as_ref()
.map(|o| {
if let Some(err) = &o.error {
err.msg.contains("This is an error")
} else {
false
}
})
.unwrap_or(false)
});

println!("Error in output: {}", error_in_output);
println!("Error in error field: {}", error_in_error);

// Current behavior (before fix): eprint goes to output
// Desired behavior: eprint goes to error field

// So if I assert what I want:
assert!(error_in_error, "eprint should be reported as TaskError");
assert!(
!error_in_output,
"eprint should NOT be reported as regular output"
);
}
Loading