Skip to content

Commit

Permalink
[metal] update language version (#7116)
Browse files Browse the repository at this point in the history
* [metal] update language version

* fix mps
  • Loading branch information
antinucleon authored Dec 16, 2020
1 parent 7a20b4a commit 0a3e178
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 6 additions & 3 deletions src/runtime/contrib/mps/conv.mm
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0,
[mtlbuf length], buf -> ctx, buf -> ctx, nullptr);
[mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype,
nullptr);

MPSImageDescriptor* desc =
[MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
Expand Down Expand Up @@ -69,7 +70,8 @@
imageIndex:0];

entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0,
[mtlbuf length], buf -> ctx, buf -> ctx, nullptr);
[mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype,
nullptr);
});

TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down Expand Up @@ -111,7 +113,8 @@
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0,
[bufB length], weight -> ctx, weight -> ctx, nullptr);
[bufB length], weight -> ctx, weight -> ctx, tmp_in.dtype,
nullptr);
float* ptr_w = (float*)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
if (e.lib == nil) {
if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
// Use the Metal 1.2 for now.
opts.languageVersion = MTLLanguageVersion1_2;
opts.languageVersion = MTLLanguageVersion2_3;
opts.fastMathEnabled = YES;
// opts = nil;
e.lib = [w->devices[device_id]
Expand Down

0 comments on commit 0a3e178

Please sign in to comment.