Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions crates/test-programs/src/bin/cli_p1_much_stdout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
fn main() {
let mut args = std::env::args().skip(1);
let string_to_write = args.next().unwrap();
let times_to_write = args.next().unwrap().parse::<u32>().unwrap();

let bytes = string_to_write.as_bytes();
for _ in 0..times_to_write {
let mut remaining = bytes;
while !remaining.is_empty() {
let iovec = wasip1::Ciovec {
buf: remaining.as_ptr(),
buf_len: remaining.len(),
};
let amt = unsafe { wasip1::fd_write(1, &[iovec]).unwrap() };
remaining = &remaining[amt..];
}
}
}
13 changes: 13 additions & 0 deletions crates/test-programs/src/bin/cli_p2_much_stdout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn main() {
let mut args = std::env::args().skip(1);
let string_to_write = args.next().unwrap();
let times_to_write = args.next().unwrap().parse::<u32>().unwrap();

let bytes = string_to_write.as_bytes();
let stdout = wasip2::cli::stdout::get_stdout();
for _ in 0..times_to_write {
for chunk in bytes.chunks(4096) {
stdout.blocking_write_and_flush(chunk).unwrap();
}
}
}
26 changes: 26 additions & 0 deletions crates/test-programs/src/bin/cli_p3_much_stdout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use test_programs::p3::{wasi, wit_stream};

struct Component;

test_programs::p3::export!(Component);

impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
async fn run() -> Result<(), ()> {
let mut args = std::env::args().skip(1);
let string_to_write = args.next().unwrap();
let times_to_write = args.next().unwrap().parse::<u32>().unwrap();

let bytes = string_to_write.as_bytes();
let (mut tx, rx) = wit_stream::new();
wasi::cli::stdout::set_stdout(rx);
for _ in 0..times_to_write {
let result = tx.write_all(bytes.to_vec()).await;
assert!(result.is_empty());
}
Ok(())
}
}

fn main() {
unreachable!();
}
35 changes: 29 additions & 6 deletions crates/wasi/src/cli/stdout.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::cli::{IsTerminal, StdoutStream};
use crate::p2;
use bytes::Bytes;
use std::io::{self, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncWrite;
use wasmtime_wasi_io::streams::OutputStream;

Expand All @@ -15,7 +18,7 @@ impl StdoutStream for tokio::io::Stdout {
Box::new(StdioOutputStream::Stdout)
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(tokio::io::stdout())
Box::new(StdioOutputStream::Stdout)
}
}

Expand All @@ -30,7 +33,7 @@ impl StdoutStream for std::io::Stdout {
Box::new(StdioOutputStream::Stdout)
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(tokio::io::stdout())
Box::new(StdioOutputStream::Stdout)
}
}

Expand All @@ -45,7 +48,7 @@ impl StdoutStream for tokio::io::Stderr {
Box::new(StdioOutputStream::Stderr)
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(tokio::io::stderr())
Box::new(StdioOutputStream::Stderr)
}
}

Expand All @@ -60,7 +63,7 @@ impl StdoutStream for std::io::Stderr {
Box::new(StdioOutputStream::Stderr)
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(tokio::io::stderr())
Box::new(StdioOutputStream::Stderr)
}
}

Expand All @@ -71,7 +74,6 @@ enum StdioOutputStream {

impl OutputStream for StdioOutputStream {
fn write(&mut self, bytes: Bytes) -> p2::StreamResult<()> {
use std::io::Write;
match self {
StdioOutputStream::Stdout => std::io::stdout().write_all(&bytes),
StdioOutputStream::Stderr => std::io::stderr().write_all(&bytes),
Expand All @@ -80,7 +82,6 @@ impl OutputStream for StdioOutputStream {
}

fn flush(&mut self) -> p2::StreamResult<()> {
use std::io::Write;
match self {
StdioOutputStream::Stdout => std::io::stdout().flush(),
StdioOutputStream::Stderr => std::io::stderr().flush(),
Expand All @@ -93,6 +94,28 @@ impl OutputStream for StdioOutputStream {
}
}

impl AsyncWrite for StdioOutputStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(match *self {
StdioOutputStream::Stdout => std::io::stdout().write(buf),
StdioOutputStream::Stderr => std::io::stderr().write(buf),
})
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(match *self {
StdioOutputStream::Stdout => std::io::stdout().flush(),
StdioOutputStream::Stderr => std::io::stderr().flush(),
})
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}

#[async_trait::async_trait]
impl p2::Pollable for StdioOutputStream {
async fn ready(&mut self) {}
Expand Down
47 changes: 47 additions & 0 deletions tests/all/cli_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ mod test_programs {
use http_body_util::BodyExt;
use hyper::header::HeaderValue;
use std::io::{BufRead, BufReader, Read, Write};
use std::iter;
use std::net::SocketAddr;
use std::process::{Child, Command, Stdio};
use test_programs_artifacts::*;
Expand Down Expand Up @@ -2173,6 +2174,52 @@ start a print 1234
assert_eq!(output, "\"hello?\"\n");
Ok(())
}

fn run_much_stdout(component: &str, extra_flags: &[&str]) -> Result<()> {
let total_write_size = 1 << 19;
let expected = iter::repeat('a').take(total_write_size).collect::<String>();

for i in 0..15 {
let string = iter::repeat('a').take(1 << i).collect::<String>();
let times = (total_write_size >> i).to_string();
println!("writing {} bytes {times} times", string.len());

let mut args = Vec::new();
args.push("run");
args.extend_from_slice(extra_flags);
args.push(component);
args.push(&string);
args.push(&times);
let output = run_wasmtime(&args)?;
println!(
"expected {} bytes, got {} bytes",
expected.len(),
output.len()
);
assert!(output == expected);
}

Ok(())
}

#[test]
fn cli_p1_much_stdout() -> Result<()> {
run_much_stdout(CLI_P1_MUCH_STDOUT_COMPONENT, &[])
}

#[test]
fn cli_p2_much_stdout() -> Result<()> {
run_much_stdout(CLI_P2_MUCH_STDOUT_COMPONENT, &[])
}

#[test]
#[cfg_attr(not(feature = "component-model-async"), ignore)]
fn cli_p3_much_stdout() -> Result<()> {
run_much_stdout(
CLI_P3_MUCH_STDOUT_COMPONENT,
&["-Wcomponent-model-async", "-Sp3"],
)
}
}

#[test]
Expand Down