-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Implement Split operator #23198
Open
jchen10
wants to merge
1
commit into
microsoft:main
Choose a base branch
from
jchen10:split
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+264
−5
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/webgpu/tensor/split.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
namespace { | ||
|
||
// Helper function to calculate the output index based on the input index and the sizes of the splits. | ||
void CalculateOutputIndex(std::ostream& os, size_t output_count) { | ||
os << "fn calculate_output_index(index: u32) -> u32 {\n" | ||
<< " for (var i: u32 = 0u; i < " << output_count << "u; i += 1u ) {\n" | ||
<< " if (index < " << GetElementAt("uniforms.sizes_in_split_axis", "i", output_count) << ") {\n" | ||
<< " return i;\n" | ||
<< " }\n" | ||
<< " }\n" | ||
<< " return " << output_count << "u;\n" | ||
<< "}\n"; | ||
} | ||
|
||
// Helper function to write the buffer data for each output. | ||
void WriteBufferData(std::ostream& os, const ShaderVariableHelper& input, | ||
gsl::span<const ShaderVariableHelper*> outputs) { | ||
os << "fn write_buffer_data(output_number: u32, global_idx: u32, indices: output_0_indices_t) {\n"; | ||
for (size_t i = 0; i < outputs.size(); ++i) { | ||
const auto buffer_write = outputs[i]->SetByIndices("indices", input.GetByOffset("global_idx")); | ||
if (outputs.size() == 1) { | ||
os << buffer_write; | ||
} else if (i == 0) { | ||
os << " if (output_number == 0u) {\n" | ||
<< " " << buffer_write << "\n"; | ||
} else if (i == outputs.size() - 1) { | ||
os << " } else {\n" | ||
<< " " << buffer_write << "\n"; | ||
} else { | ||
os << " } else if (output_number == " << i << "u) {\n" | ||
<< " " << buffer_write << "\n"; | ||
} | ||
} | ||
os << " }\n" | ||
<< "}\n"; | ||
} | ||
|
||
} // namespace | ||
|
||
Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); | ||
|
||
size_t output_count = Outputs().size(); | ||
std::vector<const ShaderVariableHelper*> outputs; | ||
outputs.reserve(output_count); | ||
for (size_t i = 0; i < output_count; ++i) { | ||
outputs.push_back( | ||
&shader.AddOutput("output_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); | ||
} | ||
|
||
// Add implementation of fn calculate_output_index. | ||
CalculateOutputIndex(shader.AdditionalImplementation(), output_count); | ||
// Add implementation of fn write_buffer_data. | ||
WriteBufferData(shader.AdditionalImplementation(), input, outputs); | ||
|
||
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size") | ||
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n" | ||
<< " var index = indices[" << axis_ << "];\n" | ||
<< " let output_number = calculate_output_index(index);\n" | ||
<< " if (output_number != 0u) {\n" | ||
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n" | ||
<< " indices[" << axis_ << "] = index;\n" | ||
<< " }\n" | ||
<< " write_buffer_data(output_number, global_idx, indices);\n"; | ||
|
||
return Status::OK(); | ||
} | ||
|
||
Status Split::ComputeInternal(ComputeContext& context) const { | ||
const Tensor* input = context.Input<Tensor>(0); | ||
auto& input_shape = input->Shape(); | ||
auto num_outputs = context.OutputCount(); | ||
|
||
int64_t axis = axis_; | ||
std::vector<int64_t> split_sizes; | ||
|
||
split_sizes.assign(split_sizes_.begin(), split_sizes_.end()); | ||
// Compute split_sizes from the 'split' input tensor. | ||
if (split_sizes_.size() == 0 && context.InputCount() > 1) { | ||
const Tensor* split_tensor = context.Input<Tensor>(1); | ||
// Check if split_tensor is valid. | ||
if (split_tensor != nullptr) { | ||
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor."); | ||
// Get split_sizes from the input tensor. | ||
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]); | ||
const auto* data = split_tensor->Data<int64_t>(); | ||
split_sizes.assign(data, data + nDims); | ||
} | ||
} | ||
|
||
// The variables below are not actually used in the current implementation. | ||
int before_dims = 0; | ||
int after_dims_including_split_axis = 0; | ||
int after_dims_excluding_split = 0; | ||
// This handles the case where the axis is negative. It also splits outputs evenly according to num_ouputs if | ||
// split_sizes is empty. | ||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, | ||
after_dims_excluding_split, split_sizes)); | ||
|
||
SplitProgram program{gsl::narrow_cast<uint32_t>(axis)}; | ||
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); | ||
|
||
auto output_dimensions = input_shape.AsShapeVector(); | ||
for (int i = 0; i < num_outputs; ++i) { | ||
// Update the size of dimension for axis we're splitting on. | ||
auto split_size = narrow<int>(split_sizes[i]); | ||
output_dimensions[narrow<size_t>(axis)] = split_size; | ||
|
||
Tensor* output = context.Output(i, TensorShape{output_dimensions}); | ||
program.AddOutput({output, ProgramTensorMetadataDependency::Rank}); | ||
} | ||
|
||
uint32_t input_size = gsl::narrow<uint32_t>(input_shape.Size()); | ||
// Early return if the input tensor is empty. | ||
if (input_size == 0) { | ||
return Status::OK(); | ||
} | ||
|
||
uint32_t previous_sum = 0; | ||
std::vector<uint32_t> sizes_in_split_axis; | ||
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis. | ||
for (auto split_size : split_sizes) { | ||
previous_sum += gsl::narrow<uint32_t>(split_size); | ||
sizes_in_split_axis.push_back(previous_sum); | ||
} | ||
|
||
program | ||
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
.CacheHint(std::to_string(axis)) | ||
.AddUniformVariables( | ||
{input_size, gsl::span<const uint32_t>(sizes_in_split_axis.data(), sizes_in_split_axis.size())}); | ||
return context.RunProgram(program); | ||
} | ||
|
||
#define WEBGPU_SPLIT_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ | ||
ONNX_OPERATOR_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ | ||
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ | ||
KERNEL_CLASS); | ||
|
||
#define WEBGPU_SPLIT_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ | ||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ | ||
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ | ||
KERNEL_CLASS); | ||
|
||
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 1, 1, Split_1, WebGpuSupportedNumberTypes()) | ||
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 2, 10, Split_2_10, WebGpuSupportedNumberTypes()) | ||
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 11, 12, Split_11_12, WebGpuSupportedNumberTypes()) | ||
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 13, 17, Split_13_17, WebGpuSupportedNumberTypes()) | ||
WEBGPU_SPLIT_KERNEL(Split, 18, Split_18, WebGpuSupportedNumberTypes()); | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/common.h" | ||
#include "core/providers/cpu/tensor/split.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class SplitProgram final : public Program<SplitProgram> { | ||
public: | ||
SplitProgram(const uint32_t axis) : Program{"Split"}, axis_{axis} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}, | ||
{"sizes_in_split_axis", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
uint32_t axis_; | ||
}; | ||
|
||
class Split : public WebGpuKernel, public SplitBase { | ||
public: | ||
Split(const OpKernelInfo& info, uint32_t opset) : WebGpuKernel(info), SplitBase(info, opset) { | ||
std::vector<int32_t> split_sizes; | ||
// Check if split_sizes is provided as an attribute. | ||
if (split_sizes_.size() > 0) { | ||
ORT_ENFORCE(split_sizes_.size() == info.node().OutputDefs().size(), "Number of outputs (", | ||
info.node().OutputDefs().size(), ") does not match split_sizes (", split_sizes_.size(), ")"); | ||
split_sizes.resize(split_sizes_.size()); | ||
for (size_t i = 0; i < split_sizes_.size(); ++i) { | ||
split_sizes[i] = gsl::narrow_cast<int32_t>(split_sizes_[i]); | ||
} | ||
} else if (info.GetInputCount() < 2) { | ||
// No valid split_sizes is providede as an attribute or input tensor. In this case, we try to compute it from input, output shapes and | ||
// num_outputs. | ||
|
||
// Handle negative axis. | ||
const auto num_dimensions = gsl::narrow_cast<int64_t>(info.node().InputDefs()[0]->Shape()->dim_size()); | ||
const auto axis = HandleNegativeAxis(axis_, num_dimensions); | ||
|
||
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis)).dim_value(); | ||
int64_t split_size_sum = 0; | ||
if (num_outputs_ >= 0) { | ||
ORT_ENFORCE(num_outputs_ == gsl::narrow_cast<int64_t>(info.node().OutputDefs().size()), | ||
"Invalid num_outputs value of ", num_outputs_, ". Size of dimension being split is ", | ||
info.node().OutputDefs().size()); | ||
} | ||
|
||
// Compute split_sizes from the output shapes. | ||
for (auto output : info.node().OutputDefs()) { | ||
auto split_size = output->Shape()->dim(gsl::narrow_cast<int32_t>(axis)).dim_value(); | ||
split_sizes.push_back(gsl::narrow_cast<int32_t>(split_size)); | ||
split_size_sum += split_size; | ||
} | ||
ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, | ||
") does not match input size (", total_split_size, ")"); | ||
} | ||
} | ||
|
||
protected: | ||
Status ComputeInternal(ComputeContext& context) const override; | ||
}; | ||
|
||
class Split_1 final : public Split { | ||
public: | ||
Split_1(const OpKernelInfo& info) : Split(info, 1) {} | ||
}; | ||
|
||
class Split_2_10 final : public Split { | ||
public: | ||
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {} | ||
}; | ||
|
||
class Split_11_12 final : public Split { | ||
public: | ||
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {} | ||
}; | ||
|
||
class Split_13_17 final : public Split { | ||
public: | ||
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {} | ||
}; | ||
|
||
class Split_18 final : public Split { | ||
public: | ||
Split_18(const OpKernelInfo& info) : Split(info, 18) {} | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do some validation in the
Split
constructor function, not in theComputeInternal
function? I seePrepareForCompute
has already done some the same validation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is right, in general the validation should be in the op's constructor function which is supposed to called during the session's initialization phase. For
Split
op, it can either have an attribute 'split' or input 'split'. If specified via input, its values are unavailable until the session's run phase, where theComputeInternal
function runs. This just follows what the CPUSplit
op does in the same way.