Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/Furnace.Backends.Torch/Torch.RawTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1115,19 +1115,34 @@ type TorchRawTensor(tt: torch.Tensor, shape: Shape, dtype: Dtype, device: Device
checkMutable()
tt.add_(toTorchScalar t2) |> ignore

// TODO - this should be faster
// Optimized AddSliceInPlace - reduced allocations and conversions
override t1.AddSliceInPlace(location, t2) =
checkMutable()
Shape.checkCanAddSlice t1.Shape location t2.Shape
let shape1 = t1.Shape
let shape2 = t2.Shape
let expandedShape2 = Shape.unsqueezeAs shape2 shape1
let t2Expanded = t2.TorchTensor.expand(toTorchShape expandedShape2)

// Pre-compute torch shape to avoid repeated conversions
let torchExpandedShape2 =
let result = Array.zeroCreate expandedShape2.Length
for i = 0 to expandedShape2.Length - 1 do
result[i] <- int64 expandedShape2[i]
result

let t2Expanded = t2.TorchTensor.expand(torchExpandedShape2)
let mutable t1Slice = tt // will share memory with res

// Optimize the slicing loop - cache shape values and reduce conditional checks
for d in 0 .. location.Length - 1 do
let locationD = location[d]
let len2 = expandedShape2[d]
if location[d] <> 0 || len2 <> shape1[d] then
t1Slice <- t1Slice.narrow(int64 d, int64 location[d], int64 len2)
let shape1D = shape1[d]

// Only narrow if we're not accessing the full dimension
if locationD <> 0 || len2 <> shape1D then
t1Slice <- t1Slice.narrow(int64 d, int64 locationD, int64 len2)

t1Slice.add_(t2Expanded) |> ignore

override _.SubInPlace(t2) = checkMutable(); tt.sub_(t2.TorchTensor) |> ignore
Expand Down