@@ -16,6 +16,10 @@ namespace runtime {
1616 */
1717class MicroDeviceAPI final : public DeviceAPI {
1818 public:
19+ MicroDeviceAPI () {
20+ session_ = MicroSession::Global ();
21+ }
22+
1923 void SetDevice (TVMContext ctx) final {}
2024
2125 void GetAttr (TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
@@ -28,15 +32,12 @@ class MicroDeviceAPI final : public DeviceAPI {
2832 size_t nbytes,
2933 size_t alignment,
3034 TVMType type_hint) final {
31- // TODO: can make this a private member, but where to best init it?
32- std::shared_ptr<MicroSession> session = MicroSession::Global ();
33- void * alloc_ptr = session->AllocateInSection (kHeap , nbytes);
35+ void * alloc_ptr = session_->AllocateInSection (kHeap , nbytes);
3436 return alloc_ptr;
3537 }
3638
3739 void FreeDataSpace (TVMContext ctx, void * ptr) final {
38- std::shared_ptr<MicroSession> session = MicroSession::Global ();
39- session->FreeInSection (kHeap , ptr);
40+ session_->FreeInSection (kHeap , ptr);
4041 }
4142
4243 void CopyDataFromTo (const void * from,
@@ -48,27 +49,33 @@ class MicroDeviceAPI final : public DeviceAPI {
4849 TVMContext ctx_to,
4950 TVMType type_hint,
5051 TVMStreamHandle stream) final {
51- std::shared_ptr<MicroSession> session = MicroSession::Global ();
52- uint8_t buffer[size];
5352 constexpr int micro_devtype = kDLMicroDev ;
5453 std::tuple<int , int > type_from_to (ctx_from.device_type , ctx_to.device_type );
5554
5655 if (type_from_to == std::make_tuple (micro_devtype, micro_devtype)) {
57- // TODO: ignored ctx because we assume only one low-level micro_dev - is ok?
58- std::shared_ptr<LowLevelDevice> from_lld = session->low_level_device ();
59- std::shared_ptr<LowLevelDevice> to_lld = session->low_level_device ();
60- from_lld->Read ((uint8_t *)(from) + from_offset, buffer, size);
61- to_lld->Write ((uint8_t *)(to) + to_offset, buffer, size);
62-
56+ CHECK (ctx_from.device_id == ctx_to.device_id )
57+ << " can only copy between the same micro device" ;
58+ std::string buffer;
59+ const std::shared_ptr<LowLevelDevice>& from_lld = session_->low_level_device ();
60+ const std::shared_ptr<LowLevelDevice>& to_lld = session_->low_level_device ();
61+ from_lld->Read (
62+ const_cast <uint8_t *>(static_cast <const uint8_t *>(from)) + from_offset,
63+ const_cast <char *>(&buffer[0 ]), size);
64+ to_lld->Write (
65+ const_cast <uint8_t *>(static_cast <const uint8_t *>(to)) + to_offset,
66+ const_cast <char *>(&buffer[0 ]), size);
6367 } else if (type_from_to == std::make_tuple (micro_devtype, kDLCPU )) {
64- std::shared_ptr<LowLevelDevice> from_lld = session->low_level_device ();
65- from_lld->Read ((uint8_t *)(from) + from_offset, buffer, size);
66- memcpy (static_cast <uint8_t *>(to) + to_offset, buffer, size);
68+ const std::shared_ptr<LowLevelDevice>& from_lld = session_->low_level_device ();
69+ from_lld->Read (
70+ const_cast <uint8_t *>(static_cast <const uint8_t *>(from)) + from_offset,
71+ const_cast <uint8_t *>(static_cast <const uint8_t *>(to)), size);
6772
6873 } else if (type_from_to == std::make_tuple (micro_devtype, kDLCPU )) {
69- std::shared_ptr<LowLevelDevice> to_lld = session->low_level_device ();
70- to_lld->Write ((uint8_t *)(to) + to_offset,
71- (uint8_t *)(from) + from_offset, size);
74+ const std::shared_ptr<LowLevelDevice>& to_lld = session_->low_level_device ();
75+ to_lld->Write (
76+ const_cast <uint8_t *>(static_cast <const uint8_t *>(to)) + to_offset,
77+ const_cast <uint8_t *>(static_cast <const uint8_t *>(from)) + from_offset,
78+ size);
7279
7380 } else {
7481 LOG (FATAL) << " Expect copy from/to micro_dev or between micro_dev\n " ;
@@ -81,15 +88,13 @@ class MicroDeviceAPI final : public DeviceAPI {
8188
8289 // TODO: what about ctx?
8390 void * AllocWorkspace (TVMContext ctx, size_t size, TVMType type_hint) final {
84- std::shared_ptr<MicroSession> session = MicroSession::Global ();
85- void * alloc_ptr = session->AllocateInSection (kWorkspace , size);
91+ void * alloc_ptr = session_->AllocateInSection (kWorkspace , size);
8692 return alloc_ptr;
8793 }
8894
8995 // TODO: what about ctx?
9096 void FreeWorkspace (TVMContext ctx, void * data) final {
91- std::shared_ptr<MicroSession> session = MicroSession::Global ();
92- session->FreeInSection (kWorkspace , data);
97+ session_->FreeInSection (kWorkspace , data);
9398 }
9499
95100 /* !
@@ -101,6 +106,10 @@ class MicroDeviceAPI final : public DeviceAPI {
101106 std::make_shared<MicroDeviceAPI>();
102107 return inst;
103108 }
109+
110+ private:
111+ /* ! \brief pointer to global session */
112+ MicroSession* session_;
104113};
105114
106115// register device that can be obtained from Python frontend
0 commit comments