Skip to content

Commit

Permalink
Bugfix: memsetAsync uses wrong default stream
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Han <fujun.han@iluvatar.ai>
  • Loading branch information
Peter9606 committed Mar 23, 2021
1 parent 50bf00e commit 92393b2
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion include/cutlass/conv/device/implicit_gemm_convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class ImplicitGemmConvolution {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/gemm/device/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class Gemm {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down Expand Up @@ -673,7 +673,7 @@ class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}

/// Lightweight update given a subset of arguments
Expand All @@ -699,7 +699,7 @@ class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/gemm/device/gemm_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ class GemmArray {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down Expand Up @@ -700,7 +700,7 @@ class GemmArray<
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}

/// Lightweight update given a subset of arguments
Expand All @@ -726,7 +726,7 @@ class GemmArray<
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/gemm/device/gemm_batched.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ class GemmBatched {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down Expand Up @@ -666,7 +666,7 @@ class GemmBatched<
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}

/// Lightweight update given a subset of arguments
Expand All @@ -692,7 +692,7 @@ class GemmBatched<
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/gemm/device/gemm_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ class GemmComplex {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down Expand Up @@ -674,7 +674,7 @@ class GemmComplex<
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}

/// Lightweight update given a subset of arguments
Expand All @@ -700,7 +700,7 @@ class GemmComplex<
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/device/gemm_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ class SparseGemm {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/device/gemm_universal_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class GemmUniversalBase {
return Status::kErrorWorkspaceNull;
}

params_.update(args, workspace);
params_.update(args, workspace, stream);

return Status::kSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/reduction/device/reduce_split_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class ReduceSplitK {
void *workspace = nullptr,
cudaStream_t stream = nullptr) {

Status status = initialize(args, workspace);
Status status = initialize(args, workspace, stream);

if (status == Status::kSuccess) {
status = run(stream);
Expand Down

0 comments on commit 92393b2

Please sign in to comment.