Skip to content

Commit 8b9fb46

Browse files
committed
Update plugin classes to match the base IPluginV2 class in TensorRT8
1 parent 311f328 commit 8b9fb46

File tree

2 files changed

+50
-48
lines changed

2 files changed

+50
-48
lines changed

torch2trt/plugins/group_norm.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,19 @@ class GroupNormPlugin : public IPluginV2 {
112112
return data_str.str();
113113
}
114114

115-
const char* getPluginType() const override {
115+
const char* getPluginType() const noexcept override {
116116
return "group_norm";
117117
};
118118

119-
const char* getPluginVersion() const override {
119+
const char* getPluginVersion() const noexcept override {
120120
return "1";
121121
}
122122

123-
int getNbOutputs() const override {
123+
int getNbOutputs() const noexcept override {
124124
return 1;
125125
}
126126

127-
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override {
127+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override {
128128
Dims dims;
129129
dims.nbDims = inputs->nbDims;
130130

@@ -135,8 +135,8 @@ class GroupNormPlugin : public IPluginV2 {
135135
return dims;
136136
}
137137

138-
bool supportsFormat(DataType type, PluginFormat format) const override {
139-
if (format != PluginFormat::kNCHW) {
138+
bool supportsFormat(DataType type, PluginFormat format) const noexcept override {
139+
if (format != PluginFormat::kLINEAR) {
140140
return false;
141141
}
142142
if (type == DataType::kINT32 || type == DataType::kINT8) {
@@ -146,7 +146,7 @@ class GroupNormPlugin : public IPluginV2 {
146146
}
147147

148148
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
149-
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {
149+
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override {
150150

151151
// set data type
152152
if (type == DataType::kFLOAT) {
@@ -170,7 +170,7 @@ class GroupNormPlugin : public IPluginV2 {
170170
}
171171
}
172172

173-
int initialize() override {
173+
int initialize() noexcept override {
174174
// set device
175175
tensor_options = tensor_options.device(c10::kCUDA);
176176

@@ -188,11 +188,12 @@ class GroupNormPlugin : public IPluginV2 {
188188
return 0;
189189
}
190190

191-
void terminate() override {}
191+
void terminate() noexcept override {}
192192

193-
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
193+
size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; }
194194

195-
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override {
195+
int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
196+
cudaStream_t stream) noexcept override {
196197
// get input / output dimensions
197198
std::vector<long> batch_input_sizes = input_sizes;
198199
std::vector<long> batch_output_sizes = output_sizes;
@@ -235,25 +236,25 @@ class GroupNormPlugin : public IPluginV2 {
235236
}
236237

237238

238-
size_t getSerializationSize() const override {
239+
size_t getSerializationSize() const noexcept override {
239240
return serializeToString().size();
240241
}
241242

242-
void serialize(void* buffer) const override {
243+
void serialize(void* buffer) const noexcept override {
243244
std::string data = serializeToString();
244245
size_t size = getSerializationSize();
245246
data.copy((char *) buffer, size);
246247
}
247248

248-
void destroy() override {}
249+
void destroy() noexcept override {}
249250

250-
IPluginV2* clone() const override {
251+
IPluginV2* clone() const noexcept override {
251252
return new GroupNormPlugin(num_groups, weight, bias, eps);
252253
}
253254

254-
void setPluginNamespace(const char* pluginNamespace) override {}
255+
void setPluginNamespace(const char* pluginNamespace) noexcept override {}
255256

256-
const char *getPluginNamespace() const override {
257+
const char *getPluginNamespace() const noexcept override {
257258
return "torch2trt";
258259
}
259260

@@ -263,26 +264,26 @@ class GroupNormPluginCreator : public IPluginCreator {
263264
public:
264265
GroupNormPluginCreator() {}
265266

266-
const char *getPluginNamespace() const override {
267+
const char *getPluginNamespace() const noexcept override {
267268
return "torch2trt";
268269
}
269270

270-
const char *getPluginName() const override {
271+
const char *getPluginName() const noexcept override {
271272
return "group_norm";
272273
}
273274

274-
const char *getPluginVersion() const override {
275+
const char *getPluginVersion() const noexcept override {
275276
return "1";
276277
}
277278

278-
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override {
279+
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) noexcept override {
279280
return new GroupNormPlugin((const char*) data, length);
280281
}
281282

282-
void setPluginNamespace(const char *N) override {}
283-
const PluginFieldCollection *getFieldNames() override { return nullptr; }
283+
void setPluginNamespace(const char *N) noexcept override {}
284+
const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; }
284285

285-
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; }
286+
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; }
286287

287288
};
288289

torch2trt/plugins/interpolate.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,19 @@ class InterpolatePlugin : public IPluginV2 {
103103
return data_str.str();
104104
}
105105

106-
const char* getPluginType() const override {
106+
const char* getPluginType() const noexcept override {
107107
return "interpolate";
108108
};
109109

110-
const char* getPluginVersion() const override {
110+
const char* getPluginVersion() const noexcept override {
111111
return "1";
112112
}
113113

114-
int getNbOutputs() const override {
114+
int getNbOutputs() const noexcept override {
115115
return 1;
116116
}
117117

118-
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override {
118+
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override {
119119
Dims dims;
120120
dims.nbDims = inputs->nbDims;
121121

@@ -127,8 +127,8 @@ class InterpolatePlugin : public IPluginV2 {
127127
return dims;
128128
}
129129

130-
bool supportsFormat(DataType type, PluginFormat format) const override {
131-
if (format != PluginFormat::kNCHW) {
130+
bool supportsFormat(DataType type, PluginFormat format) const noexcept override {
131+
if (format != PluginFormat::kLINEAR) {
132132
return false;
133133
}
134134
if (type == DataType::kINT32 || type == DataType::kINT8) {
@@ -138,7 +138,7 @@ class InterpolatePlugin : public IPluginV2 {
138138
}
139139

140140
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
141-
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {
141+
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override {
142142

143143
// set data type
144144
if (type == DataType::kFLOAT) {
@@ -162,7 +162,7 @@ class InterpolatePlugin : public IPluginV2 {
162162
}
163163
}
164164

165-
int initialize() override {
165+
int initialize() noexcept override {
166166
// set device
167167
tensor_options = tensor_options.device(c10::kCUDA);
168168

@@ -176,11 +176,12 @@ class InterpolatePlugin : public IPluginV2 {
176176
return 0;
177177
}
178178

179-
void terminate() override {}
179+
void terminate() noexcept override {}
180180

181-
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
181+
size_t getWorkspaceSize(int maxBatchSize) const noexcept override { return 0; }
182182

183-
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override {
183+
int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
184+
cudaStream_t stream) noexcept override {
184185
// get input / output dimensions
185186
std::vector<long> batch_input_sizes = input_sizes;
186187
std::vector<long> batch_output_sizes = output_sizes;
@@ -227,25 +228,25 @@ class InterpolatePlugin : public IPluginV2 {
227228
return 0;
228229
}
229230

230-
size_t getSerializationSize() const override {
231+
size_t getSerializationSize() const noexcept override {
231232
return serializeToString().size();
232233
}
233234

234-
void serialize(void* buffer) const override {
235+
void serialize(void* buffer) const noexcept override {
235236
std::string data = serializeToString();
236237
size_t size = getSerializationSize();
237238
data.copy((char *) buffer, size);
238239
}
239240

240-
void destroy() override {}
241+
void destroy() noexcept override {}
241242

242-
IPluginV2* clone() const override {
243+
IPluginV2* clone() const noexcept override {
243244
return new InterpolatePlugin(size, mode, align_corners);
244245
}
245246

246-
void setPluginNamespace(const char* pluginNamespace) override {}
247+
void setPluginNamespace(const char* pluginNamespace) noexcept override {}
247248

248-
const char *getPluginNamespace() const override {
249+
const char *getPluginNamespace() const noexcept override {
249250
return "torch2trt";
250251
}
251252

@@ -255,26 +256,26 @@ class InterpolatePluginCreator : public IPluginCreator {
255256
public:
256257
InterpolatePluginCreator() {}
257258

258-
const char *getPluginNamespace() const override {
259+
const char *getPluginNamespace() const noexcept override {
259260
return "torch2trt";
260261
}
261262

262-
const char *getPluginName() const override {
263+
const char *getPluginName() const noexcept override {
263264
return "interpolate";
264265
}
265266

266-
const char *getPluginVersion() const override {
267+
const char *getPluginVersion() const noexcept override {
267268
return "1";
268269
}
269270

270-
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override {
271+
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) noexcept override {
271272
return new InterpolatePlugin((const char*) data, length);
272273
}
273274

274-
void setPluginNamespace(const char *N) override {}
275-
const PluginFieldCollection *getFieldNames() override { return nullptr; }
275+
void setPluginNamespace(const char *N) noexcept override {}
276+
const PluginFieldCollection *getFieldNames() noexcept override { return nullptr; }
276277

277-
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; }
278+
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override { return nullptr; }
278279

279280
};
280281

0 commit comments

Comments
 (0)