-
Notifications
You must be signed in to change notification settings - Fork 4
/
resource.h
139 lines (133 loc) · 4.3 KB
/
resource.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/*!
* Copyright (c) 2015 by Contributors
* \file resource.h
* \brief Global resource allocation handling.
*/
#ifndef MXNET_RESOURCE_H_
#define MXNET_RESOURCE_H_
#include <dmlc/logging.h>
#include "./base.h"
#include "./engine.h"
namespace mxnet {
/*!
* \brief The resources that can be requested by Operator
*/
struct ResourceRequest {
/*! \brief Resource type, indicating what the pointer type is */
enum Type {
/*! \brief mshadow::Random<xpu> object */
kRandom,
/*! \brief A dynamic temp space that can be arbitrary size */
kTempSpace
};
/*! \brief type of resources */
Type type;
/*! \brief default constructor */
ResourceRequest() {}
/*!
* \brief constructor, allow implicit conversion
* \param type type of resources
*/
ResourceRequest(Type type) // NOLINT(*)
: type(type) {}
};
/*!
* \brief Resources used by mxnet operations.
* A resource is something special other than NDArray,
* but will still participate
*/
struct Resource {
/*! \brief The original request */
ResourceRequest req;
/*! \brief engine variable */
engine::VarHandle var;
/*! \brief identifier of id information, used for debug purpose */
int32_t id;
/*!
* \brief pointer to the resource, do not use directly,
* access using member functions
*/
void *ptr_;
/*! \brief default constructor */
Resource() : id(0) {}
/*!
* \brief Get random number generator.
* \param stream The stream to use in the random number generator.
* \return the mshadow random number generator requested.
* \tparam xpu the device type of random number generator.
*/
template<typename xpu, typename DType>
inline mshadow::Random<xpu, DType>* get_random(
mshadow::Stream<xpu> *stream) const {
CHECK_EQ(req.type, ResourceRequest::kRandom);
mshadow::Random<xpu, DType> *ret =
static_cast<mshadow::Random<xpu, DType>*>(ptr_);
ret->set_stream(stream);
return ret;
}
/*!
* \brief Get space requested as mshadow Tensor.
* The caller can request arbitrary size.
*
* \param shape the Shape of returning tensor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
* \tparam xpu the device type of random number generator.
* \tparam ndim the number of dimension of the tensor requested.
*/
template<typename xpu, int ndim>
inline mshadow::Tensor<xpu, ndim, real_t> get_space(
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
return get_space_typed<xpu, ndim, real_t>(shape, stream);
}
/*!
* \brief Get space requested as mshadow Tensor in specified type.
* The caller can request arbitrary size.
*
* \param shape the Shape of returning tensor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
* \tparam xpu the device type of random number generator.
* \tparam ndim the number of dimension of the tensor requested.
*/
template<typename xpu, int ndim, typename DType>
inline mshadow::Tensor<xpu, ndim, DType> get_space_typed(
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
CHECK_EQ(req.type, ResourceRequest::kTempSpace);
return mshadow::Tensor<xpu, ndim, DType>(
reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))),
shape, shape[ndim - 1], stream);
}
/*!
* \brief internal function to get space from resources.
* \param size The size of the space.
* \return The allocated space.
*/
void* get_space_internal(size_t size) const;
};
/*! \brief Global resource manager */
class ResourceManager {
public:
/*!
* \brief Get resource of requested type.
* \param ctx the context of the request.
* \param req the resource request.
* \return the requested resource.
* \note The returned resource's ownership is
* still hold by the manager singleton.
*/
virtual Resource Request(Context ctx, const ResourceRequest &req) = 0;
/*!
* \brief Seed all the allocated random numbers.
* \param seed the seed to the random number generators on all devices.
*/
virtual void SeedRandom(uint32_t seed) = 0;
/*! \brief virtual destructor */
virtual ~ResourceManager() DMLC_THROW_EXCEPTION {}
/*!
* \return Resource manager singleton.
*/
static ResourceManager *Get();
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_