Skip to content

Commit

Permalink
Merge branch 'INSTX-5832-ios-gpu-tracing' into 'master'
Browse files Browse the repository at this point in the history
[INSTX-5832] Add stuff to help with iOS tracing

See merge request machine-learning/dorado!1142
  • Loading branch information
blawrence-ont committed Aug 1, 2024
2 parents 7f42b8f + cd5c31d commit ac0dfe2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 22 deletions.
5 changes: 3 additions & 2 deletions cmake/Metal.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set(METAL_SOURCES dorado/basecall/metal/nn.metal)

if (IOS)
set(XCRUN_SDK ${SDK_NAME})
set(METAL_STD_VERSION "ios-metal2.3") # iOS 14
set(METAL_STD_VERSION "metal3.0") # iOS 16
else()
set(XCRUN_SDK macosx)
set(METAL_STD_VERSION "macos-metal2.3") # macOS 11.0
Expand All @@ -25,7 +25,8 @@ foreach(source ${METAL_SOURCES})
-Wall -Wextra -pedantic
-Wno-c++17-extensions # [[maybe_unused]] is C++17
-std=${METAL_STD_VERSION}
-ffast-math
-gline-tables-only -frecord-sources # embed source for trace analysis
-O2 -ffast-math
-c "${CMAKE_CURRENT_SOURCE_DIR}/${source}"
-o "${air_path}"
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${source}"
Expand Down
7 changes: 5 additions & 2 deletions dorado/basecall/MetalCaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ int MetalLSTMCaller::benchmark_batch_sizes(const CRFModelConfig &model_config,
const int kNumSmallerSizes = 16;
const float test_size_increment = static_cast<float>(max_batch_size - min_batch_size) /
static_cast<float>(kNumSmallerSizes);
for (int i = 0; i < kNumSmallerSizes; ++i) {
for (int i = 0; i <= kNumSmallerSizes; ++i) {
const int test_batch_size =
utils::pad_to(min_batch_size + static_cast<int>(i * test_size_increment),
static_cast<int>(MTL_CORE_BATCH_SIZE));
Expand Down Expand Up @@ -425,6 +425,7 @@ bool MetalLSTMCaller::run_scan_kernels(MTL::CommandBuffer *const cb, int try_cou
// the effective batch size is m_out_batch_size.
std::vector<int32_t> scan_args_{m_out_chunk_size, m_out_batch_size, m_states};
auto scan_args = create_vec_buffer(m_device.get(), scan_args_);
name_mtl_object(scan_args, "scan_kernel_args");

for (int i = 0; i < m_out_split; ++i) {
// TODO: optimise grid size
Expand Down Expand Up @@ -546,6 +547,8 @@ bool MetalTxCaller::run_scan_kernels(MTL::CommandBuffer *const cb, int try_count
// ScanArgs expects scores TNC tensor sizes
std::vector<int32_t> scan_args_{m_out_chunk_size, m_batch_size, m_states};
auto scan_args = create_vec_buffer(m_device.get(), scan_args_);
name_mtl_object(scan_args, "scan_kernel_args");

// TODO: optimise grid size
launch_kernel_no_wait(
m_bwd_scan_float_cps.get(), cb,
Expand All @@ -566,7 +569,7 @@ bool MetalTxCaller::call_task(NNTask &task, std::mutex &inter_caller_mutex, int
.contiguous()
.to(m_scores_dtype);

MTL::CommandBuffer *const cb = m_command_queue->commandBuffer();
MTL::CommandBuffer *const cb = next_command_buffer(m_command_queue.get(), try_count);
if (m_decode_complete_event) {
// wait for the previous decode task to complete - this acts as a mutex
// previous scores are processed in the decode threads
Expand Down
2 changes: 1 addition & 1 deletion dorado/basecall/decode/beam_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ std::tuple<std::string, std::string> generate_sequence(const std::vector<uint8_t
qstring[i] = static_cast<char>(33.5f + qscore);
}

return make_tuple(sequence, qstring);
return make_tuple(std::move(sequence), std::move(qstring));
}

// Incorporates NUM_NEW_BITS into a Castagnoli CRC32, aka CRC32C
Expand Down
44 changes: 34 additions & 10 deletions dorado/basecall/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#include <stdexcept>

// Splitting up command buffers can be useful since it allows Xcode to make GPU captures.
#define USE_SPLIT_LSTM_COMMAND_BUFFERS 0

namespace {
constexpr int kLstmGates = 4;
// SIMD tile size dictated by the Metal spec.
Expand Down Expand Up @@ -83,7 +86,8 @@ MetalConv1dImpl::MetalConv1dImpl(int layer,
// The last 2 arguments are unused.
const std::vector<int32_t> args{in_size, win_size, out_size, stride, win_size / 2,
chunk_size, batch_size, 0, 0};
m_args.push_back(create_vec_buffer(device, args));
auto &buffer = m_args.emplace_back(create_vec_buffer(device, args));
name_mtl_object(buffer, "conv_args");
} else {
// We cut up the time span for individual kernel launches for conv3 since it is by far
// the most time consuming, and sup times can be of the order of seconds, which
Expand All @@ -99,7 +103,8 @@ MetalConv1dImpl::MetalConv1dImpl(int layer,
const std::vector<int32_t> args{in_size, win_size, out_size,
stride, win_size / 2, chunk_size,
batch_size, time_step_begin, time_step_end};
m_args.push_back(create_vec_buffer(device, args));
auto &buffer = m_args.emplace_back(create_vec_buffer(device, args));
name_mtl_object(buffer, "conv_args");
}
spdlog::debug("conv3 output_time_step_count {} => {} kernel launches",
output_time_step_count, num_pieces);
Expand Down Expand Up @@ -222,8 +227,8 @@ MetalBlockImpl::MetalBlockImpl(int chunk_size_,
const int time_step_end = std::min((i + 1) * kMaxTimeSteps, lstm_chunk_size);
std::vector<int32_t> args{batch_size / kTileSize, lstm_chunk_size, time_step_begin,
time_step_end};
auto args_buffer = create_vec_buffer(m_device, args);
m_args_lstm.push_back(args_buffer);
auto &buffer = m_args_lstm.emplace_back(create_vec_buffer(m_device, args));
name_mtl_object(buffer, "lstm_args");
}
spdlog::debug("lstm_chunk_size {} => {} LSTM kernel launches", lstm_chunk_size, num_pieces);
}
Expand All @@ -237,10 +242,13 @@ MetalBlockImpl::MetalBlockImpl(int chunk_size_,
const int32_t in_batch_tile_offset = out_batch_tiles * i;
std::vector<int32_t> args_linear_ = {in_batch_tiles, in_batch_tile_offset, out_batch_tiles,
lstm_chunk_size};
args_linear.at(i) = create_vec_buffer(m_device, args_linear_);
auto &buffer = args_linear.at(i);
buffer = create_vec_buffer(m_device, args_linear_);
name_mtl_object(buffer, "linear_args");
}
args_linear2 = create_vec_buffer<int32_t>(
device, {out_batch_tiles, 0, out_batch_tiles, lstm_chunk_size});
name_mtl_object(args_linear2, "linear2_args");

switch (config.lstm_size) {
case 128:
Expand All @@ -253,7 +261,7 @@ MetalBlockImpl::MetalBlockImpl(int chunk_size_,
kernel_simd_groups = 32;
break;
case 384:
kernel_simd_groups = 24;
kernel_simd_groups = TARGET_OS_IPHONE ? 32 : 24;
break;
case 512:
kernel_simd_groups = 32;
Expand Down Expand Up @@ -369,6 +377,9 @@ MetalBlockImpl::MetalBlockImpl(int chunk_size_,
m_device, size_t(lstm_chunk_size + 3) * batch_size * config.lstm_size * dtype_bytes);
mat_state = create_buffer(m_device, batch_size * config.lstm_size * dtype_bytes);
mat_temp = create_buffer(m_device, mat_temp_elems * dtype_bytes);
name_mtl_object(mat_working_mem, "mat_working_mem");
name_mtl_object(mat_state, "mat_state");
name_mtl_object(mat_temp, "mat_temp");
}

void MetalBlockImpl::load_weights() {
Expand Down Expand Up @@ -400,6 +411,7 @@ void MetalBlockImpl::load_weights() {
t_w = torch::stack({t_w[2], t_w[0], t_w[1], t_w[3]}, 2);

rnn->t_weights_bias.view_as(t_w) = t_w;
name_mtl_object(mtl_for_tensor(rnn->t_weights_bias), "rnn_weights");
}

// Load and prepare linear layer weights.
Expand Down Expand Up @@ -433,7 +445,7 @@ MTL::CommandBuffer *MetalBlockImpl::forward_async(at::Tensor &in,
uint64_t linear_hold_off_id,
int try_count,
std::vector<at::Tensor> &out) {
auto command_buffer = m_command_queue->commandBuffer();
auto command_buffer = next_command_buffer(m_command_queue.get(), try_count);

if (in.dtype() != torch::kF16) {
throw std::runtime_error("Input tensor must be float16.");
Expand All @@ -448,7 +460,9 @@ MTL::CommandBuffer *MetalBlockImpl::forward_async(at::Tensor &in,
std::string lstm_label = "lstm_rnn0";
for (auto &rnn : {rnn1, rnn2, rnn3, rnn4, rnn5}) {
lstm_label.back()++;
command_buffer = m_command_queue->commandBuffer();
#if !USE_SPLIT_LSTM_COMMAND_BUFFERS
command_buffer = next_command_buffer(m_command_queue.get(), try_count);
#endif

const int kResBufSize =
static_cast<int>(dtype_bytes * kernel_simd_groups * 2 * kTileSize * kTileSize);
Expand All @@ -459,17 +473,27 @@ MTL::CommandBuffer *MetalBlockImpl::forward_async(at::Tensor &in,
const std::vector<MTL::Buffer *> buffers{args_lstm.get(), mat_working_mem.get(),
mtl_for_tensor(rnn->t_weights_bias),
mat_state.get()};
#if USE_SPLIT_LSTM_COMMAND_BUFFERS
command_buffer = next_command_buffer(m_command_queue.get(), try_count);
#endif
launch_kernel_no_wait(lstm_cps[rnn->reverse].get(), command_buffer, buffers,
tg_buffer_lens, kernel_thread_groups,
kernel_simd_groups * kSIMDGroupWidth);
#if USE_SPLIT_LSTM_COMMAND_BUFFERS
if (!finishCommandBuffer(lstm_label.c_str(), command_buffer, try_count)) {
return nullptr;
}
#endif
}

if (!finishCommandBuffer(lstm_label, command_buffer, try_count)) {
#if !USE_SPLIT_LSTM_COMMAND_BUFFERS
if (!finishCommandBuffer(lstm_label.c_str(), command_buffer, try_count)) {
return nullptr;
}
#endif
}

command_buffer = m_command_queue->commandBuffer();
command_buffer = next_command_buffer(m_command_queue.get(), try_count);

// The output buffers of conv/LSTM layers are not used by the decoding, so
// can be overwritten by subsequent batches as soon as they have been consumed by
Expand Down
19 changes: 13 additions & 6 deletions dorado/torch_utils/metal_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ void launch_kernel(ComputePipelineState *const pipeline,
}
}

MTL::CommandBuffer *next_command_buffer(MTL::CommandQueue *queue, int try_count) {
if (try_count == 0) {
return queue->commandBuffer();
}
// We're on a retry so there must have been an error, so enable additional logging this time around.
auto descriptor = NS::TransferPtr(MTL::CommandBufferDescriptor::alloc()->init());
descriptor->setErrorOptions(MTL::CommandBufferErrorOptionEncoderExecutionStatus);
return queue->commandBuffer(descriptor.get());
}

void launch_kernel_no_wait(ComputePipelineState *const pipeline,
CommandBuffer *const command_buffer,
const std::vector<Buffer *> &buffers,
Expand All @@ -235,7 +245,8 @@ void launch_kernel_no_wait(ComputePipelineState *const pipeline,
compute_encoder->endEncoding();
}

bool finishCommandBuffer(std::string_view label, MTL::CommandBuffer *cb, int try_count) {
bool finishCommandBuffer(const char *label, MTL::CommandBuffer *cb, int try_count) {
name_mtl_object(cb, label);
cb->commit();
cb->waitUntilCompleted();

Expand All @@ -249,11 +260,7 @@ bool finishCommandBuffer(std::string_view label, MTL::CommandBuffer *cb, int try
spdlog::warn("Metal command buffer {} failed: status {} (try {})", label,
fmt::underlying(status), try_count);
if (status == MTL::CommandBufferStatusError) {
const auto *const error_ptr = cb->error();
if (error_ptr) {
spdlog::warn("Command buffer error code: {} ({})", error_ptr->code(),
error_ptr->localizedDescription()->utf8String());
}
report_error(cb->error(), "finishCommandBuffer");
}
}
return success;
Expand Down
9 changes: 8 additions & 1 deletion dorado/torch_utils/metal_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

namespace dorado::utils {

template <typename MetalObject>
void name_mtl_object(MetalObject &&obj, const char *name) {
obj->setLabel(NS::String::string(name, NS::ASCIIStringEncoding));
}

// Returns an uninitialised MTL::Buffer of length bytes.
NS::SharedPtr<MTL::Buffer> create_buffer(MTL::Device *device, size_t length);

Expand Down Expand Up @@ -46,6 +51,8 @@ void launch_kernel(MTL::ComputePipelineState *cps,
long threadgroups,
long threads_per_threadroup);

MTL::CommandBuffer *next_command_buffer(MTL::CommandQueue *queue, int try_count);

void launch_kernel_no_wait(MTL::ComputePipelineState *cps,
MTL::CommandBuffer *cb,
const std::vector<MTL::Buffer *> &buffers,
Expand All @@ -54,7 +61,7 @@ void launch_kernel_no_wait(MTL::ComputePipelineState *cps,
long threads_per_threadgroup);

// Returns true on success.
bool finishCommandBuffer(std::string_view label, MTL::CommandBuffer *cb, int try_count);
bool finishCommandBuffer(const char *label, MTL::CommandBuffer *cb, int try_count);

NS::SharedPtr<MTL::Device> get_mtl_device();
int get_mtl_device_core_count();
Expand Down

0 comments on commit ac0dfe2

Please sign in to comment.