Skip to content

Commit

Permalink
Merge pull request #34 from Eventual-Inc/sql
Browse files Browse the repository at this point in the history
feat: Add a new `sql` and `ssh` command
  • Loading branch information
raunakab authored Jan 16, 2025
2 parents 132ef1d + f7bb92a commit abf036d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 24 deletions.
118 changes: 94 additions & 24 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ enum SubCommand {
/// Ray cluster.
Connect(Connect),

/// SSH into the head of the remote Ray cluster.
Ssh(ConfigPath),

/// Submit a SQL query string to the Ray cluster.
///
/// This is executed using Daft's SQL API support.
Sql(Sql),

/// Spin down a given cluster and put the nodes to "sleep".
///
/// This will *not* delete the nodes, only stop them. The nodes can be
Expand All @@ -94,9 +102,6 @@ struct Init {

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
struct List {
#[clap(flatten)]
config_path: ConfigPath,

/// The region which to list all the available clusters for.
#[arg(long)]
region: Option<StrRef>,
Expand All @@ -108,6 +113,9 @@ struct List {
/// Only list the running instances.
#[arg(long)]
running: bool,

#[clap(flatten)]
config_path: ConfigPath,
}

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
Expand All @@ -129,6 +137,15 @@ struct Connect {
config_path: ConfigPath,
}

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
struct Sql {
/// The SQL string to submit to the remote Ray cluster.
sql: StrRef,

#[clap(flatten)]
config_path: ConfigPath,
}

#[derive(Debug, Parser, Clone, PartialEq, Eq)]
struct ConfigPath {
/// Path to configuration file.
Expand Down Expand Up @@ -486,12 +503,16 @@ impl TeardownBehaviour {
}
}

fn create_temp_ray_file() -> anyhow::Result<(TempDir, PathRef)> {
fn create_temp_file(name: &str) -> anyhow::Result<(TempDir, PathRef)> {
let temp_dir = TempDir::new("daft-launcher")?;
let mut ray_path = temp_dir.path().to_owned();
ray_path.push("ray.yaml");
let ray_path = Arc::from(ray_path);
Ok((temp_dir, ray_path))
let mut temp_path = temp_dir.path().to_owned();
temp_path.push(name);
let temp_path = Arc::from(temp_path);
Ok((temp_dir, temp_path))
}

fn create_temp_ray_file() -> anyhow::Result<(TempDir, PathRef)> {
create_temp_file("ray.yaml")
}

async fn run_ray_up_or_down_command(
Expand Down Expand Up @@ -702,6 +723,25 @@ async fn get_head_node_ip(ray_path: impl AsRef<Path>) -> anyhow::Result<Ipv4Addr
Ok(addr)
}

async fn ssh(ray_path: impl AsRef<Path>, daft_config: &DaftConfig) -> anyhow::Result<()> {
let user = daft_config.setup.ssh_user.as_ref();
let addr = get_head_node_ip(ray_path).await?;
let exit_status = Command::new("ssh")
.arg("-i")
.arg(daft_config.setup.ssh_private_key.as_ref())
.arg(format!("{user}@{addr}"))
.kill_on_drop(true)
.spawn()?
.wait()
.await?;

if exit_status.success() {
Ok(())
} else {
Err(anyhow::anyhow!("Failed to ssh into the ray cluster"))
}
}

async fn establish_ssh_portforward(
ray_path: impl AsRef<Path>,
daft_config: &DaftConfig,
Expand Down Expand Up @@ -750,6 +790,26 @@ async fn establish_ssh_portforward(
Ok(child)
}

async fn submit(working_dir: &Path, command_segments: impl AsRef<[&str]>) -> anyhow::Result<()> {
let command_segments = command_segments.as_ref();

let exit_status = Command::new("ray")
.env("PYTHONUNBUFFERED", "1")
.args(["job", "submit", "--address", "http://localhost:8265"])
.arg("--working-dir")
.arg(working_dir)
.arg("--")
.args(command_segments)
.spawn()?
.wait()
.await?;
if exit_status.success() {
Ok(())
} else {
Err(anyhow::anyhow!("Failed to submit job to the ray cluster"))
}
}

async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
match daft_launcher.sub_command {
SubCommand::Init(Init { path }) => {
Expand All @@ -765,12 +825,12 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
let _ = read_and_convert(&config, None).await?;
}
SubCommand::Export(ConfigPath { config }) => {
let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?;
let (_, ray_config) = read_and_convert(&config, None).await?;
let ray_config_str = serde_yaml::to_string(&ray_config)?;
println!("{ray_config_str}");
}
SubCommand::Up(ConfigPath { config }) => {
let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?;
let (_, ray_config) = read_and_convert(&config, None).await?;
assert_is_logged_in_with_aws().await?;

let (_temp_dir, ray_path) = create_temp_ray_file()?;
Expand Down Expand Up @@ -803,20 +863,11 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
let (_temp_dir, ray_path) = create_temp_ray_file()?;
write_ray_config(ray_config, &ray_path).await?;
let _child = establish_ssh_portforward(ray_path, &daft_config, None).await?;

let exit_status = Command::new("ray")
.env("PYTHONUNBUFFERED", "1")
.args(["job", "submit", "--address", "http://localhost:8265"])
.arg("--working-dir")
.arg(daft_job.working_dir.as_ref())
.arg("--")
.args(daft_job.command.split(' '))
.spawn()?
.wait()
.await?;
if !exit_status.success() {
anyhow::bail!("Failed to submit job to the ray cluster");
};
submit(
daft_job.working_dir.as_ref(),
daft_job.command.as_ref().split(' ').collect::<Vec<_>>(),
)
.await?;
}
SubCommand::Connect(Connect { port, config_path }) => {
let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?;
Expand All @@ -829,6 +880,25 @@ async fn run(daft_launcher: DaftLauncher) -> anyhow::Result<()> {
.wait_with_output()
.await?;
}
SubCommand::Ssh(ConfigPath { config }) => {
let (daft_config, ray_config) = read_and_convert(&config, None).await?;
assert_is_logged_in_with_aws().await?;

let (_temp_dir, ray_path) = create_temp_ray_file()?;
write_ray_config(ray_config, &ray_path).await?;
ssh(ray_path, &daft_config).await?;
}
SubCommand::Sql(Sql { sql, config_path }) => {
let (daft_config, ray_config) = read_and_convert(&config_path.config, None).await?;
assert_is_logged_in_with_aws().await?;

let (_temp_dir, ray_path) = create_temp_ray_file()?;
write_ray_config(ray_config, &ray_path).await?;
let _child = establish_ssh_portforward(ray_path, &daft_config, None).await?;
let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?;
fs::write(sql_path, include_str!("sql.py")).await?;
submit(temp_sql_dir.path(), vec!["python", "sql.py", sql.as_ref()]).await?;
}
SubCommand::Stop(ConfigPath { config }) => {
let (_, ray_config) = read_and_convert(&config, Some(TeardownBehaviour::Stop)).await?;
assert_is_logged_in_with_aws().await?;
Expand Down
6 changes: 6 additions & 0 deletions src/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import daft
import sys

sql_query = sys.argv[1]
daft.context.set_runner_ray()
daft.sql(sql_query).show()

0 comments on commit abf036d

Please sign in to comment.