|
2 | 2 |
|
3 | 3 | #include <algorithm>
|
4 | 4 | #include <atomic>
|
| 5 | +#include <cmath> |
5 | 6 | #include <cstdint>
|
6 | 7 | #include <cstring>
|
7 | 8 | #include <exception>
|
@@ -46,6 +47,29 @@ struct mlxsharp_array final {
|
46 | 47 | : value(std::move(v)) {}
|
47 | 48 | };
|
48 | 49 |
|
| 50 | +struct mlxsharp_session final { |
| 51 | + std::atomic<int32_t> ref_count{1}; |
| 52 | + mlxsharp_context_t* context; |
| 53 | + std::string chat_model; |
| 54 | + std::string embedding_model; |
| 55 | + std::string image_model; |
| 56 | + |
| 57 | + mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image) |
| 58 | + : context(ctx), |
| 59 | + chat_model(std::move(chat)), |
| 60 | + embedding_model(std::move(embed)), |
| 61 | + image_model(std::move(image)) {} |
| 62 | +}; |
| 63 | + |
| 64 | +static void assign_usage(mlx_usage* usage, int input_tokens, int output_tokens) |
| 65 | +{ |
| 66 | + if (usage != nullptr) |
| 67 | + { |
| 68 | + usage->input_tokens = input_tokens; |
| 69 | + usage->output_tokens = output_tokens; |
| 70 | + } |
| 71 | +} |
| 72 | + |
49 | 73 | inline int set_error(int status, const char* message) {
|
50 | 74 | if (message != nullptr) {
|
51 | 75 | g_last_error = message;
|
@@ -263,6 +287,18 @@ mlxsharp_array_t* make_array_ptr(mlx::core::array array) {
|
263 | 287 | return handle;
|
264 | 288 | }
|
265 | 289 |
|
| 290 | +mlxsharp_session_t* make_session_ptr( |
| 291 | + mlxsharp_context_t* context, |
| 292 | + std::string chat_model, |
| 293 | + std::string embedding_model, |
| 294 | + std::string image_model) { |
| 295 | + auto* handle = new (std::nothrow) mlxsharp_session(context, std::move(chat_model), std::move(embedding_model), std::move(image_model)); |
| 296 | + if (handle == nullptr) { |
| 297 | + throw std::bad_alloc(); |
| 298 | + } |
| 299 | + return handle; |
| 300 | +} |
| 301 | + |
266 | 302 | mlx::core::Shape copy_shape(const int64_t* shape, int32_t rank) {
|
267 | 303 | if (rank < 0) {
|
268 | 304 | throw std::invalid_argument("Rank must be non-negative.");
|
@@ -294,6 +330,225 @@ void ensure_contiguous(const mlx::core::array& arr) {
|
294 | 330 |
|
295 | 331 | extern "C" {
|
296 | 332 |
|
| 333 | +int mlxsharp_create_session( |
| 334 | + const char* chat_model_id, |
| 335 | + const char* embedding_model_id, |
| 336 | + const char* image_model_id, |
| 337 | + void** session) { |
| 338 | + if (session == nullptr) { |
| 339 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session output pointer is null."); |
| 340 | + } |
| 341 | + |
| 342 | + return invoke([&]() -> int { |
| 343 | + auto chat = chat_model_id != nullptr ? std::string(chat_model_id) : std::string{}; |
| 344 | + auto embed = embedding_model_id != nullptr ? std::string(embedding_model_id) : std::string{}; |
| 345 | + auto image = image_model_id != nullptr ? std::string(image_model_id) : std::string{}; |
| 346 | + |
| 347 | + auto device = mlx::core::default_device(); |
| 348 | + auto* context = make_context_ptr(device); |
| 349 | + auto* handle = make_session_ptr(context, std::move(chat), std::move(embed), std::move(image)); |
| 350 | + *session = handle; |
| 351 | + return MLXSHARP_STATUS_SUCCESS; |
| 352 | + }); |
| 353 | +} |
| 354 | + |
| 355 | +int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** response, mlx_usage* usage) { |
| 356 | + if (session_ptr == nullptr || response == nullptr) { |
| 357 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session or response pointer is null."); |
| 358 | + } |
| 359 | + |
| 360 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 361 | + |
| 362 | + return invoke([&]() -> int { |
| 363 | + const std::string input = prompt != nullptr ? std::string(prompt) : std::string{}; |
| 364 | + const size_t length = input.size(); |
| 365 | + |
| 366 | + mlx::core::set_default_device(session->context->device); |
| 367 | + |
| 368 | + std::vector<float> values; |
| 369 | + values.reserve(length > 0 ? length : 1); |
| 370 | + if (length == 0) { |
| 371 | + values.push_back(0.0f); |
| 372 | + } else { |
| 373 | + for (unsigned char ch : input) { |
| 374 | + values.push_back(static_cast<float>(ch)); |
| 375 | + } |
| 376 | + } |
| 377 | + |
| 378 | + auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; |
| 379 | + auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); |
| 380 | + auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3)); |
| 381 | + auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); |
| 382 | + auto transformed = mlx::core::sin(divided); |
| 383 | + transformed.eval(); |
| 384 | + transformed.wait(); |
| 385 | + ensure_contiguous(transformed); |
| 386 | + |
| 387 | + std::vector<float> buffer(transformed.size()); |
| 388 | + copy_to_buffer(transformed, buffer.data(), buffer.size()); |
| 389 | + |
| 390 | + std::string output; |
| 391 | + output.reserve(buffer.size()); |
| 392 | + for (float value : buffer) { |
| 393 | + const float normalized = std::fabs(value); |
| 394 | + const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94; |
| 395 | + output.push_back(static_cast<char>(32 + code)); |
| 396 | + } |
| 397 | + |
| 398 | + if (output.empty()) { |
| 399 | + output = ""; |
| 400 | + } |
| 401 | + |
| 402 | + auto* data = static_cast<char*>(std::malloc(output.size() + 1)); |
| 403 | + if (data == nullptr) { |
| 404 | + return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory."); |
| 405 | + } |
| 406 | + |
| 407 | + std::memcpy(data, output.data(), output.size()); |
| 408 | + data[output.size()] = '\0'; |
| 409 | + |
| 410 | + *response = data; |
| 411 | + assign_usage(usage, static_cast<int>(length), static_cast<int>(output.size())); |
| 412 | + return MLXSHARP_STATUS_SUCCESS; |
| 413 | + }); |
| 414 | +} |
| 415 | + |
| 416 | +int mlxsharp_generate_embedding( |
| 417 | + void* session_ptr, |
| 418 | + const char* text, |
| 419 | + float** embedding, |
| 420 | + int* dimension, |
| 421 | + mlx_usage* usage) { |
| 422 | + if (session_ptr == nullptr || embedding == nullptr || dimension == nullptr) { |
| 423 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Embedding output or session pointer is null."); |
| 424 | + } |
| 425 | + |
| 426 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 427 | + |
| 428 | + return invoke([&]() -> int { |
| 429 | + const std::string input = text != nullptr ? std::string(text) : std::string{}; |
| 430 | + const size_t length = input.size(); |
| 431 | + |
| 432 | + mlx::core::set_default_device(session->context->device); |
| 433 | + |
| 434 | + std::vector<float> values; |
| 435 | + values.reserve(length > 0 ? length : 1); |
| 436 | + if (length == 0) { |
| 437 | + values.push_back(0.0f); |
| 438 | + } else { |
| 439 | + for (unsigned char ch : input) { |
| 440 | + values.push_back(static_cast<float>(ch)); |
| 441 | + } |
| 442 | + } |
| 443 | + |
| 444 | + auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; |
| 445 | + auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); |
| 446 | + auto scale = mlx::core::array(static_cast<float>((values.size() % 23) + 5)); |
| 447 | + auto normalized = mlx::core::divide(arr, scale); |
| 448 | + auto sine = mlx::core::sin(normalized); |
| 449 | + auto cosine = mlx::core::cos(normalized); |
| 450 | + auto square = mlx::core::multiply(normalized, normalized); |
| 451 | + auto features = mlx::core::stack({ |
| 452 | + mlx::core::sum(normalized), |
| 453 | + mlx::core::sum(sine), |
| 454 | + mlx::core::sum(cosine), |
| 455 | + mlx::core::sum(mlx::core::abs(normalized)), |
| 456 | + mlx::core::sum(square), |
| 457 | + mlx::core::sum(mlx::core::sin(square)), |
| 458 | + mlx::core::sum(mlx::core::cos(square)), |
| 459 | + mlx::core::sum(mlx::core::sin(mlx::core::add(normalized, sine))) |
| 460 | + }); |
| 461 | + |
| 462 | + features.eval(); |
| 463 | + features.wait(); |
| 464 | + ensure_contiguous(features); |
| 465 | + |
| 466 | + std::vector<float> host(features.size()); |
| 467 | + copy_to_buffer(features, host.data(), host.size()); |
| 468 | + |
| 469 | + const int dims = static_cast<int>(host.size()); |
| 470 | + auto* buffer = static_cast<float*>(std::malloc(sizeof(float) * static_cast<size_t>(dims))); |
| 471 | + if (buffer == nullptr) { |
| 472 | + return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory."); |
| 473 | + } |
| 474 | + |
| 475 | + std::memcpy(buffer, host.data(), sizeof(float) * static_cast<size_t>(dims)); |
| 476 | + *embedding = buffer; |
| 477 | + *dimension = dims; |
| 478 | + assign_usage(usage, static_cast<int>(length), dims); |
| 479 | + return MLXSHARP_STATUS_SUCCESS; |
| 480 | + }); |
| 481 | +} |
| 482 | + |
| 483 | +int mlxsharp_generate_image( |
| 484 | + void* session_ptr, |
| 485 | + const char* prompt, |
| 486 | + int width, |
| 487 | + int height, |
| 488 | + unsigned char** buffer, |
| 489 | + int* length, |
| 490 | + mlx_usage* usage) { |
| 491 | + if (session_ptr == nullptr || buffer == nullptr || length == nullptr) { |
| 492 | + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Image output pointer is null."); |
| 493 | + } |
| 494 | + |
| 495 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 496 | + |
| 497 | + return invoke([&]() -> int { |
| 498 | + const std::string input = prompt != nullptr ? std::string(prompt) : std::string{}; |
| 499 | + const int w = width > 0 ? width : 16; |
| 500 | + const int h = height > 0 ? height : 16; |
| 501 | + const size_t total = static_cast<size_t>(w) * static_cast<size_t>(h); |
| 502 | + |
| 503 | + mlx::core::set_default_device(session->context->device); |
| 504 | + |
| 505 | + auto indices = mlx::core::arange(static_cast<int>(total), mlx::core::float32); |
| 506 | + auto scale = mlx::core::array(static_cast<float>((input.length() % 29) + 7)); |
| 507 | + auto pattern = mlx::core::abs(mlx::core::sin(mlx::core::divide(indices, scale))); |
| 508 | + pattern.eval(); |
| 509 | + pattern.wait(); |
| 510 | + ensure_contiguous(pattern); |
| 511 | + |
| 512 | + std::vector<float> host(pattern.size()); |
| 513 | + copy_to_buffer(pattern, host.data(), host.size()); |
| 514 | + |
| 515 | + auto* data = static_cast<unsigned char*>(std::malloc(host.size())); |
| 516 | + if (data == nullptr) { |
| 517 | + return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory."); |
| 518 | + } |
| 519 | + |
| 520 | + for (size_t i = 0; i < host.size(); ++i) { |
| 521 | + const float value = std::clamp(host[i], 0.0f, 1.0f); |
| 522 | + data[i] = static_cast<unsigned char>(value * 255.0f); |
| 523 | + } |
| 524 | + |
| 525 | + *buffer = data; |
| 526 | + *length = static_cast<int>(host.size()); |
| 527 | + assign_usage(usage, static_cast<int>(input.size()), static_cast<int>(host.size())); |
| 528 | + return MLXSHARP_STATUS_SUCCESS; |
| 529 | + }); |
| 530 | +} |
| 531 | + |
| 532 | +void mlxsharp_free_embedding(float* embedding) { |
| 533 | + std::free(embedding); |
| 534 | +} |
| 535 | + |
| 536 | +void mlxsharp_free_buffer(unsigned char* data) { |
| 537 | + std::free(data); |
| 538 | +} |
| 539 | + |
| 540 | +void mlxsharp_release_session(void* session_ptr) { |
| 541 | + if (session_ptr == nullptr) { |
| 542 | + return; |
| 543 | + } |
| 544 | + |
| 545 | + auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
| 546 | + if (session->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { |
| 547 | + mlxsharp_context_release(session->context); |
| 548 | + delete session; |
| 549 | + } |
| 550 | +} |
| 551 | + |
297 | 552 | int mlxsharp_get_last_error(char* buffer, size_t length) {
|
298 | 553 | const auto size = g_last_error.size();
|
299 | 554 |
|
|
0 commit comments