11
11
12
12
#import < Accelerate/Accelerate.h>
13
13
14
- #if PY_MAJOR_VERSION < 3
15
-
16
- #pragma clang diagnostic push
17
- #pragma clang diagnostic ignored "-Wmacro-redefined"
18
- #define PyBytes_Check (name ) PyString_Check(name)
19
- #pragma clang diagnostic pop
20
- #define PyAnyInteger_Check (name ) (PyLong_Check(name) || PyInt_Check(name))
21
-
22
- #else
23
-
24
14
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
25
15
#include < numpy/arrayobject.h>
26
- #define PyAnyInteger_Check (name ) (PyLong_Check(name) || (_import_array(), PyArray_IsScalar(name, Integer)))
27
16
28
- #endif
17
+ #define PyAnyInteger_Check ( name ) (PyLong_Check(name) || (_import_array(), PyArray_IsScalar(name, Integer)))
29
18
30
19
using namespace CoreML ::Python;
31
20
59
48
60
49
@autoreleasepool {
61
50
NSMutableDictionary <NSString *, NSObject *> *inputDict = [[NSMutableDictionary <NSString *, NSObject *> alloc] init ];
62
-
51
+
63
52
for (const auto element : dict) {
64
53
std::string key = element.first .cast <std::string>();
65
54
NSString *nsKey = [NSString stringWithUTF8String: key.c_str ()];
66
55
id nsValue = Utils::convertValueToObjC (element.second );
67
56
inputDict[nsKey] = nsValue;
68
57
}
69
-
58
+
70
59
feautreProvider = [[MLDictionaryFeatureProvider alloc ] initWithDictionary: inputDict error: &localError];
71
60
}
72
61
134
123
}
135
124
136
125
static MLFeatureValue * convertValueToDictionary (const py::handle& handle) {
137
-
126
+
138
127
if (!PyDict_Check (handle.ptr ())) {
139
128
throw std::runtime_error (" Not a dictionary." );
140
129
}
141
-
130
+
142
131
// Get the first value in the dictionary; use that as a hint.
143
132
PyObject *key = nullptr , *value = nullptr ;
144
133
Py_ssize_t pos = 0 ;
145
-
134
+
146
135
int has_values = PyDict_Next (handle.ptr (), &pos, &key, &value);
147
-
136
+
148
137
// Is it an empty dict? If so, just return an empty dictionary.
149
138
if (!has_values) {
150
139
return [MLFeatureValue featureValueWithDictionary: @{} error: nullptr ];
@@ -222,12 +211,12 @@ static void handleCVReturn(CVReturn status) {
222
211
static MLFeatureValue * convertValueToImage (const py::handle& handle) {
223
212
// assumes handle is a valid PIL image!
224
213
CVPixelBufferRef pixelBuffer = nil ;
225
-
214
+
226
215
size_t width = handle.attr (" width" ).cast <size_t >();
227
216
size_t height = handle.attr (" height" ).cast <size_t >();
228
217
OSType format;
229
218
std::string formatStr = handle.attr (" mode" ).cast <std::string>();
230
-
219
+
231
220
if (formatStr == " RGB" ) {
232
221
format = kCVPixelFormatType_32BGRA ;
233
222
} else if (formatStr == " RGBA" ) {
@@ -242,18 +231,18 @@ static void handleCVReturn(CVReturn status) {
242
231
msg << " Supported types are: RGB, RGBA, L." ;
243
232
throw std::runtime_error (msg.str ());
244
233
}
245
-
234
+
246
235
CVReturn status = CVPixelBufferCreate (kCFAllocatorDefault , width, height, format, NULL , &pixelBuffer);
247
236
handleCVReturn (status);
248
-
237
+
249
238
// get bytes out of the PIL image
250
239
py::object tobytes = handle.attr (" tobytes" );
251
240
py::object bytesResult = tobytes ();
252
241
assert (PyBytes_Check (bytesResult.ptr ()));
253
242
Py_ssize_t bytesLength = PyBytes_Size (bytesResult.ptr ());
254
243
assert (bytesLength >= 0 );
255
244
const char *bytesPtr = PyBytes_AsString (bytesResult.ptr ());
256
-
245
+
257
246
// copy data into the CVPixelBuffer
258
247
status = CVPixelBufferLockBaseAddress (pixelBuffer, 0 );
259
248
handleCVReturn (status);
@@ -268,33 +257,33 @@ static void handleCVReturn(CVReturn status) {
268
257
srcBuffer.data = const_cast <char *>(srcPointer);
269
258
srcBuffer.width = width;
270
259
srcBuffer.height = height;
271
-
260
+
272
261
vImage_Buffer dstBuffer;
273
262
memset (&dstBuffer, 0 , sizeof (dstBuffer));
274
263
dstBuffer.data = baseAddress;
275
264
dstBuffer.width = width;
276
265
dstBuffer.height = height;
277
266
278
267
if (formatStr == " RGB" ) {
279
-
268
+
280
269
// convert RGB to BGRA
281
270
assert (bytesLength == width * height * 3 );
282
271
283
272
srcBuffer.rowBytes = width * 3 ;
284
273
dstBuffer.rowBytes = bytesPerRow;
285
274
vImageConvert_RGB888toBGRA8888 (&srcBuffer, NULL , 255 , &dstBuffer, false , 0 );
286
-
275
+
287
276
} else if (formatStr == " RGBA" ) {
288
-
277
+
289
278
// convert RGBA to BGRA
290
279
assert (bytesLength == width * height * 4 );
291
280
srcBuffer.rowBytes = width * 4 ;
292
281
dstBuffer.rowBytes = bytesPerRow;
293
282
uint8_t permuteMap[4 ] = { 2 , 1 , 0 , 3 };
294
283
vImagePermuteChannels_ARGB8888 (&srcBuffer, &dstBuffer, permuteMap, 0 );
295
-
284
+
296
285
} else if (formatStr == " L" ) {
297
-
286
+
298
287
// 8 bit grayscale.
299
288
assert (bytesLength == width * height);
300
289
@@ -303,7 +292,7 @@ static void handleCVReturn(CVReturn status) {
303
292
vImageCopyBuffer (&srcBuffer, &dstBuffer, 1 , 0 );
304
293
305
294
} else if (formatStr == " F" ) {
306
-
295
+
307
296
// convert Float32 to Float16.
308
297
assert (bytesLength == width * height * sizeof (Float32));
309
298
@@ -317,14 +306,14 @@ static void handleCVReturn(CVReturn status) {
317
306
msg << " Supported types are: RGB, RGBA, L." ;
318
307
throw std::runtime_error (msg.str ());
319
308
}
320
-
309
+
321
310
#ifdef COREML_SHOW_PIL_IMAGES
322
311
if (formatStr == " RGB" ) {
323
312
// for debugging purposes, convert back to PIL image and show it
324
313
py::object scope = py::module::import (" __main__" ).attr (" __dict__" );
325
314
py::eval<py::eval_single_statement>(" import PIL.Image" , scope);
326
315
py::object pilImage = py::eval<py::eval_expr>(" PIL.Image" );
327
-
316
+
328
317
std::string cvPixelStr (count, 0 );
329
318
const char *basePtr = static_cast <char *>(baseAddress);
330
319
for (size_t row = 0 ; row < height; row++) {
@@ -334,7 +323,7 @@ static void handleCVReturn(CVReturn status) {
334
323
}
335
324
}
336
325
}
337
-
326
+
338
327
py::bytes cvPixelBytes = py::bytes (cvPixelStr);
339
328
py::object frombytes = pilImage.attr (" frombytes" );
340
329
py::str mode = " RGB" ;
@@ -343,62 +332,62 @@ static void handleCVReturn(CVReturn status) {
343
332
img.attr (" show" )();
344
333
}
345
334
#endif
346
-
335
+
347
336
status = CVPixelBufferUnlockBaseAddress (pixelBuffer, 0 );
348
337
handleCVReturn (status);
349
338
350
339
MLFeatureValue *fv = [MLFeatureValue featureValueWithPixelBuffer: pixelBuffer];
351
340
CVPixelBufferRelease (pixelBuffer);
352
-
341
+
353
342
return fv;
354
343
}
355
344
356
345
static bool IsPILImage (const py::handle& handle) {
357
346
// TODO put try/catch around this?
358
-
347
+
359
348
try {
360
349
py::module::import (" PIL.Image" );
361
350
} catch (...) {
362
351
return false ;
363
352
}
364
-
353
+
365
354
py::object scope = py::module::import (" __main__" ).attr (" __dict__" );
366
355
py::eval<py::eval_single_statement>(" import PIL.Image" , scope);
367
356
py::handle imageTypeHandle = py::eval<py::eval_expr>(" PIL.Image.Image" , scope);
368
357
assert (PyType_Check (imageTypeHandle.ptr ())); // should be a Python type
369
-
358
+
370
359
return PyObject_TypeCheck (handle.ptr (), (PyTypeObject *)(imageTypeHandle.ptr ()));
371
360
}
372
361
373
362
MLFeatureValue * Utils::convertValueToObjC (const py::handle& handle) {
374
-
363
+
375
364
if (PyAnyInteger_Check (handle.ptr ())) {
376
365
try {
377
366
int64_t val = handle.cast <int64_t >();
378
367
return [MLFeatureValue featureValueWithInt64: val];
379
368
} catch (...) {}
380
369
}
381
-
370
+
382
371
if (PyFloat_Check (handle.ptr ())) {
383
372
try {
384
373
double val = handle.cast <double >();
385
374
return [MLFeatureValue featureValueWithDouble: val];
386
375
} catch (...) {}
387
376
}
388
-
377
+
389
378
if (PyBytes_Check (handle.ptr ()) || PyUnicode_Check (handle.ptr ())) {
390
379
try {
391
380
std::string val = handle.cast <std::string>();
392
381
return [MLFeatureValue featureValueWithString: [NSString stringWithUTF8String: val.c_str ()]];
393
382
} catch (...) {}
394
383
}
395
-
384
+
396
385
if (PyDict_Check (handle.ptr ())) {
397
386
try {
398
387
return convertValueToDictionary (handle);
399
388
} catch (...) {}
400
389
}
401
-
390
+
402
391
if (PyList_Check (handle.ptr ()) || PyTuple_Check (handle.ptr ())) {
403
392
try {
404
393
return convertValueToSequence (handle);
@@ -410,11 +399,11 @@ static bool IsPILImage(const py::handle& handle) {
410
399
return convertValueToArray (handle);
411
400
} catch (...) {}
412
401
}
413
-
402
+
414
403
if (IsPILImage (handle)) {
415
404
return convertValueToImage (handle);
416
405
}
417
-
406
+
418
407
py::print (" Error: value type not convertible:" );
419
408
py::print (handle);
420
409
throw std::runtime_error (" value type not convertible" );
@@ -474,6 +463,11 @@ static size_t sizeOfArrayElement(MLMultiArrayDataType type) {
474
463
__block py::object array;
475
464
[value getBytesWithHandler: ^(const void *bytes, NSInteger size) {
476
465
switch (type) {
466
+ #if BUILT_WITH_MACOS26_SDK
467
+ case MLMultiArrayDataTypeInt8:
468
+ array = py::array (shape, strides, reinterpret_cast <const int8_t *>(bytes));
469
+ break ;
470
+ #endif
477
471
case MLMultiArrayDataTypeInt32:
478
472
array = py::array (shape, strides, reinterpret_cast <const int32_t *>(bytes));
479
473
break ;
@@ -508,7 +502,7 @@ static size_t sizeOfArrayElement(MLMultiArrayDataType type) {
508
502
NSString *nskey = static_cast <NSString *>(key);
509
503
pykey = py::str ([nskey UTF8String ]);
510
504
}
511
-
505
+
512
506
NSNumber *value = dict[key];
513
507
ret[pykey] = py::float_ ([value doubleValue ]);
514
508
}
@@ -519,7 +513,7 @@ static size_t sizeOfArrayElement(MLMultiArrayDataType type) {
519
513
if (CVPixelBufferIsPlanar (value)) {
520
514
throw std::runtime_error (" Only non-planar CVPixelBuffers are currently supported by this Python binding." );
521
515
}
522
-
516
+
523
517
// supports grayscale and BGRA format types
524
518
auto formatType = CVPixelBufferGetPixelFormatType (value);
525
519
assert (formatType == kCVPixelFormatType_32BGRA
@@ -553,10 +547,10 @@ static size_t sizeOfArrayElement(MLMultiArrayDataType type) {
553
547
554
548
auto result = CVPixelBufferLockBaseAddress (value, kCVPixelBufferLock_ReadOnly );
555
549
assert (result == kCVReturnSuccess );
556
-
550
+
557
551
uint8_t *src = reinterpret_cast <uint8_t *>(CVPixelBufferGetBaseAddress (value));
558
552
assert (src != nullptr );
559
-
553
+
560
554
size_t srcBytesPerRow = CVPixelBufferGetBytesPerRow (value);
561
555
562
556
// Prepare for vImage blitting
@@ -587,10 +581,10 @@ static size_t sizeOfArrayElement(MLMultiArrayDataType type) {
587
581
msg << " Unsupported pixel format type: " << std::hex << std::setfill (' 0' ) << std::setw (4 ) << formatType << " . " ;
588
582
throw std::runtime_error (msg.str ());
589
583
}
590
-
584
+
591
585
result = CVPixelBufferUnlockBaseAddress (value, kCVPixelBufferLock_ReadOnly );
592
586
assert (result == kCVReturnSuccess );
593
-
587
+
594
588
py::object scope = py::module::import (" __main__" ).attr (" __dict__" );
595
589
py::eval<py::eval_single_statement>(" import PIL.Image" , scope);
596
590
py::object pilImage = py::eval<py::eval_expr>(" PIL.Image" , scope);
0 commit comments