forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgemm_shared_stream.h
134 lines (111 loc) · 5.06 KB
/
gemm_shared_stream.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines abstractions for managing loading and storing fragments to shared memory in the
efficient GEMM pipeline.
*/
#pragma once
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/gemm_shared_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The load iterator.
typename Iterator_,
/// The transformer to be applied after the data has been copied from shared memory.
typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct SharedLoadStream {
/// The load iterator.
typedef Iterator_ Iterator;
/// The transformer.
typedef Transformer_ Transformer;
/// The fragment that is copied from shared memory.
typedef typename Iterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Scalar data type
typedef typename Iterator::Scalar Scalar;
/// Reference type to a tensor
typedef TensorRef<Scalar, 4> TensorRef;
/// The params.
struct Params {
/// The iterator params.
typename Iterator::Params iterator;
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize() { return iterator.initialize(); }
};
/// The storage in shared memory needed by that stream.
typedef typename Iterator::Storage SharedStorage;
/// Ctor.
CUTLASS_DEVICE SharedLoadStream() {}
/// Ctor.
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, TensorRef const &ref) {
this->initialize(params, ref);
}
/// Initialize the stream.
CUTLASS_DEVICE void initialize(Params const ¶ms, TensorRef const &ref) {
// The iterator.
iterator = Iterator(params.iterator, ref.data());
// The transformer.
transformer = Transformer();
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy() {
iterator.load_post_increment(fetched[0]);
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); }
/// Commit the data.
CUTLASS_DEVICE void commit() { transformer.transform(fetched[0], transformed[0]); }
/// Commit the data.
CUTLASS_DEVICE void commit(int step) {
transformer.transform(fetched[step % 2], transformed[step % 2]);
}
/// Returns the fragment for the given step
CUTLASS_DEVICE TransformedFragment &fragment(int step = 0) { return transformed[step % 2]; }
/// Returns the fragment for the given step
CUTLASS_DEVICE TransformedFragment const &fragment(int step = 0) const {
return transformed[step % 2];
}
/// Increment the stage.
CUTLASS_DEVICE void inc_stage() { iterator.inc_stage(); }
/// The iterator.
Iterator iterator;
/// Fetched fragment
FetchedFragment fetched[2];
/// The transformer.
Transformer transformer;
/// Transformed fragment
TransformedFragment transformed[2];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass