@@ -594,124 +594,213 @@ ur_result_t urUSMReleaseExp(ur_context_handle_t Context, void *HostPtr) {
594
594
return UR_RESULT_SUCCESS;
595
595
}
596
596
597
+ static ur_result_t enqueueUSMAllocHelper (
598
+ ur_queue_handle_t Queue, ur_usm_pool_handle_t Pool, const size_t Size,
599
+ const ur_exp_async_usm_alloc_properties_t *Properties,
600
+ uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
601
+ void **RetMem, ur_event_handle_t *OutEvent, ur_usm_type_t Type) {
602
+ std::ignore = Pool;
603
+ std::ignore = Properties;
604
+
605
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
606
+
607
+ bool UseCopyEngine = false ;
608
+ _ur_ze_event_list_t TmpWaitList;
609
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
610
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
611
+
612
+ bool OkToBatch = true ;
613
+ // Get a new command list to be used on this call
614
+ ur_command_list_ptr_t CommandList{};
615
+ UR_CALL (Queue->Context ->getAvailableCommandList (
616
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
617
+ OkToBatch, nullptr /* ForcedCmdQueue*/ ));
618
+
619
+ ze_event_handle_t ZeEvent = nullptr ;
620
+ ur_event_handle_t InternalEvent{};
621
+ bool IsInternal = OutEvent == nullptr ;
622
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
623
+
624
+ ur_command_t CommandType = UR_COMMAND_FORCE_UINT32;
625
+ switch (Type) {
626
+ case UR_USM_TYPE_HOST:
627
+ CommandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
628
+ break ;
629
+ case UR_USM_TYPE_DEVICE:
630
+ CommandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
631
+ break ;
632
+ case UR_USM_TYPE_SHARED:
633
+ CommandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
634
+ break ;
635
+ default :
636
+ logger::error (" enqueueUSMAllocHelper: unsupported USM type" );
637
+ throw UR_RESULT_ERROR_UNKNOWN;
638
+ }
639
+ UR_CALL (createEventAndAssociateQueue (Queue, Event, CommandType, CommandList,
640
+ IsInternal, false ));
641
+ ZeEvent = (*Event)->ZeEvent ;
642
+ (*Event)->WaitList = TmpWaitList;
643
+
644
+ // Allocate USM memory
645
+ ur_usm_pool_handle_t USMPool = nullptr ;
646
+ if (Pool) {
647
+ USMPool = Pool;
648
+ } else {
649
+ USMPool = &Queue->Context ->AsyncPool ;
650
+ }
651
+
652
+ auto Device = (Type == UR_USM_TYPE_HOST) ? nullptr : Queue->Device ;
653
+ auto Ret =
654
+ USMPool->allocate (Queue->Context , Device, nullptr , Type, Size, RetMem);
655
+ if (Ret) {
656
+ return Ret;
657
+ }
658
+
659
+ // Signal that USM allocation event was finished
660
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (CommandList->first , ZeEvent));
661
+
662
+ UR_CALL (Queue->executeCommandList (CommandList, false , OkToBatch));
663
+
664
+ return UR_RESULT_SUCCESS;
665
+ }
666
+
597
667
ur_result_t urEnqueueUSMDeviceAllocExp (
598
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
599
- ur_usm_pool_handle_t
600
- pPool, // /< [in][optional] handle of the USM memory pool
601
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
668
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
669
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
670
+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
602
671
// /< to be allocated
603
672
const ur_exp_async_usm_alloc_properties_t
604
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
605
- // /< USM allocation properties
606
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
673
+ *Properties , // /< [in][optional] pointer to the enqueue asynchronous
674
+ // /< USM allocation properties
675
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
607
676
const ur_event_handle_t
608
- *phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
609
- // /< pointer to a list of events that must be complete
610
- // /< before the kernel execution. If nullptr, the
611
- // /< numEventsInWaitList must be 0, indicating no wait
612
- // /< events.
613
- void **ppMem, // /< [out] pointer to USM memory object
614
- ur_event_handle_t
615
- *phEvent // /< [out][optional] return an event object that identifies the
616
- // /< asynchronous USM device allocation
677
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
678
+ // /< pointer to a list of events that must be complete
679
+ // /< before the kernel execution. If nullptr, the
680
+ // /< numEventsInWaitList must be 0, indicating no wait
681
+ // /< events.
682
+ void **Mem, // /< [out] pointer to USM memory object
683
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
684
+ // /< identifies the async alloc
617
685
) {
618
- std::ignore = hQueue;
619
- std::ignore = pPool;
620
- std::ignore = size;
621
- std::ignore = pProperties;
622
- std::ignore = numEventsInWaitList;
623
- std::ignore = phEventWaitList;
624
- std::ignore = ppMem;
625
- std::ignore = phEvent;
626
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
686
+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
687
+ NumEventsInWaitList, EventWaitList, Mem,
688
+ OutEvent, UR_USM_TYPE_DEVICE);
627
689
}
628
690
629
691
ur_result_t urEnqueueUSMSharedAllocExp (
630
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
631
- ur_usm_pool_handle_t
632
- pPool, // /< [in][optional] handle of the USM memory pool
633
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
692
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
693
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
694
+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
634
695
// /< to be allocated
635
696
const ur_exp_async_usm_alloc_properties_t
636
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
637
- // /< USM allocation properties
638
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
697
+ *Properties , // /< [in][optional] pointer to the enqueue asynchronous
698
+ // /< USM allocation properties
699
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
639
700
const ur_event_handle_t
640
- *phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
641
- // /< pointer to a list of events that must be complete
642
- // /< before the kernel execution. If nullptr, the
643
- // /< numEventsInWaitList must be 0, indicating no wait
644
- // /< events.
645
- void **ppMem, // /< [out] pointer to USM memory object
646
- ur_event_handle_t
647
- *phEvent // /< [out][optional] return an event object that identifies the
648
- // /< asynchronous USM shared allocation
701
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
702
+ // /< pointer to a list of events that must be complete
703
+ // /< before the kernel execution. If nullptr, the
704
+ // /< numEventsInWaitList must be 0, indicating no wait
705
+ // /< events.
706
+ void **Mem, // /< [out] pointer to USM memory object
707
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
708
+ // /< identifies the async alloc
649
709
) {
650
- std::ignore = hQueue;
651
- std::ignore = pPool;
652
- std::ignore = size;
653
- std::ignore = pProperties;
654
- std::ignore = numEventsInWaitList;
655
- std::ignore = phEventWaitList;
656
- std::ignore = ppMem;
657
- std::ignore = phEvent;
658
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
710
+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
711
+ NumEventsInWaitList, EventWaitList, Mem,
712
+ OutEvent, UR_USM_TYPE_SHARED);
659
713
}
660
714
661
715
ur_result_t urEnqueueUSMHostAllocExp (
662
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
663
- ur_usm_pool_handle_t
664
- pPool, // /< [in][optional] handle of the USM memory pool
665
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
716
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
717
+ ur_usm_pool_handle_t Pool, // /< [in][optional] handle of the USM memory pool
718
+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
666
719
// /< to be allocated
667
720
const ur_exp_async_usm_alloc_properties_t
668
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
669
- // /< USM allocation properties
670
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
721
+ *Properties , // /< [in][optional] pointer to the enqueue asynchronous
722
+ // /< USM allocation properties
723
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
671
724
const ur_event_handle_t
672
- *phEventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
673
- // /< pointer to a list of events that must be complete
674
- // /< before the kernel execution. If nullptr, the
675
- // /< numEventsInWaitList must be 0, indicating no wait
676
- // /< events.
677
- void **ppMem , // /< [out] pointer to USM memory object
725
+ *EventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
726
+ // /< pointer to a list of events that must be complete
727
+ // /< before the kernel execution. If nullptr, the
728
+ // /< numEventsInWaitList must be 0, indicating no wait
729
+ // /< events.
730
+ void **Mem , // /< [out] pointer to USM memory object
678
731
ur_event_handle_t
679
- *phEvent // /< [out][optional] return an event object that identifies the
680
- // /< asynchronous USM host allocation
732
+ *OutEvent // /< [out][optional] return an event object that identifies
733
+ // /< the asynchronous USM device allocation
681
734
) {
682
- std::ignore = hQueue;
683
- std::ignore = pPool;
684
- std::ignore = size;
685
- std::ignore = pProperties;
686
- std::ignore = numEventsInWaitList;
687
- std::ignore = phEventWaitList;
688
- std::ignore = ppMem;
689
- std::ignore = phEvent;
690
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
735
+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
736
+ NumEventsInWaitList, EventWaitList, Mem,
737
+ OutEvent, UR_USM_TYPE_HOST);
691
738
}
692
739
693
740
ur_result_t urEnqueueUSMFreeExp (
694
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
695
- ur_usm_pool_handle_t
696
- pPool, // /< [in][optional] handle of the USM memory pooliptor
697
- void *pMem, // /< [in] pointer to USM memory object
698
- uint32_t numEventsInWaitList, // /< [in] size of the event wait list
741
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
742
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
743
+ void *Mem, // /< [in] pointer to USM memory object
744
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
699
745
const ur_event_handle_t
700
- *phEventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
701
- // /< pointer to a list of events that must be complete
702
- // /< before the kernel execution. If nullptr, the
703
- // /< numEventsInWaitList must be 0, indicating no wait
704
- // /< events.
705
- ur_event_handle_t *phEvent // /< [out][optional] return an event object that
706
- // /< identifies the asynchronous USM deallocation
746
+ *EventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
747
+ // /< pointer to a list of events that must be complete
748
+ // /< before the kernel execution. If nullptr, the
749
+ // /< numEventsInWaitList must be 0, indicating no wait
750
+ // /< events.
751
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
752
+ // /< identifies the async alloc
707
753
) {
708
- std::ignore = hQueue;
709
- std::ignore = pPool;
710
- std::ignore = pMem;
711
- std::ignore = numEventsInWaitList;
712
- std::ignore = phEventWaitList;
713
- std::ignore = phEvent;
714
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
754
+ std::ignore = Pool;
755
+
756
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
757
+
758
+ bool UseCopyEngine = false ;
759
+ _ur_ze_event_list_t TmpWaitList;
760
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
761
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
762
+
763
+ bool OkToBatch = false ;
764
+ // Get a new command list to be used on this call
765
+ ur_command_list_ptr_t CommandList{};
766
+ UR_CALL (Queue->Context ->getAvailableCommandList (
767
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
768
+ OkToBatch, nullptr /* ForcedCmdQueue*/ ));
769
+
770
+ ze_event_handle_t ZeEvent = nullptr ;
771
+ ur_event_handle_t InternalEvent{};
772
+ bool IsInternal = OutEvent == nullptr ;
773
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
774
+
775
+ UR_CALL (createEventAndAssociateQueue (Queue, Event,
776
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP,
777
+ CommandList, IsInternal, false ));
778
+ ZeEvent = (*Event)->ZeEvent ;
779
+ (*Event)->WaitList = TmpWaitList;
780
+
781
+ const auto &ZeCommandList = CommandList->first ;
782
+ const auto &WaitList = (*Event)->WaitList ;
783
+ if (WaitList.Length ) {
784
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
785
+ (ZeCommandList, WaitList.Length , WaitList.ZeEventList ));
786
+
787
+ // Wait for commands execution until USM can be freed
788
+ UR_CALL (
789
+ Queue->executeCommandList (CommandList, true , OkToBatch)); // Blocking
790
+ }
791
+
792
+ // Free USM memory
793
+ auto Ret = USMFreeHelper (Queue->Context , Mem);
794
+ if (Ret) {
795
+ return Ret;
796
+ }
797
+
798
+ // Signal that USM free event was finished
799
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (ZeCommandList, ZeEvent));
800
+
801
+ UR_CALL (Queue->executeCommandList (CommandList, false , OkToBatch));
802
+
803
+ return UR_RESULT_SUCCESS;
715
804
}
716
805
717
806
ur_result_t UR_APICALL urUSMPoolCreateExp (
0 commit comments