Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Multinomial sampling #2850

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions crates/burn-candle/src/ops/candle_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,94 @@ pub(crate) fn fill_like<E: CandleElement>(value: E, reference_tensor: &Tensor) -
reference_tensor.device(),
)
}

/// Implementation based on PyTorch Cuda Kernel: https://fossies.org/linux/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu
pub(crate) fn multinomial<E: CandleElement>(
props: &[f64],
num_samples: usize,
device: &Device,
) -> Tensor {
if props.len() == 0 {
return Tensor::from_iter([0f64; 0].into_iter(), device)
.unwrap()
.to_dtype(E::DTYPE)
.unwrap();
}
let p = Tensor::from_iter(props.into_iter().cloned(), &device).unwrap();
let p_cum = p.cumsum(0).unwrap();
let p_sum = p_cum.get(props.len() - 1).unwrap();
let p = p.broadcast_div(&p_sum).unwrap(); // Binary search each probability in parallel, using index_select
let randos = Tensor::rand(0.0, 1.0, num_samples, &device).unwrap();
let mut starts = Tensor::zeros(randos.shape(), DType::I64, &device).unwrap();
let mut ends = Tensor::full(p.dims1().unwrap() as i64 - 1, randos.shape(), &device).unwrap();
let mut mids;
let mut mid_vals; // Helper vals for scalar operations
let twos = Tensor::full(2i64, randos.shape(), &device).unwrap();
let ones = Tensor::full(1i64, randos.shape(), &device).unwrap();
let mut ends_gt_starts = ends
.gt(&starts)
.unwrap()
.to_dtype(DType::F32)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
while ends_gt_starts > 0.0 {
mids = (&starts + (&ends - &starts).unwrap().div(&twos).unwrap()).unwrap();
mid_vals = p_cum.index_select(&mids, 0).unwrap();
let mid_val_less_than_val = mid_vals.lt(&randos).unwrap();
let new_starts = (&mids + &ones).unwrap();
let new_ends = &mids;
starts = mid_val_less_than_val
.where_cond(&new_starts, &starts)
.unwrap();
ends = mid_val_less_than_val.where_cond(&ends, &new_ends).unwrap();
ends_gt_starts = ends
.gt(&starts)
.unwrap()
.to_dtype(DType::F32)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
}
let size = Tensor::full(num_samples as i64, randos.shape(), &device).unwrap();
let starts_are_size = starts.eq(&size).unwrap();
let new_starts = (&size - &ones).unwrap();
starts = starts_are_size.where_cond(&new_starts, &starts).unwrap();
let mut starts_above_one = starts.gt(1i64).unwrap();
let mut prob_is_zero = p.index_select(&starts, 0).unwrap().eq(0f64).unwrap();
let mut both_true = (&starts_above_one + &prob_is_zero)
.unwrap()
.eq(2u8)
.unwrap();
let mut any_true = both_true
.to_dtype(DType::F32)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
while any_true > 0.0 {
starts = both_true
.where_cond(&(&starts - &ones).unwrap(), &starts)
.unwrap();
starts_above_one = starts.gt(1i64).unwrap();
prob_is_zero = p.index_select(&starts, 0).unwrap().gt(0f64).unwrap();
both_true = (&starts_above_one + &prob_is_zero)
.unwrap()
.eq(2u8)
.unwrap();
any_true = both_true
.to_dtype(DType::F32)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
}
let final_result = starts;
return final_result.to_dtype(E::DTYPE).unwrap();
}
7 changes: 6 additions & 1 deletion crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
Candle, CandleTensor,
};

use super::base::{expand, permute, sign};
use super::{base::{expand, permute, sign}, candle_utils::multinomial};

impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
Expand Down Expand Up @@ -359,6 +359,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
Distribution::Multinomial(probs) => {
let num_samples = shape.iter().product::<usize>();
let out = multinomial::<I>(&probs, num_samples, device).reshape(shape).unwrap();
CandleTensor::new(out)
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
Candle, CandleTensor,
};

use super::base::{expand, permute, sign};
use super::{base::{expand, permute, sign}, candle_utils::multinomial};

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
Expand Down Expand Up @@ -56,6 +56,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)
.unwrap(),
),
Distribution::Multinomial(probs) => {
let num_samples = shape.iter().product::<usize>();
let out = multinomial::<F>(&probs, num_samples, device).reshape(shape).unwrap();
CandleTensor::new(out)
},
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-cubecl/src/kernel/prng/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod base;
mod bernoulli;
mod multinomial;
mod normal;
mod uniform;

pub use base::*;
pub use bernoulli::*;
pub use multinomial::*;
pub use normal::*;
pub use uniform::*;
71 changes: 71 additions & 0 deletions crates/burn-cubecl/src/kernel/prng/multinomial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use burn_tensor::Shape;
use cubecl::{linalg::tensor::TensorHandle, prelude::*};

use crate::{
kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
tensor::CubeTensor,
CubeElement, CubeRuntime,
};

use super::{random, PrngArgs, PrngRuntime};

#[derive(CubeLaunch)]
pub(crate) struct Multinomial<E: Numeric + CubeType> {
props: Tensor<E>,
}

#[cube]
impl<E: CubeElement + CubeType> PrngRuntime<E> for Multinomial<E> {
fn inner_loop(
args: Multinomial<E>,
write_index_base: u32,
n_invocations: u32,
#[comptime] n_values_per_thread: u32,
state_0: &mut u32,
state_1: &mut u32,
state_2: &mut u32,
state_3: &mut u32,
output: &mut Tensor<E>,
) {
let props = args.props;

let should_unroll = n_values_per_thread <= 8;
let scale = upper_bound - lower_bound;

#[unroll(should_unroll)]
for i in 0..n_values_per_thread {
*state_0 = taus_step_0(*state_0);
*state_1 = taus_step_1(*state_1);
*state_2 = taus_step_2(*state_2);
*state_3 = lcg_step(*state_3);

let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
let f32_random = cast_uint_to_float(int_random);
let random = E::cast_from(f32_random);

let uniform = random * scale + lower_bound;

let write_index = i * n_invocations + write_index_base;

output[write_index] = uniform;
}
}
}


/// Pseudo-random generator with uniform distribution
pub fn random_multinomial<R: CubeRuntime, E: CubeElement>(
shape: Shape,
device: &R::Device,
props: Tensor<E>,
) -> CubeTensor<R> {
random(shape, device, Multinomial { props })
}
/// Pseudo-random generator for uniform distribution, based on
/// another tensor.
pub fn random_like_multinomial<R: CubeRuntime, E: CubeElement>(
tensor: &CubeTensor<R>,
props: Tensor<E>,
) -> CubeTensor<R> {
random_multinomial::<R, E>(tensor.shape.clone(), &tensor.device, props)
}
3 changes: 3 additions & 0 deletions crates/burn-cubecl/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ where
Distribution::Normal(mean, std) => {
random_normal(shape, device, mean.elem::<F>(), std.elem())
}
Distribution::Multinomial(probs) => {
random_multinomial(shape, device, probs)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ impl RelativeOpsScalar<f32> for FloatOperationIr {
}),
FloatOperationIr::Random(desc) => FloatOperationIr::Random(RandomOpIr {
out: desc.out.to_relative(converter),
distribution: desc.distribution,
distribution: desc.distribution.clone(),
}),
FloatOperationIr::Recip(desc) => FloatOperationIr::Recip(UnaryOpIr {
input: desc.input.to_relative(converter),
Expand Down Expand Up @@ -981,7 +981,7 @@ impl<E: Element> RelativeOpsScalar<E> for NumericOperationIr<E> {
}),
NumericOperationIr::IntRandom(desc) => NumericOperationIr::IntRandom(RandomOpIr {
out: desc.out.to_relative(converter),
distribution: desc.distribution,
distribution: desc.distribution.clone(),
}),
NumericOperationIr::Powf(desc) => NumericOperationIr::Powf(BinaryOpIr {
lhs: desc.lhs.to_relative(converter),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-ir/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,7 @@ impl core::hash::Hash for RandomOpIr {
Distribution::Bernoulli(_) => 2u8.hash(state),
Distribution::Uniform(_, _) => 3u8.hash(state),
Distribution::Normal(_, _) => 4u8.hash(state),
Distribution::Multinomial(_) => 5u8.hash(state),
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,24 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
let mut tensor = TchTensor::empty::<i64>(shape, *device);
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
}
Distribution::Multinomial(probs) => {
let num_samples = shape.num_elements() as i64;
let num_probs = probs.len();
let prob_data = TensorData::new(probs, vec![num_probs]);
let mut tensor = TchTensor::from_data::<E>(prob_data, (*device).into());
tensor
.mut_ops(|tensor| {
tensor.f_multinomial(num_samples, false).unwrap().reshape(
shape
.dims
.clone()
.into_iter()
.map(|dim| dim as i64)
.collect::<Vec<_>>(),
)
})
.unwrap()
}
}
}

Expand Down
18 changes: 18 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
let mut tensor = TchTensor::empty::<E>(shape, *device);
tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
}
Distribution::Multinomial(probs) => {
let num_samples = shape.num_elements() as i64;
let num_probs = probs.len();
let prob_data = TensorData::new(probs, vec![num_probs]);
let mut tensor = TchTensor::from_data::<E>(prob_data, (*device).into());
tensor
.mut_ops(|tensor| {
tensor.f_multinomial(num_samples, false).unwrap().reshape(
shape
.dims
.clone()
.into_iter()
.map(|dim| dim as i64)
.collect::<Vec<_>>(),
)
})
.unwrap()
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ impl TensorData {
let mut data = Vec::with_capacity(num_elements);

for _ in 0..num_elements {
data.push(E::random(distribution, rng));
data.push(E::random(distribution.clone(), rng));
}

TensorData::new(data, shape)
Expand Down
19 changes: 15 additions & 4 deletions crates/burn-tensor/src/tensor/distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rand::{distr::StandardUniform, Rng, RngCore};
use crate::{Element, ElementConversion};

/// Distribution for random value of a tensor.
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Distribution {
/// Uniform distribution from 0 (inclusive) to 1 (exclusive).
Default,
Expand All @@ -16,14 +16,17 @@ pub enum Distribution {

/// Normal distribution with the given mean and standard deviation.
Normal(f64, f64),

/// Multinomial distribution with the given probabilities.
Multinomial(Vec<f64>),
}

/// Distribution sampler for random value of a tensor.
#[derive(new)]
pub struct DistributionSampler<'a, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
E: rand::distr::uniform::SampleUniform + std::cmp::PartialOrd,
R: RngCore,
{
kind: DistributionSamplerKind<E>,
Expand All @@ -34,7 +37,7 @@ where
pub enum DistributionSamplerKind<E>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
E: rand::distr::uniform::SampleUniform + std::cmp::PartialOrd,
{
/// Standard distribution.
Standard(rand::distr::StandardUniform),
Expand All @@ -47,13 +50,17 @@ where

/// Normal distribution.
Normal(rand_distr::Normal<f64>),

/// Multinomial (categorical) distribution.
Multinomial(rand::distr::weighted::WeightedIndex<f64>),
}

impl<E, R> DistributionSampler<'_, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
E: Element,
E: std::cmp::PartialOrd,
R: RngCore,
{
/// Sames a random value from the distribution.
Expand All @@ -69,6 +76,7 @@ where
}
}
DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
DistributionSamplerKind::Multinomial(distribution) => (self.rng.sample(distribution) as f64).elem(),
}
}
}
Expand All @@ -86,7 +94,7 @@ impl Distribution {
pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
where
R: RngCore,
E: Element + rand::distr::uniform::SampleUniform,
E: Element + rand::distr::uniform::SampleUniform + std::cmp::PartialOrd,
StandardUniform: rand::distr::Distribution<E>,
{
let kind = match self {
Expand All @@ -102,6 +110,9 @@ impl Distribution {
Distribution::Normal(mean, std) => {
DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
}
Distribution::Multinomial(vec) =>
DistributionSamplerKind::Multinomial(rand::distr::weighted::WeightedIndex::new(vec).unwrap()),

};

DistributionSampler::new(kind, rng)
Expand Down
Loading