Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions datafusion_ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,26 @@ class RayDataFrame:
def __init__(self, ray_internal_df: RayDataFrameInternal):
self.df = ray_internal_df
self.coordinator_id = self.df.coordinator_id
self._stages = []
self._stages = None
self._batches = None

def stages(self):
def stages(self, batch_size=8192):
# create our coordinator now, which we need to create stages
if not self._stages:
self.coord = RayStageCoordinator.options(
name="RayQueryCoordinator:" + self.coordinator_id,
).remote(self.coordinator_id)
self._stages = self.df.stages()
self._stages = self.df.stages(batch_size)
return self._stages

def execution_plan(self):
return self.df.execution_plan()

def collect(self) -> list[pa.RecordBatch]:
reader = self.reader()
self.batches = list(reader)
return self.batches
if not self._batches:
reader = self.reader()
self._batches = list(reader)
return self._batches

def show(self) -> None:
table = pa.Table.from_batches(self.collect())
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(self, coordinator_id: str) -> None:
self.exchanger = RayExchanger.remote()

def get_exchanger(self):
print("Coord: returning exchanger {self.exchanger}")
print(f"Coord: returning exchanger {self.exchanger}")
return self.exchanger

def new_stage(self, stage_id: str, plan_bytes: bytes):
Expand All @@ -123,7 +125,6 @@ def new_stage(self, stage_id: str, plan_bytes: bytes):
print(f"creating new stage {stage_id} from bytes {len(plan_bytes)}")
stage = RayStage.options(
name="stage:" + stage_id,
# lifetime="detached",
).remote(stage_id, plan_bytes, self.my_id, self.exchanger)
self.stages[stage_id] = stage

Expand Down Expand Up @@ -155,7 +156,7 @@ def run_stages(self):
raise e


@ray.remote(num_cpus=1)
@ray.remote(num_cpus=0)
class RayStage:
def __init__(
self, stage_id: str, plan_bytes: bytes, coordinator_id: str, exchanger
Expand All @@ -178,8 +179,13 @@ def consume(self):
reader = self.pystage.execute(partition)
for batch in reader:
ipc_batch = batch_to_ipc(batch)
o_ref = ray.put(ipc_batch)

# upload a nested object, list[oref] so that ray does not
# materialize it at the destination. The shuffler only
# needs to exchange object refs
ray.get(
self.exchanger.put.remote(self.stage_id, partition, ipc_batch)
self.exchanger.put.remote(self.stage_id, partition, [o_ref])
)
# signal there are no more batches
ray.get(self.exchanger.put.remote(self.stage_id, partition, None))
Expand Down Expand Up @@ -209,7 +215,7 @@ async def put(self, stage_id, output_partition, item):

q = self.queues[key]
await q.put(item)
print(f"RayExchanger got batch for {key}")
# print(f"RayExchanger got batch for {key}")

async def get(self, stage_id, output_partition):
key = f"{stage_id}-{output_partition}"
Expand Down Expand Up @@ -257,17 +263,20 @@ def __init__(self, exchanger, stage_id, partition):

def __next__(self):
obj_ref = self.exchanger.get.remote(self.stage_id, self.partition)
print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ref")
ipc_batch = ray.get(obj_ref)
# print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ref")
message = ray.get(obj_ref)

if ipc_batch is None:
if message is None:
raise StopIteration

print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ipc batch")
# other wise we know its a list of a single object ref
ipc_batch = ray.get(message[0])

# print(f"[RayIterable stage:{self.stage_id} p:{self.partition}] got ipc batch")
batch = ipc_to_batch(ipc_batch)
print(
f"[RayIterable stage:{self.stage_id} p:{self.partition}] converted to batch"
)
# print(
# f"[RayIterable stage:{self.stage_id} p:{self.partition}] converted to batch"
# )

return batch

Expand Down
66 changes: 66 additions & 0 deletions examples/ray_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import datafusion
import glob
import os
import ray
import pyarrow as pa
from datafusion_ray import RayContext


def go(data_dir: str, concurrency: int):
ctx = RayContext()
ctx.set("datafusion.execution.target_partitions", str(concurrency))
ctx.set("datafusion.catalog.information_schema", "true")
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")

for f in glob.glob(os.path.join(data_dir, "*parquet")):
print(f)
table, _ = os.path.basename(f).split(".")
ctx.register_parquet(table, f)

query = """SELECT customer.c_name, sum(orders.o_totalprice) as total_amount
FROM customer JOIN orders ON customer.c_custkey = orders.o_custkey
GROUP BY customer.c_name limit 10"""

# query = """SELECT count(customer.c_name), customer.c_mktsegment from customer group by customer.c_mktsegment limit 10"""

df = ctx.sql(query)
print(df.execution_plan().display_indent())
for stage in df.stages():
print(f"Stage ", stage.stage_id)
print(stage.execution_plan().display_indent())
b = stage.plan_bytes()
print(f"Stage bytes: {len(b)}")

df.show()

import time

time.sleep(3)


if __name__ == "__main__":
ray.init(namespace="example")
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", required=True, help="path to tpch*.parquet files")
parser.add_argument("--concurrency", required=True, type=int)
args = parser.parse_args()

go(args.data_dir, args.concurrency)
123 changes: 123 additions & 0 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::common::tree_node::Transformed;
use datafusion::common::tree_node::TreeNode;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::displayable;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_python::physical_plan::PyExecutionPlan;
use pyo3::prelude::*;
use std::sync::Arc;

use crate::pystage::PyStage;
use crate::ray_stage::RayStageExec;
use crate::ray_stage_reader::RayStageReaderExec;

pub struct CoordinatorId(pub String);

#[pyclass]
pub struct RayDataFrame {
physical_plan: Arc<dyn ExecutionPlan>,
#[pyo3(get)]
coordinator_id: String,
}

impl RayDataFrame {
pub fn new(physical_plan: Arc<dyn ExecutionPlan>, coordinator_id: String) -> Self {
Self {
physical_plan,
coordinator_id,
}
}
}

#[pymethods]
impl RayDataFrame {
#[pyo3(signature = (batch_size=8192))]
fn stages(&self, batch_size: usize) -> PyResult<Vec<PyStage>> {
let mut stages = vec![];

// TODO: This can be done more efficiently, likely in one pass but I'm
// struggling to get the TreeNodeRecursion return values to make it do
// what I want. So, two steps for now

// Step 2: we walk down this stage and replace stages earlier in the tree with
// RayStageReaderExecs
let down = |plan: Arc<dyn ExecutionPlan>| {
//println!("examining plan: {}", displayable(plan.as_ref()).one_line());

if let Some(stage_exec) = plan.as_any().downcast_ref::<RayStageExec>() {
let input = plan.children();
assert!(input.len() == 1, "RayStageExec must have exactly one child");
let input = input[0];

let replacement = Arc::new(RayStageReaderExec::try_new_from_input(
input.clone(),
stage_exec.stage_id.clone(),
self.coordinator_id.clone(),
)?) as Arc<dyn ExecutionPlan>;

Ok(Transformed::yes(replacement))
} else {
Ok(Transformed::no(plan))
}
};

// Step 1: we walk up the tree from the leaves to find the stages
let up = |plan: Arc<dyn ExecutionPlan>| {
println!("examining plan: {}", displayable(plan.as_ref()).one_line());

if let Some(stage_exec) = plan.as_any().downcast_ref::<RayStageExec>() {
let input = plan.children();
assert!(input.len() == 1, "RayStageExec must have exactly one child");
let input = input[0];

let fixed_plan = input.clone().transform_down(down)?.data;

// insert a coalescing batches here too so that we aren't sending
// too small of batches over the network
let final_plan = Arc::new(CoalesceBatchesExec::new(fixed_plan, batch_size))
as Arc<dyn ExecutionPlan>;

let stage = PyStage::new(
stage_exec.stage_id.clone(),
final_plan,
self.coordinator_id.clone(),
);

/*println!(
"made new stage {}: plan:\n{}",
stage_exec.stage_id,
displayable(stage.plan.as_ref()).indent(true)
);*/

stages.push(stage);
}

Ok(Transformed::no(plan))
};

self.physical_plan.clone().transform_up(up)?;

Ok(stages)
}

fn execution_plan(&self) -> PyResult<PyExecutionPlan> {
Ok(PyExecutionPlan::new(self.physical_plan.clone()))
}
}
10 changes: 5 additions & 5 deletions src/physical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl PhysicalOptimizerRule for RayShuffleOptimizerRule {
let mut stage_counter = 0;

let up = |plan: Arc<dyn ExecutionPlan>| {
println!("examining plan: {}", displayable(plan.as_ref()).one_line());
//println!("examining plan: {}", displayable(plan.as_ref()).one_line());

if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
let stage = Arc::new(RayStageExec::new(plan, stage_counter.to_string()));
Expand All @@ -67,10 +67,10 @@ impl PhysicalOptimizerRule for RayShuffleOptimizerRule {
let plan = plan.transform_up(up)?.data;
let final_plan = Arc::new(RayStageExec::new(plan, stage_counter.to_string()));

println!(
"optimized physical plan:\n{}",
displayable(final_plan.as_ref()).indent(false)
);
//println!(
// "optimized physical plan:\n{}",
// displayable(final_plan.as_ref()).indent(false)
//);
Ok(final_plan)
}

Expand Down
Loading