@@ -199,10 +199,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
199199 break ;
200200 }
201201 case kImagePitchAlignment : {
202- cl_uint row_pitch;
203- OPENCL_CALL (clGetDeviceInfo (device_id, CL_DEVICE_IMAGE_PITCH_ALIGNMENT_KHR, sizeof (row_pitch),
204- &row_pitch, nullptr ));
205- *rv = static_cast <int64_t >(row_pitch);
202+ *rv = static_cast <int64_t >(device_info[device_id].image_row_align );
206203 break ;
207204 }
208205 }
@@ -280,12 +277,45 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width,
280277 return desc;
281278}
282279
280+ static size_t GetMemObjectSize (Device dev, int ndim, const int64_t * shape, DLDataType dtype) {
281+ DLTensor temp;
282+ temp.data = nullptr ;
283+ temp.device = dev;
284+ temp.ndim = ndim;
285+ temp.dtype = dtype;
286+ temp.shape = const_cast <int64_t *>(shape);
287+ temp.strides = nullptr ;
288+ temp.byte_offset = 0 ;
289+ size_t size = GetDataSize (temp);
290+ return size;
291+ }
292+
283293void * OpenCLWorkspace::AllocDataSpaceView (Device dev, void * data, int ndim, const int64_t * shape,
284294 DLDataType dtype, Optional<String> mem_scope) {
295+ cl::BufferDescriptor* desc = static_cast <cl::BufferDescriptor*>(data);
296+
297+ // Fall back for devices w/o "cl_khr_image2d_from_buffer"
285298 if (!IsBufferToImageSupported (dev.device_id )) {
286- return data;
299+ cl::BufferDescriptor* ret_desc = desc; // buffer -> buffer
300+ if (!mem_scope.defined () || mem_scope.value () == " global" ) {
301+ if (desc->layout != cl::BufferDescriptor::MemoryLayout::kBuffer1D ) {
302+ // image -> buffer
303+ size_t nbytes = GetMemObjectSize (dev, ndim, shape, dtype);
304+ ret_desc = static_cast <cl::BufferDescriptor*>(
305+ OpenCLWorkspace::AllocCLBuffer (dev, nbytes, kTempAllocaAlignment , dtype));
306+ ret_desc->is_compat_view = true ;
307+ }
308+ } else {
309+ // Any -> Image
310+ size_t axis = DefaultTextureLayoutSeparator (ndim, mem_scope.value ());
311+ auto texture = ApplyTexture2DFlattening<int64_t >(shape, ndim, axis);
312+ size_t row_pitch = GetRowPitch (dev, texture.width , dtype);
313+ ret_desc = static_cast <cl::BufferDescriptor*>(OpenCLWorkspace::Global ()->AllocCLImage (
314+ dev, nullptr , texture.width , texture.height , row_pitch, dtype, mem_scope));
315+ ret_desc->is_compat_view = true ;
316+ }
317+ return ret_desc;
287318 }
288- cl::BufferDescriptor* desc = static_cast <cl::BufferDescriptor*>(data);
289319
290320 if (!mem_scope.defined () || mem_scope.value () == " global" ) {
291321 if (desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D ) {
@@ -298,7 +328,6 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, int ndim, cons
298328 }
299329 size_t axis = DefaultTextureLayoutSeparator (ndim, mem_scope.value ());
300330 auto texture = ApplyTexture2DFlattening<int64_t >(shape, ndim, axis);
301-
302331 size_t row_pitch = GetRowPitch (dev, texture.width , dtype);
303332
304333 cl::BufferDescriptor* back_buffer;
@@ -314,6 +343,24 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, int ndim, cons
314343 row_pitch, dtype, mem_scope);
315344}
316345
346+ void OpenCLWorkspace::FreeDataSpaceView (Device dev, void * ptr) {
347+ OPENCL_CALL (clFinish (this ->GetQueue (dev)));
348+ auto * desc = static_cast <const cl::BufferDescriptor*>(ptr);
349+ // Handle the fall back
350+ if (!IsBufferToImageSupported (dev.device_id )) {
351+ if (desc->is_compat_view ) {
352+ OPENCL_CALL (clReleaseMemObject (desc->buffer ));
353+ delete desc;
354+ }
355+ return ;
356+ }
357+
358+ if (desc->layout != cl::BufferDescriptor::MemoryLayout::kBuffer1D ) {
359+ OPENCL_CALL (clReleaseMemObject (desc->buffer ));
360+ delete desc;
361+ }
362+ }
363+
317364void * OpenCLWorkspace::GetNativePtr (const tvm::runtime::NDArray& narr) {
318365 cl::BufferDescriptor* desc = static_cast <cl::BufferDescriptor*>(narr.operator ->()->data );
319366 return desc->host_ptr ;
@@ -329,16 +376,22 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) {
329376 clEnqueueUnmapMemObject (this ->GetQueue (dev), desc->buffer ,
330377 reinterpret_cast <void *>(desc->host_ptr ), 0 , nullptr , nullptr );
331378 }
332- if (!IsBufferToImageSupported (dev.device_id )) {
333- OPENCL_CALL (clReleaseMemObject (desc->buffer ));
334- return ;
335- }
379+
336380 if (desc->back_buffer ) {
381+ // 2D Image w/ back buffer allocated from pool
337382 OPENCL_CALL (clReleaseMemObject (desc->buffer ));
338383 GetThreadEntry ()->mpool .FreeMemory (dev, desc->back_buffer );
339384 delete desc;
340385 } else {
341- GetThreadEntry ()->mpool .FreeMemory (dev, desc);
386+ if (desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D ) {
387+ // 1D buffer allocated from pool
388+ GetThreadEntry ()->mpool .FreeMemory (dev, desc);
389+ } else if (!IsBufferToImageSupported (dev.device_id )) {
390+ // 2D Image allocated w/o pool
391+ OPENCL_CALL (clReleaseMemObject (desc->buffer ));
392+ delete desc;
393+ return ;
394+ }
342395 }
343396}
344397
@@ -349,18 +402,6 @@ void OpenCLWorkspace::FreeCLBuffer(Device dev, void* ptr) {
349402 delete desc;
350403}
351404
352- void OpenCLWorkspace::FreeDataSpaceView (Device dev, void * ptr) {
353- OPENCL_CALL (clFinish (this ->GetQueue (dev)));
354- if (!IsBufferToImageSupported (dev.device_id )) {
355- return ;
356- }
357- auto * desc = static_cast <const cl::BufferDescriptor*>(ptr);
358- if (desc->layout != cl::BufferDescriptor::MemoryLayout::kBuffer1D ) {
359- OPENCL_CALL (clReleaseMemObject (desc->buffer ));
360- delete desc;
361- }
362- }
363-
364405void OpenCLWorkspace::CopyDataFromTo (DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
365406 size_t nbytes = GetDataSize (*from);
366407 ICHECK_EQ (nbytes, GetDataSize (*to));
@@ -593,14 +634,15 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
593634 cl_uint row_pitch;
594635 OPENCL_CALL (clGetDeviceInfo (did, CL_DEVICE_IMAGE_PITCH_ALIGNMENT_KHR, sizeof (row_pitch),
595636 &row_pitch, nullptr ));
637+ if (0 == row_pitch) {
638+ row_pitch = kAllocAlignment ; // Fallback
639+ }
596640 dev_info.image_row_align = row_pitch;
597-
598641 size_t reqd_size = 0 ;
599642 OPENCL_CALL (clGetDeviceInfo (did, CL_DEVICE_EXTENSIONS, 0 , nullptr , &reqd_size));
600643 std::vector<char > extn_buf (reqd_size);
601644 OPENCL_CALL (clGetDeviceInfo (did, CL_DEVICE_EXTENSIONS, reqd_size, extn_buf.data (), nullptr ));
602645 std::string extensions (extn_buf.data ());
603- LOG (WARNING) << " OpenCL Extensions:" << extensions;
604646
605647 if (extensions.find (" cl_khr_image2d_from_buffer" ) != std::string::npos) {
606648 dev_info.image_from_buffer_support = true ;
0 commit comments