Skip to content

Commit

Permalink
Merge branch 'main' into unit
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasGustafsson committed Nov 5, 2024
2 parents b48a5e5 + 829aae2 commit 77bf378
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/TorchVision/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -443,21 +443,21 @@ public static Tensor erase(Tensor img, int top, int left, int height, int width,
/// The image is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions.
/// </summary>
/// <returns></returns>
public static Tensor gaussian_blur(Tensor input, IList<long> kernelSize, ReadOnlySpan<float> sigma)
public static Tensor gaussian_blur(Tensor input, IList<long> kernelSize, IList<float> sigma)
{
var dtype = torch.is_integral(input.dtype) ? ScalarType.Float32 : input.dtype;

if (kernelSize.Count == 1) {
kernelSize = new long[] { kernelSize[0], kernelSize[0] };
}

if (sigma == null || sigma.Length == 0)
if (sigma == null || sigma.Count == 0)
{
sigma = new float[] {
0.3f * ((kernelSize[0] - 1) * 0.5f - 1) + 0.8f,
0.3f * ((kernelSize[1] - 1) * 0.5f - 1) + 0.8f,
};
} else if (sigma.Length == 1) {
} else if (sigma.Count == 1) {
sigma = new float[] {
sigma[0],
sigma[0],
Expand Down Expand Up @@ -892,7 +892,7 @@ private static Tensor GetGaussianKernel1d(long size, float sigma)
return pdf / sum;
}

private static Tensor GetGaussianKernel2d(IList<long> kernelSize, ReadOnlySpan<float> sigma, ScalarType dtype, torch.Device device)
private static Tensor GetGaussianKernel2d(IList<long> kernelSize, IList<float> sigma, ScalarType dtype, torch.Device device)
{
using var tX1 = GetGaussianKernel1d(kernelSize[0], sigma[0]);
using var tX2 = tX1.to(dtype, device);
Expand Down
2 changes: 1 addition & 1 deletion src/TorchVision/GaussianBlur.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal GaussianBlur(IList<long> kernelSize, float sigma_min, float sigma_max)
public override Tensor forward(Tensor input)
{
var s = sigma.HasValue ? sigma.Value : torch.empty(1).uniform_(sigma_min, sigma_max).item<float>();
return transforms.functional.gaussian_blur(input, kernelSize, stackalloc[]{s, s});
return transforms.functional.gaussian_blur(input, kernelSize, new []{s,s});
}

protected long[] kernelSize;
Expand Down

0 comments on commit 77bf378

Please sign in to comment.