Skip to content

Commit

Permalink
[Refactor] Create an Intermediate representation of compute shaders o…
Browse files Browse the repository at this point in the history
…n the GPU (#1274)
  • Loading branch information
nathanielsimard authored Feb 8, 2024
1 parent a9b6dbc commit fb6cc2d
Show file tree
Hide file tree
Showing 42 changed files with 1,686 additions and 1,105 deletions.
9 changes: 9 additions & 0 deletions burn-wgpu/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use super::dialect::gpu;
use std::fmt::Display;

pub trait Compiler: Sync + Send + 'static {
type Representation: Display;

fn compile(shader: gpu::ComputeShader) -> Self::Representation;
fn elem_size(elem: gpu::Elem) -> usize;
}
7 changes: 7 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use super::Operation;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, new)]
pub struct Body {
pub operators: Vec<Operation>,
}
11 changes: 11 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
mod body;
mod operation;
mod shader;
mod variable;
mod vectorization;

pub(crate) use body::*;
pub(crate) use operation::*;
pub(crate) use shader::*;
pub(crate) use variable::*;
pub(crate) use vectorization::*;
81 changes: 81 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use super::Variable;
use serde::{Deserialize, Serialize};

/// All operations that can be used in a GPU compute shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operation {
Add(BinaryOperation),
Sub(BinaryOperation),
Mul(BinaryOperation),
Div(BinaryOperation),
Abs(UnaryOperation),
Exp(UnaryOperation),
Log(UnaryOperation),
Log1p(UnaryOperation),
Cos(UnaryOperation),
Sin(UnaryOperation),
Tanh(UnaryOperation),
Powf(BinaryOperation),
Sqrt(UnaryOperation),
Erf(UnaryOperation),
Recip(UnaryOperation),
Equal(BinaryOperation),
Lower(BinaryOperation),
Clamp(ClampOperation),
Greater(BinaryOperation),
LowerEqual(BinaryOperation),
GreaterEqual(BinaryOperation),
ConditionalAssign(ConditionalAssignOperation),
AssignGlobal(UnaryOperation),
AssignLocal(UnaryOperation),
ReadGlobal(ReadGlobalOperation),
ReadGlobalWithLayout(ReadGlobalWithLayoutOperation),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct BinaryOperation {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct UnaryOperation {
pub input: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ClampOperation {
pub input: Variable,
pub min_value: Variable,
pub max_value: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ConditionalAssignOperation {
pub cond: Variable,
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ReadGlobalOperation {
pub variable: Variable,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ReadGlobalWithLayoutOperation {
pub variable: Variable,
pub tensor_read_pos: usize,
pub tensor_layout_pos: usize,
}
91 changes: 91 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/shader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use super::Body;
use crate::kernel::WORKGROUP_DEFAULT;
use serde::{Deserialize, Serialize};
use std::fmt::Display;

#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum Visibility {
Read,
ReadWrite,
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
pub enum Elem {
Float,
Int,
UInt,
Bool,
}

impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Float => f.write_str("float"),
Self::Int => f.write_str("int"),
Self::UInt => f.write_str("uint"),
Self::Bool => f.write_str("bool"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
pub enum Item {
Vec4(Elem),
Vec3(Elem),
Vec2(Elem),
Scalar(Elem),
}

impl Item {
pub fn elem(&self) -> Elem {
match self {
Self::Vec4(elem) => *elem,
Self::Vec3(elem) => *elem,
Self::Vec2(elem) => *elem,
Self::Scalar(elem) => *elem,
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct Binding {
pub location: Location,
pub visibility: Visibility,
pub item: Item,
pub size: Option<usize>,
}

#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct WorkgroupSize {
pub x: u32,
pub y: u32,
pub z: u32,
}

impl Default for WorkgroupSize {
fn default() -> Self {
Self {
x: WORKGROUP_DEFAULT as u32,
y: WORKGROUP_DEFAULT as u32,
z: 1,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub named: Vec<(String, Binding)>,
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub num_workgroups: bool,
pub body: Body,
}
11 changes: 11 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/variable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use super::Item;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Variable {
Input(u16, Item),
Scalar(u16, Item),
Local(u16, Item),
Output(u16, Item),
Constant(f64, Item),
}
154 changes: 154 additions & 0 deletions burn-wgpu/src/codegen/dialect/gpu/vectorization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use super::{
BinaryOperation, ClampOperation, ConditionalAssignOperation, Item, Operation,
ReadGlobalOperation, ReadGlobalWithLayoutOperation, UnaryOperation, Variable,
};

/// Define a vectorization scheme.
#[allow(dead_code)]
#[derive(Copy, Clone, Debug)]
pub enum Vectorization {
/// Use vec4 for vectorization.
Vec4,
/// Use vec3 for vectorization.
Vec3,
/// Use vec2 for vectorization.
Vec2,
/// Don't vectorize.
Scalar,
}

impl Operation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
match self {
Operation::Add(op) => Operation::Add(op.vectorize(vectorization)),
Operation::Sub(op) => Operation::Sub(op.vectorize(vectorization)),
Operation::Mul(op) => Operation::Mul(op.vectorize(vectorization)),
Operation::Div(op) => Operation::Div(op.vectorize(vectorization)),
Operation::Abs(op) => Operation::Abs(op.vectorize(vectorization)),
Operation::Exp(op) => Operation::Exp(op.vectorize(vectorization)),
Operation::Log(op) => Operation::Log(op.vectorize(vectorization)),
Operation::Log1p(op) => Operation::Log1p(op.vectorize(vectorization)),
Operation::Cos(op) => Operation::Cos(op.vectorize(vectorization)),
Operation::Sin(op) => Operation::Sin(op.vectorize(vectorization)),
Operation::Tanh(op) => Operation::Tanh(op.vectorize(vectorization)),
Operation::Powf(op) => Operation::Powf(op.vectorize(vectorization)),
Operation::Sqrt(op) => Operation::Sqrt(op.vectorize(vectorization)),
Operation::Erf(op) => Operation::Erf(op.vectorize(vectorization)),
Operation::Recip(op) => Operation::Recip(op.vectorize(vectorization)),
Operation::Equal(op) => Operation::Equal(op.vectorize(vectorization)),
Operation::Lower(op) => Operation::Lower(op.vectorize(vectorization)),
Operation::Clamp(op) => Operation::Clamp(op.vectorize(vectorization)),
Operation::Greater(op) => Operation::Greater(op.vectorize(vectorization)),
Operation::LowerEqual(op) => Operation::LowerEqual(op.vectorize(vectorization)),
Operation::GreaterEqual(op) => Operation::GreaterEqual(op.vectorize(vectorization)),
Operation::ConditionalAssign(op) => {
Operation::ConditionalAssign(op.vectorize(vectorization))
}
Operation::AssignGlobal(op) => Operation::AssignGlobal(op.vectorize(vectorization)),
Operation::AssignLocal(op) => Operation::AssignLocal(op.vectorize(vectorization)),
Operation::ReadGlobal(op) => Operation::ReadGlobal(op.vectorize(vectorization)),
Operation::ReadGlobalWithLayout(op) => {
Operation::ReadGlobalWithLayout(op.vectorize(vectorization))
}
}
}
}

impl BinaryOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let lhs = self.lhs.vectorize(vectorization);
let rhs = self.rhs.vectorize(vectorization);
let out = self.out.vectorize(vectorization);

Self { lhs, rhs, out }
}
}

impl UnaryOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let input = self.input.vectorize(vectorization);
let out = self.out.vectorize(vectorization);

Self { input, out }
}
}

impl ClampOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let input = self.input.vectorize(vectorization);
let out = self.out.vectorize(vectorization);
let min_value = self.min_value.vectorize(vectorization);
let max_value = self.max_value.vectorize(vectorization);

Self {
input,
out,
min_value,
max_value,
}
}
}

impl ConditionalAssignOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let cond = self.cond.vectorize(vectorization);
let lhs = self.lhs.vectorize(vectorization);
let rhs = self.rhs.vectorize(vectorization);
let out = self.out.vectorize(vectorization);

Self {
cond,
lhs,
rhs,
out,
}
}
}

impl ReadGlobalOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let variable = self.variable.vectorize(vectorization);

Self { variable }
}
}

impl ReadGlobalWithLayoutOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let variable = self.variable.vectorize(vectorization);
let tensor_read_pos = self.tensor_read_pos;
let tensor_layout_pos = self.tensor_layout_pos;

Self {
variable,
tensor_read_pos,
tensor_layout_pos,
}
}
}

impl Variable {
pub fn vectorize(&self, vectorize: Vectorization) -> Self {
match self {
Variable::Input(index, item) => Variable::Input(*index, item.vectorize(vectorize)),
Variable::Local(index, item) => Variable::Local(*index, item.vectorize(vectorize)),
Variable::Output(index, item) => Variable::Output(*index, item.vectorize(vectorize)),
Variable::Constant(index, item) => {
Variable::Constant(*index, item.vectorize(vectorize))
}
Variable::Scalar(index, item) => Variable::Scalar(*index, *item), // Don't vectorize
// scalar variables.
}
}
}

impl Item {
pub fn vectorize(&self, vectorize: Vectorization) -> Item {
match vectorize {
Vectorization::Vec4 => Item::Vec4(self.elem()),
Vectorization::Vec3 => Item::Vec3(self.elem()),
Vectorization::Vec2 => Item::Vec2(self.elem()),
Vectorization::Scalar => Item::Scalar(self.elem()),
}
}
}
10 changes: 10 additions & 0 deletions burn-wgpu/src/codegen/dialect/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/// GPU dialect module that contains a representation that can be used to program any GPU.
///
/// This dialect should be used to perform most GPU-related optimizations, such as vectorization.
///
/// [Compilers](crate::codegen::Compiler) can be used to transform that representation into a lower
/// level one, such as [wgsl](crate::codegen::dialect::wgsl).
pub(crate) mod gpu;
/// WGSL dialect module that contains a representation that can be compiled to WebGPU shading
/// language (wgsl).
pub(crate) mod wgsl;
Loading

0 comments on commit fb6cc2d

Please sign in to comment.