-
Notifications
You must be signed in to change notification settings - Fork 430
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Create an Intermediate representation of compute shaders o…
…n the GPU (#1274)
- Loading branch information
1 parent
a9b6dbc
commit fb6cc2d
Showing
42 changed files
with
1,686 additions
and
1,105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
use super::dialect::gpu; | ||
use std::fmt::Display; | ||
|
||
pub trait Compiler: Sync + Send + 'static { | ||
type Representation: Display; | ||
|
||
fn compile(shader: gpu::ComputeShader) -> Self::Representation; | ||
fn elem_size(elem: gpu::Elem) -> usize; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
use super::Operation; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, new)] | ||
pub struct Body { | ||
pub operators: Vec<Operation>, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
mod body; | ||
mod operation; | ||
mod shader; | ||
mod variable; | ||
mod vectorization; | ||
|
||
pub(crate) use body::*; | ||
pub(crate) use operation::*; | ||
pub(crate) use shader::*; | ||
pub(crate) use variable::*; | ||
pub(crate) use vectorization::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
use super::Variable; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
/// All operations that can be used in a GPU compute shader. | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub enum Operation { | ||
Add(BinaryOperation), | ||
Sub(BinaryOperation), | ||
Mul(BinaryOperation), | ||
Div(BinaryOperation), | ||
Abs(UnaryOperation), | ||
Exp(UnaryOperation), | ||
Log(UnaryOperation), | ||
Log1p(UnaryOperation), | ||
Cos(UnaryOperation), | ||
Sin(UnaryOperation), | ||
Tanh(UnaryOperation), | ||
Powf(BinaryOperation), | ||
Sqrt(UnaryOperation), | ||
Erf(UnaryOperation), | ||
Recip(UnaryOperation), | ||
Equal(BinaryOperation), | ||
Lower(BinaryOperation), | ||
Clamp(ClampOperation), | ||
Greater(BinaryOperation), | ||
LowerEqual(BinaryOperation), | ||
GreaterEqual(BinaryOperation), | ||
ConditionalAssign(ConditionalAssignOperation), | ||
AssignGlobal(UnaryOperation), | ||
AssignLocal(UnaryOperation), | ||
ReadGlobal(ReadGlobalOperation), | ||
ReadGlobalWithLayout(ReadGlobalWithLayoutOperation), | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct BinaryOperation { | ||
pub lhs: Variable, | ||
pub rhs: Variable, | ||
pub out: Variable, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct UnaryOperation { | ||
pub input: Variable, | ||
pub out: Variable, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct ClampOperation { | ||
pub input: Variable, | ||
pub min_value: Variable, | ||
pub max_value: Variable, | ||
pub out: Variable, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct ConditionalAssignOperation { | ||
pub cond: Variable, | ||
pub lhs: Variable, | ||
pub rhs: Variable, | ||
pub out: Variable, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct ReadGlobalOperation { | ||
pub variable: Variable, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
#[allow(dead_code)] // Some variants might not be used with different flags | ||
pub struct ReadGlobalWithLayoutOperation { | ||
pub variable: Variable, | ||
pub tensor_read_pos: usize, | ||
pub tensor_layout_pos: usize, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
use super::Body; | ||
use crate::kernel::WORKGROUP_DEFAULT; | ||
use serde::{Deserialize, Serialize}; | ||
use std::fmt::Display; | ||
|
||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] | ||
pub enum Location { | ||
Storage, | ||
#[allow(dead_code)] | ||
Workgroup, | ||
} | ||
|
||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] | ||
pub enum Visibility { | ||
Read, | ||
ReadWrite, | ||
} | ||
|
||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)] | ||
pub enum Elem { | ||
Float, | ||
Int, | ||
UInt, | ||
Bool, | ||
} | ||
|
||
impl Display for Elem { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
Self::Float => f.write_str("float"), | ||
Self::Int => f.write_str("int"), | ||
Self::UInt => f.write_str("uint"), | ||
Self::Bool => f.write_str("bool"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)] | ||
pub enum Item { | ||
Vec4(Elem), | ||
Vec3(Elem), | ||
Vec2(Elem), | ||
Scalar(Elem), | ||
} | ||
|
||
impl Item { | ||
pub fn elem(&self) -> Elem { | ||
match self { | ||
Self::Vec4(elem) => *elem, | ||
Self::Vec3(elem) => *elem, | ||
Self::Vec2(elem) => *elem, | ||
Self::Scalar(elem) => *elem, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] | ||
pub struct Binding { | ||
pub location: Location, | ||
pub visibility: Visibility, | ||
pub item: Item, | ||
pub size: Option<usize>, | ||
} | ||
|
||
#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] | ||
pub struct WorkgroupSize { | ||
pub x: u32, | ||
pub y: u32, | ||
pub z: u32, | ||
} | ||
|
||
impl Default for WorkgroupSize { | ||
fn default() -> Self { | ||
Self { | ||
x: WORKGROUP_DEFAULT as u32, | ||
y: WORKGROUP_DEFAULT as u32, | ||
z: 1, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct ComputeShader { | ||
pub inputs: Vec<Binding>, | ||
pub outputs: Vec<Binding>, | ||
pub named: Vec<(String, Binding)>, | ||
pub workgroup_size: WorkgroupSize, | ||
pub global_invocation_id: bool, | ||
pub num_workgroups: bool, | ||
pub body: Body, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use super::Item; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub enum Variable { | ||
Input(u16, Item), | ||
Scalar(u16, Item), | ||
Local(u16, Item), | ||
Output(u16, Item), | ||
Constant(f64, Item), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
use super::{ | ||
BinaryOperation, ClampOperation, ConditionalAssignOperation, Item, Operation, | ||
ReadGlobalOperation, ReadGlobalWithLayoutOperation, UnaryOperation, Variable, | ||
}; | ||
|
||
/// Define a vectorization scheme. | ||
#[allow(dead_code)] | ||
#[derive(Copy, Clone, Debug)] | ||
pub enum Vectorization { | ||
/// Use vec4 for vectorization. | ||
Vec4, | ||
/// Use vec3 for vectorization. | ||
Vec3, | ||
/// Use vec2 for vectorization. | ||
Vec2, | ||
/// Don't vectorize. | ||
Scalar, | ||
} | ||
|
||
impl Operation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
match self { | ||
Operation::Add(op) => Operation::Add(op.vectorize(vectorization)), | ||
Operation::Sub(op) => Operation::Sub(op.vectorize(vectorization)), | ||
Operation::Mul(op) => Operation::Mul(op.vectorize(vectorization)), | ||
Operation::Div(op) => Operation::Div(op.vectorize(vectorization)), | ||
Operation::Abs(op) => Operation::Abs(op.vectorize(vectorization)), | ||
Operation::Exp(op) => Operation::Exp(op.vectorize(vectorization)), | ||
Operation::Log(op) => Operation::Log(op.vectorize(vectorization)), | ||
Operation::Log1p(op) => Operation::Log1p(op.vectorize(vectorization)), | ||
Operation::Cos(op) => Operation::Cos(op.vectorize(vectorization)), | ||
Operation::Sin(op) => Operation::Sin(op.vectorize(vectorization)), | ||
Operation::Tanh(op) => Operation::Tanh(op.vectorize(vectorization)), | ||
Operation::Powf(op) => Operation::Powf(op.vectorize(vectorization)), | ||
Operation::Sqrt(op) => Operation::Sqrt(op.vectorize(vectorization)), | ||
Operation::Erf(op) => Operation::Erf(op.vectorize(vectorization)), | ||
Operation::Recip(op) => Operation::Recip(op.vectorize(vectorization)), | ||
Operation::Equal(op) => Operation::Equal(op.vectorize(vectorization)), | ||
Operation::Lower(op) => Operation::Lower(op.vectorize(vectorization)), | ||
Operation::Clamp(op) => Operation::Clamp(op.vectorize(vectorization)), | ||
Operation::Greater(op) => Operation::Greater(op.vectorize(vectorization)), | ||
Operation::LowerEqual(op) => Operation::LowerEqual(op.vectorize(vectorization)), | ||
Operation::GreaterEqual(op) => Operation::GreaterEqual(op.vectorize(vectorization)), | ||
Operation::ConditionalAssign(op) => { | ||
Operation::ConditionalAssign(op.vectorize(vectorization)) | ||
} | ||
Operation::AssignGlobal(op) => Operation::AssignGlobal(op.vectorize(vectorization)), | ||
Operation::AssignLocal(op) => Operation::AssignLocal(op.vectorize(vectorization)), | ||
Operation::ReadGlobal(op) => Operation::ReadGlobal(op.vectorize(vectorization)), | ||
Operation::ReadGlobalWithLayout(op) => { | ||
Operation::ReadGlobalWithLayout(op.vectorize(vectorization)) | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl BinaryOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let lhs = self.lhs.vectorize(vectorization); | ||
let rhs = self.rhs.vectorize(vectorization); | ||
let out = self.out.vectorize(vectorization); | ||
|
||
Self { lhs, rhs, out } | ||
} | ||
} | ||
|
||
impl UnaryOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let input = self.input.vectorize(vectorization); | ||
let out = self.out.vectorize(vectorization); | ||
|
||
Self { input, out } | ||
} | ||
} | ||
|
||
impl ClampOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let input = self.input.vectorize(vectorization); | ||
let out = self.out.vectorize(vectorization); | ||
let min_value = self.min_value.vectorize(vectorization); | ||
let max_value = self.max_value.vectorize(vectorization); | ||
|
||
Self { | ||
input, | ||
out, | ||
min_value, | ||
max_value, | ||
} | ||
} | ||
} | ||
|
||
impl ConditionalAssignOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let cond = self.cond.vectorize(vectorization); | ||
let lhs = self.lhs.vectorize(vectorization); | ||
let rhs = self.rhs.vectorize(vectorization); | ||
let out = self.out.vectorize(vectorization); | ||
|
||
Self { | ||
cond, | ||
lhs, | ||
rhs, | ||
out, | ||
} | ||
} | ||
} | ||
|
||
impl ReadGlobalOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let variable = self.variable.vectorize(vectorization); | ||
|
||
Self { variable } | ||
} | ||
} | ||
|
||
impl ReadGlobalWithLayoutOperation { | ||
pub fn vectorize(&self, vectorization: Vectorization) -> Self { | ||
let variable = self.variable.vectorize(vectorization); | ||
let tensor_read_pos = self.tensor_read_pos; | ||
let tensor_layout_pos = self.tensor_layout_pos; | ||
|
||
Self { | ||
variable, | ||
tensor_read_pos, | ||
tensor_layout_pos, | ||
} | ||
} | ||
} | ||
|
||
impl Variable { | ||
pub fn vectorize(&self, vectorize: Vectorization) -> Self { | ||
match self { | ||
Variable::Input(index, item) => Variable::Input(*index, item.vectorize(vectorize)), | ||
Variable::Local(index, item) => Variable::Local(*index, item.vectorize(vectorize)), | ||
Variable::Output(index, item) => Variable::Output(*index, item.vectorize(vectorize)), | ||
Variable::Constant(index, item) => { | ||
Variable::Constant(*index, item.vectorize(vectorize)) | ||
} | ||
Variable::Scalar(index, item) => Variable::Scalar(*index, *item), // Don't vectorize | ||
// scalar variables. | ||
} | ||
} | ||
} | ||
|
||
impl Item { | ||
pub fn vectorize(&self, vectorize: Vectorization) -> Item { | ||
match vectorize { | ||
Vectorization::Vec4 => Item::Vec4(self.elem()), | ||
Vectorization::Vec3 => Item::Vec3(self.elem()), | ||
Vectorization::Vec2 => Item::Vec2(self.elem()), | ||
Vectorization::Scalar => Item::Scalar(self.elem()), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
/// GPU dialect module that contains a representation that can be used to program any GPU. | ||
/// | ||
/// This dialect should be used to perform most GPU-related optimizations, such as vectorization. | ||
/// | ||
/// [Compilers](crate::codegen::Compiler) can be used to transform that representation into a lower | ||
/// level one, such as [wgsl](crate::codegen::dialect::wgsl). | ||
pub(crate) mod gpu; | ||
/// WGSL dialect module that contains a representation that can be compiled to WebGPU shading | ||
/// language (wgsl). | ||
pub(crate) mod wgsl; |
Oops, something went wrong.