@@ -150,7 +150,16 @@ def __ne__(self, other):
150
150
151
151
152
152
class Device (ctypes .Structure ):
153
- """TVM device strucure."""
153
+ """TVM device strucure.
154
+
155
+ Typically constructed using convenience function
156
+ :meth:`tvm.runtime.device`.
157
+
158
+ Exposes uniform interface to device-specific APIs such as CUDA or
159
+ OpenCL. Some properties may return None depending on whether an
160
+ API exposes that particular property.
161
+
162
+ """
154
163
155
164
_fields_ = [("device_type" , ctypes .c_int ), ("device_id" , ctypes .c_int )]
156
165
MASK2STR = {
@@ -205,63 +214,190 @@ def _GetDeviceAttr(self, device_type, device_id, attr_id):
205
214
206
215
@property
207
216
def exist (self ):
208
- """Whether this device exist."""
217
+ """Whether this device exists.
218
+
219
+ Returns True if TVM has support for the device, if the
220
+ physical device is present, and the device is accessible
221
+ through appropriate drivers (e.g. cuda/vulkan).
222
+
223
+ Returns
224
+ -------
225
+ exist : bool
226
+ True if the device exists
227
+
228
+ """
209
229
return self ._GetDeviceAttr (self .device_type , self .device_id , 0 ) != 0
210
230
211
231
@property
212
232
def max_threads_per_block (self ):
213
- """Maximum number of threads on each block."""
233
+ """Maximum number of threads on each block.
234
+
235
+ Returns device value for cuda, metal, rocm, opencl, and vulkan
236
+ devices. Returns remote device value for RPC devices.
237
+ Returns None for all other devices.
238
+
239
+ Returns
240
+ -------
241
+ max_threads_per_block : int or None
242
+ The number of threads on each block
243
+
244
+ """
214
245
return self ._GetDeviceAttr (self .device_type , self .device_id , 1 )
215
246
216
247
@property
217
248
def warp_size (self ):
218
- """Number of threads that executes in concurrent."""
249
+ """Number of threads that execute concurrently.
250
+
251
+ Returns device value for for cuda, rocm, and vulkan. Returns
252
+ 1 for metal and opencl devices, regardless of the physical
253
+ device. Returns remote device value for RPC devices. Returns
254
+ None for all other devices.
255
+
256
+ Returns
257
+ -------
258
+ warp_size : int or None
259
+ Number of threads that execute concurrently
260
+
261
+ """
219
262
return self ._GetDeviceAttr (self .device_type , self .device_id , 2 )
220
263
221
264
@property
222
265
def max_shared_memory_per_block (self ):
223
- """Total amount of shared memory per block in bytes."""
266
+ """Total amount of shared memory per block in bytes.
267
+
268
+ Returns device value for cuda, rocm, opencl, and vulkan.
269
+ Returns remote device value for RPC devices. Returns None for
270
+ all other devices.
271
+
272
+ Returns
273
+ -------
274
+ max_shared_memory_per_block : int or None
275
+ Total amount of shared memory per block in bytes
276
+
277
+ """
224
278
return self ._GetDeviceAttr (self .device_type , self .device_id , 3 )
225
279
226
280
@property
227
281
def compute_version (self ):
228
- """Get compute verison number in string.
282
+ """Get compute version number as string.
229
283
230
- Currently used to get compute capability of CUDA device.
284
+ Returns maximum API version (e.g. CUDA/OpenCL/Vulkan)
285
+ supported by the device.
286
+
287
+ Returns device value for cuda, rocm, opencl, and
288
+ vulkan. Returns remote device value for RPC devices. Returns
289
+ None for all other devices.
231
290
232
291
Returns
233
292
-------
234
- version : str
293
+ version : str or None
235
294
The version string in `major.minor` format.
295
+
236
296
"""
237
297
return self ._GetDeviceAttr (self .device_type , self .device_id , 4 )
238
298
239
299
@property
240
300
def device_name (self ):
241
- """Return the string name of device."""
301
+ """Return the vendor-specific name of device.
302
+
303
+ Returns device value for cuda, rocm, opencl, and vulkan.
304
+ Returns remote device value for RPC devices. Returns None for
305
+ all other devices.
306
+
307
+ Returns
308
+ -------
309
+ device_name : str or None
310
+ The name of the device.
311
+
312
+ """
242
313
return self ._GetDeviceAttr (self .device_type , self .device_id , 5 )
243
314
244
315
@property
245
316
def max_clock_rate (self ):
246
- """Return the max clock frequency of device."""
317
+ """Return the max clock frequency of device (kHz).
318
+
319
+ Returns device value for cuda, rocm, and opencl. Returns
320
+ remote device value for RPC devices. Returns None for all
321
+ other devices.
322
+
323
+ Returns
324
+ -------
325
+ max_clock_rate : int or None
326
+ The maximum clock frequency of the device (kHz)
327
+
328
+ """
247
329
return self ._GetDeviceAttr (self .device_type , self .device_id , 6 )
248
330
249
331
@property
250
332
def multi_processor_count (self ):
251
- """Return the number of compute units of device."""
333
+ """Return the number of compute units in the device.
334
+
335
+ Returns device value for cuda, rocm, and opencl. Returns
336
+ remote device value for RPC devices. Returns None for all
337
+ other devices.
338
+
339
+ Returns
340
+ -------
341
+ multi_processor_count : int or None
342
+ Thee number of compute units in the device
343
+
344
+ """
252
345
return self ._GetDeviceAttr (self .device_type , self .device_id , 7 )
253
346
254
347
@property
255
348
def max_thread_dimensions (self ):
256
349
"""Return the maximum size of each thread axis
257
350
351
+ Returns device value for cuda, rocm, opencl, and vulkan.
352
+ Returns remote device value for RPC devices. Returns None for
353
+ all other devices.
354
+
258
355
Returns
259
356
-------
260
- dims: List of int
357
+ dims: List of int, or None
261
358
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
359
+
262
360
"""
263
361
return json .loads (self ._GetDeviceAttr (self .device_type , self .device_id , 8 ))
264
362
363
+ @property
364
+ def api_version (self ):
365
+ """Returns version number of the SDK used to compile TVM.
366
+
367
+ For example, CUDA_VERSION for cuda or VK_HEADER_VERSION for
368
+ Vulkan.
369
+
370
+ Returns device value for cuda, rocm, opencl, and vulkan.
371
+ Returns remote device value for RPC devices. Returns None for
372
+ all other devices.
373
+
374
+ Returns
375
+ -------
376
+ version : int or None
377
+ The version of the SDK
378
+
379
+ """
380
+ return self ._GetDeviceAttr (self .device_type , self .device_id , 12 )
381
+
382
+ @property
383
+ def driver_version (self ):
384
+ """Returns version number of the driver
385
+
386
+ Returns driver vendor's internal version number.
387
+ (e.g. "450.408.256" for nvidia-driver-450)
388
+
389
+ Returns device value for opencl and vulkan. Returns remote
390
+ device value for RPC devices. Returns None for all other
391
+ devices.
392
+
393
+ Returns
394
+ -------
395
+ version : str or None
396
+ The version string in `major.minor.patch` format.
397
+
398
+ """
399
+ return self ._GetDeviceAttr (self .device_type , self .device_id , 12 )
400
+
265
401
def create_raw_stream (self ):
266
402
"""Create a new runtime stream at the context.
267
403
0 commit comments