Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/acyclib/src/device.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod cpu;
pub mod function;
pub mod multi;
pub mod operation;
pub mod tensor;

Expand Down
51 changes: 51 additions & 0 deletions crates/acyclib/src/device/multi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::sync::Arc;

use crate::device::{
Device,
cpu::{CpuError, CpuThread},
tensor::TensorRef,
};

pub trait MultiDeviceComm<D: Device> {
fn new(devices: Vec<Arc<D>>) -> Self;

fn reduce_sum_into_rank(&self, rank: usize, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;

fn scatter_rank_into_rest(&self, rank: usize, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
}

pub trait MultiDevice: Device {
type Comm: MultiDeviceComm<Self>;
}

impl MultiDevice for CpuThread {
type Comm = ();
}

impl MultiDeviceComm<CpuThread> for () {
fn new(_: Vec<Arc<CpuThread>>) -> Self {}

fn reduce_sum_into_rank(&self, rank: usize, buffers: &[TensorRef<CpuThread>]) -> Result<(), CpuError> {
let mut buf = buffers[rank].dense_mut();

for (i, other) in buffers.iter().enumerate() {
if rank != i {
buf.add(1.0, &other.dense()).map_err(|_| CpuError)?;
}
}

Ok(())
}

fn scatter_rank_into_rest(&self, rank: usize, buffers: &[TensorRef<CpuThread>]) -> Result<(), CpuError> {
let buf = buffers[rank].dense();

for (i, other) in buffers.iter().enumerate() {
if rank != i {
other.dense_mut().copy_from(&buf)?;
}
}

Ok(())
}
}
2 changes: 2 additions & 0 deletions crates/acyclib/src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod builder;
pub mod ir;
pub mod like;
pub mod multi;

use std::{collections::HashMap, fmt::Debug, sync::Arc};

Expand Down
44 changes: 39 additions & 5 deletions crates/acyclib/src/graph/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ use std::{

use crate::{
dag::NodeId,
device::{Device, function::Reduce, tensor::Shape},
device::{
Device,
function::Reduce,
multi::{MultiDevice, MultiDeviceComm},
tensor::Shape,
},
graph::{
Graph, GraphNodeId, GraphNodeIdTy,
ir::{
Expand All @@ -21,6 +26,7 @@ use crate::{
},
passes::GraphIRPass,
},
multi::MultiDeviceGraph,
},
};

Expand Down Expand Up @@ -121,8 +127,8 @@ where
SparseAffineActivate: GraphIROperationCompilable<B>,
Select: GraphIROperationCompilable<B>,
{
pub fn build(self, device: D) -> Graph<D> {
let mut ir = self.ir.into_inner().unwrap();
fn optimise(&mut self) {
let mut ir = self.ir.try_lock().unwrap();
let root = ir.root().unwrap();

if ir.get(root.idx).unwrap().ty().batched {
Expand All @@ -136,7 +142,7 @@ where
let unoptim = format!("subgraph cluster_0 {{\nlabel=\"Unoptimised\";\n{opts}{unoptim}}}");

ir.optimise().unwrap();
for pass in self.custom_passes.into_inner().unwrap() {
for pass in self.custom_passes.try_lock().unwrap().iter() {
ir.apply_any_pass(pass.as_ref()).unwrap();
}

Expand All @@ -147,14 +153,18 @@ where
write!(&mut file, "digraph G {{\n{unoptim}\n{optim}}}").unwrap();
} else {
ir.optimise().unwrap();
for pass in self.custom_passes.into_inner().unwrap() {
for pass in self.custom_passes.try_lock().unwrap().iter() {
ir.apply_any_pass(pass.as_ref()).unwrap();
}
}

if self.dump_ir_on_build {
println!("{}", ir.formatted().unwrap());
}
}

fn compile(&self, device: D) -> Graph<D> {
let ir = self.ir.try_lock().unwrap();

let graph = ir.compile(device).unwrap();

Expand Down Expand Up @@ -187,4 +197,28 @@ where

graph
}

pub fn build(mut self, device: D) -> Graph<D> {
self.optimise();
self.compile(device)
}
}

impl<D: Device<Marker = B> + MultiDevice, B: BackendMarker<Backend = D>> GraphBuilder<B>
where
SparseAffineActivate: GraphIROperationCompilable<B>,
Select: GraphIROperationCompilable<B>,
{
pub fn build_multi(mut self, devices: Vec<D>) -> MultiDeviceGraph<D> {
if devices.is_empty() {
panic!("No devices specified for multi-device training!");
}

self.optimise();

let graphs = devices.into_iter().map(|d| self.compile(d)).collect::<Vec<_>>();
let comm = D::Comm::new(graphs.iter().map(|g| g.device()).collect());

MultiDeviceGraph { comm, graphs }
}
}
2 changes: 1 addition & 1 deletion crates/acyclib/src/graph/ir/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ where
Ok(())
}

pub fn compile(self, device: B::Backend) -> Result<Graph<B::Backend>, GraphIRCompileError> {
pub fn compile(&self, device: B::Backend) -> Result<Graph<B::Backend>, GraphIRCompileError> {
let root = self.root()?.idx;
let root_data = self.get(root).unwrap().ty();

Expand Down
58 changes: 58 additions & 0 deletions crates/acyclib/src/graph/like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::sync::Arc;

use crate::{
device::{Device, OperationError, tensor::TensorRef},
graph::{Graph, GraphNodeId},
};

pub trait GraphLike<D: Device> {
fn devices(&self) -> Vec<Arc<D>>;

fn primary(&self) -> &Graph<D>;

fn primary_mut(&mut self) -> &mut Graph<D>;

fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>>;

fn get_output_value(&self) -> Result<f32, OperationError<D::DeviceError>>;

fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<D::DeviceError>>;

fn reduce_sum_into_first(&self, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;

fn scatter_first_into_rest(&self, buffers: &[TensorRef<D>]) -> Result<(), D::DeviceError>;
}

impl<D: Device> GraphLike<D> for Graph<D> {
fn devices(&self) -> Vec<Arc<D>> {
vec![self.device()]
}

fn primary(&self) -> &Graph<D> {
self
}

fn primary_mut(&mut self) -> &mut Graph<D> {
self
}

fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>> {
self.get(id).map(|x| vec![x])
}

fn get_output_value(&self) -> Result<f32, OperationError<<D as Device>::DeviceError>> {
self.get_output_val()
}

fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<<D as Device>::DeviceError>> {
self.execute(name)
}

fn reduce_sum_into_first(&self, _: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
Ok(())
}

fn scatter_first_into_rest(&self, _: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
Ok(())
}
}
59 changes: 59 additions & 0 deletions crates/acyclib/src/graph/multi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::sync::Arc;

use crate::{
device::{
Device, OperationError,
multi::{MultiDevice, MultiDeviceComm},
tensor::TensorRef,
},
graph::{Graph, GraphNodeId, like::GraphLike},
};

pub struct MultiDeviceGraph<D: Device + MultiDevice> {
pub(super) comm: D::Comm,
pub(super) graphs: Vec<Graph<D>>,
}

impl<D: Device + MultiDevice> GraphLike<D> for MultiDeviceGraph<D> {
fn devices(&self) -> Vec<Arc<D>> {
self.graphs.iter().map(Graph::device).collect()
}

fn primary(&self) -> &Graph<D> {
&self.graphs[0]
}

fn primary_mut(&mut self) -> &mut Graph<D> {
&mut self.graphs[0]
}

fn get_all(&self, id: GraphNodeId) -> Result<Vec<TensorRef<D>>, OperationError<<D as Device>::DeviceError>> {
self.graphs.iter().map(|g| g.get(id)).collect()
}

fn get_output_value(&self) -> Result<f32, OperationError<<D as Device>::DeviceError>> {
let mut sum = 0.0;

for g in &self.graphs {
sum += g.get_output_val()?;
}

Ok(sum)
}

fn execute_fn(&mut self, name: &str) -> Result<(), OperationError<<D as Device>::DeviceError>> {
for g in &mut self.graphs {
g.execute(name)?;
}

Ok(())
}

fn reduce_sum_into_first(&self, buffers: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
self.comm.reduce_sum_into_rank(0, buffers)
}

fn scatter_first_into_rest(&self, buffers: &[TensorRef<D>]) -> Result<(), <D as Device>::DeviceError> {
self.comm.scatter_rank_into_rest(0, buffers)
}
}
30 changes: 18 additions & 12 deletions crates/acyclib/src/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use schedule::TrainingSchedule;

use std::{sync::mpsc, thread, time::Instant};

use crate::device::{Device, OperationError};
use crate::{
device::{Device, OperationError},
graph::like::GraphLike,
};

#[derive(Debug)]
pub enum DataLoadingError {
Expand All @@ -23,6 +26,7 @@ pub enum TrainerError<D: Device> {
DataLoadingError(DataLoadingError),
GradientCalculationError(OperationError<D::DeviceError>),
Unexpected(OperationError<D::DeviceError>),
MoreDevicesThanBatchSize(usize, usize),
IoError,
}

Expand All @@ -32,12 +36,12 @@ impl<D: Device> From<DataLoadingError> for TrainerError<D> {
}
}

pub struct Trainer<D: Device, O: OptimiserState<D>, S> {
pub optimiser: Optimiser<D, O>,
pub struct Trainer<D: Device, G: GraphLike<D>, O: OptimiserState<D>, S> {
pub optimiser: Optimiser<D, G, O>,
pub state: S,
}

impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
impl<D: Device, G: GraphLike<D>, O: OptimiserState<D>, S> Trainer<D, G, O, S> {
pub fn train_custom(
&mut self,
schedule: TrainingSchedule,
Expand All @@ -52,7 +56,9 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
let lr = schedule.lr_schedule;
let steps = schedule.steps;

self.optimiser.graph.synchronise().unwrap();
if self.optimiser.graph.devices().len() > steps.batch_size {
return Err(TrainerError::MoreDevicesThanBatchSize(self.optimiser.graph.devices().len(), steps.batch_size));
}

let (sender, receiver) = mpsc::sync_channel::<PreparedBatchHost>(32);

Expand Down Expand Up @@ -88,7 +94,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
let first_batch =
receiver.recv().map_err(|_| TrainerError::DataLoadingError(DataLoadingError::NoBatchesReceived))?;

let mut batch_on_device = PreparedBatchDevice::new(self.optimiser.graph.device(), &first_batch)
let mut batch_on_device = PreparedBatchDevice::new(self.optimiser.graph.devices(), &first_batch)
.map_err(|_| TrainerError::DataLoadingError(DataLoadingError::CopyToDevice))?;

let mut batch_queued = true;
Expand Down Expand Up @@ -122,14 +128,14 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {

batch_on_device.load_into_graph(&mut self.optimiser.graph)?;

fn step<D: Device, S: OptimiserState<D>>(
optim: &mut Optimiser<D, S>,
fn step<D: Device, G: GraphLike<D>, S: OptimiserState<D>>(
optim: &mut Optimiser<D, G, S>,
gradient_factor: f32,
learning_rate: f32,
) -> Result<(), OperationError<D::DeviceError>> {
optim.graph.execute("zero_grads")?;
optim.graph.execute("forward")?;
optim.graph.execute("backward")?;
optim.graph.execute_fn("zero_grads")?;
optim.graph.execute_fn("forward")?;
optim.graph.execute_fn("backward")?;
optim.update(gradient_factor, learning_rate)
}

Expand All @@ -143,7 +149,7 @@ impl<D: Device, O: OptimiserState<D>, S> Trainer<D, O, S> {
batch_queued = false;
}

let error = self.optimiser.graph.get_output_val().unwrap() / this_batch_size as f32;
let error = self.optimiser.graph.get_output_value().unwrap() / this_batch_size as f32;

running_loss += error;
superbatch_positions += this_batch_size;
Expand Down
Loading