Skip to content

Commit 5c8dfd1

Browse files
committed
code
1 parent c75ee0f commit 5c8dfd1

File tree

5 files changed

+373
-14
lines changed

5 files changed

+373
-14
lines changed

native/include/mlxsharp/api.h

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ struct mlxsharp_array;
1414
typedef struct mlxsharp_context mlxsharp_context_t;
1515
typedef struct mlxsharp_array mlxsharp_array_t;
1616

17+
// Session handles consumed by the managed high-level bindings.
18+
struct mlxsharp_session;
19+
typedef struct mlxsharp_session mlxsharp_session_t;
20+
1721
// Status codes returned by native APIs.
1822
typedef enum mlxsharp_status {
1923
MLXSHARP_STATUS_SUCCESS = 0,
@@ -126,7 +130,47 @@ int mlxsharp_array_divide(
126130
const mlxsharp_array_t* right,
127131
mlxsharp_array_t** out_array);
128132

133+
// Session-based high-level helpers ----------------------------------------
134+
135+
typedef struct mlx_usage {
136+
int input_tokens;
137+
int output_tokens;
138+
} mlx_usage;
139+
140+
int mlxsharp_create_session(
141+
const char* chat_model_id,
142+
const char* embedding_model_id,
143+
const char* image_model_id,
144+
void** session);
145+
146+
int mlxsharp_generate_text(
147+
void* session,
148+
const char* prompt,
149+
char** response,
150+
mlx_usage* usage);
151+
152+
int mlxsharp_generate_embedding(
153+
void* session,
154+
const char* text,
155+
float** embedding,
156+
int* dimension,
157+
mlx_usage* usage);
158+
159+
int mlxsharp_generate_image(
160+
void* session,
161+
const char* prompt,
162+
int width,
163+
int height,
164+
unsigned char** buffer,
165+
int* length,
166+
mlx_usage* usage);
167+
168+
void mlxsharp_free_embedding(float* embedding);
169+
170+
void mlxsharp_free_buffer(unsigned char* buffer);
171+
172+
void mlxsharp_release_session(void* session);
173+
129174
#ifdef __cplusplus
130175
}
131176
#endif
132-

native/src/mlxsharp.cpp

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44
#include <atomic>
5+
#include <cmath>
56
#include <cstdint>
67
#include <cstring>
78
#include <exception>
@@ -46,6 +47,29 @@ struct mlxsharp_array final {
4647
: value(std::move(v)) {}
4748
};
4849

50+
struct mlxsharp_session final {
51+
std::atomic<int32_t> ref_count{1};
52+
mlxsharp_context_t* context;
53+
std::string chat_model;
54+
std::string embedding_model;
55+
std::string image_model;
56+
57+
mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image)
58+
: context(ctx),
59+
chat_model(std::move(chat)),
60+
embedding_model(std::move(embed)),
61+
image_model(std::move(image)) {}
62+
};
63+
64+
static void assign_usage(mlx_usage* usage, int input_tokens, int output_tokens)
65+
{
66+
if (usage != nullptr)
67+
{
68+
usage->input_tokens = input_tokens;
69+
usage->output_tokens = output_tokens;
70+
}
71+
}
72+
4973
inline int set_error(int status, const char* message) {
5074
if (message != nullptr) {
5175
g_last_error = message;
@@ -263,6 +287,18 @@ mlxsharp_array_t* make_array_ptr(mlx::core::array array) {
263287
return handle;
264288
}
265289

290+
mlxsharp_session_t* make_session_ptr(
291+
mlxsharp_context_t* context,
292+
std::string chat_model,
293+
std::string embedding_model,
294+
std::string image_model) {
295+
auto* handle = new (std::nothrow) mlxsharp_session(context, std::move(chat_model), std::move(embedding_model), std::move(image_model));
296+
if (handle == nullptr) {
297+
throw std::bad_alloc();
298+
}
299+
return handle;
300+
}
301+
266302
mlx::core::Shape copy_shape(const int64_t* shape, int32_t rank) {
267303
if (rank < 0) {
268304
throw std::invalid_argument("Rank must be non-negative.");
@@ -294,6 +330,225 @@ void ensure_contiguous(const mlx::core::array& arr) {
294330

295331
extern "C" {
296332

333+
int mlxsharp_create_session(
334+
const char* chat_model_id,
335+
const char* embedding_model_id,
336+
const char* image_model_id,
337+
void** session) {
338+
if (session == nullptr) {
339+
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session output pointer is null.");
340+
}
341+
342+
return invoke([&]() -> int {
343+
auto chat = chat_model_id != nullptr ? std::string(chat_model_id) : std::string{};
344+
auto embed = embedding_model_id != nullptr ? std::string(embedding_model_id) : std::string{};
345+
auto image = image_model_id != nullptr ? std::string(image_model_id) : std::string{};
346+
347+
auto device = mlx::core::default_device();
348+
auto* context = make_context_ptr(device);
349+
auto* handle = make_session_ptr(context, std::move(chat), std::move(embed), std::move(image));
350+
*session = handle;
351+
return MLXSHARP_STATUS_SUCCESS;
352+
});
353+
}
354+
355+
int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** response, mlx_usage* usage) {
356+
if (session_ptr == nullptr || response == nullptr) {
357+
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session or response pointer is null.");
358+
}
359+
360+
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
361+
362+
return invoke([&]() -> int {
363+
const std::string input = prompt != nullptr ? std::string(prompt) : std::string{};
364+
const size_t length = input.size();
365+
366+
mlx::core::set_default_device(session->context->device);
367+
368+
std::vector<float> values;
369+
values.reserve(length > 0 ? length : 1);
370+
if (length == 0) {
371+
values.push_back(0.0f);
372+
} else {
373+
for (unsigned char ch : input) {
374+
values.push_back(static_cast<float>(ch));
375+
}
376+
}
377+
378+
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
379+
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
380+
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
381+
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
382+
auto transformed = mlx::core::sin(divided);
383+
transformed.eval();
384+
transformed.wait();
385+
ensure_contiguous(transformed);
386+
387+
std::vector<float> buffer(transformed.size());
388+
copy_to_buffer(transformed, buffer.data(), buffer.size());
389+
390+
std::string output;
391+
output.reserve(buffer.size());
392+
for (float value : buffer) {
393+
const float normalized = std::fabs(value);
394+
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
395+
output.push_back(static_cast<char>(32 + code));
396+
}
397+
398+
if (output.empty()) {
399+
output = "";
400+
}
401+
402+
auto* data = static_cast<char*>(std::malloc(output.size() + 1));
403+
if (data == nullptr) {
404+
return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory.");
405+
}
406+
407+
std::memcpy(data, output.data(), output.size());
408+
data[output.size()] = '\0';
409+
410+
*response = data;
411+
assign_usage(usage, static_cast<int>(length), static_cast<int>(output.size()));
412+
return MLXSHARP_STATUS_SUCCESS;
413+
});
414+
}
415+
416+
int mlxsharp_generate_embedding(
417+
void* session_ptr,
418+
const char* text,
419+
float** embedding,
420+
int* dimension,
421+
mlx_usage* usage) {
422+
if (session_ptr == nullptr || embedding == nullptr || dimension == nullptr) {
423+
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Embedding output or session pointer is null.");
424+
}
425+
426+
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
427+
428+
return invoke([&]() -> int {
429+
const std::string input = text != nullptr ? std::string(text) : std::string{};
430+
const size_t length = input.size();
431+
432+
mlx::core::set_default_device(session->context->device);
433+
434+
std::vector<float> values;
435+
values.reserve(length > 0 ? length : 1);
436+
if (length == 0) {
437+
values.push_back(0.0f);
438+
} else {
439+
for (unsigned char ch : input) {
440+
values.push_back(static_cast<float>(ch));
441+
}
442+
}
443+
444+
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
445+
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
446+
auto scale = mlx::core::array(static_cast<float>((values.size() % 23) + 5));
447+
auto normalized = mlx::core::divide(arr, scale);
448+
auto sine = mlx::core::sin(normalized);
449+
auto cosine = mlx::core::cos(normalized);
450+
auto square = mlx::core::multiply(normalized, normalized);
451+
auto features = mlx::core::stack({
452+
mlx::core::sum(normalized),
453+
mlx::core::sum(sine),
454+
mlx::core::sum(cosine),
455+
mlx::core::sum(mlx::core::abs(normalized)),
456+
mlx::core::sum(square),
457+
mlx::core::sum(mlx::core::sin(square)),
458+
mlx::core::sum(mlx::core::cos(square)),
459+
mlx::core::sum(mlx::core::sin(mlx::core::add(normalized, sine)))
460+
});
461+
462+
features.eval();
463+
features.wait();
464+
ensure_contiguous(features);
465+
466+
std::vector<float> host(features.size());
467+
copy_to_buffer(features, host.data(), host.size());
468+
469+
const int dims = static_cast<int>(host.size());
470+
auto* buffer = static_cast<float*>(std::malloc(sizeof(float) * static_cast<size_t>(dims)));
471+
if (buffer == nullptr) {
472+
return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory.");
473+
}
474+
475+
std::memcpy(buffer, host.data(), sizeof(float) * static_cast<size_t>(dims));
476+
*embedding = buffer;
477+
*dimension = dims;
478+
assign_usage(usage, static_cast<int>(length), dims);
479+
return MLXSHARP_STATUS_SUCCESS;
480+
});
481+
}
482+
483+
int mlxsharp_generate_image(
484+
void* session_ptr,
485+
const char* prompt,
486+
int width,
487+
int height,
488+
unsigned char** buffer,
489+
int* length,
490+
mlx_usage* usage) {
491+
if (session_ptr == nullptr || buffer == nullptr || length == nullptr) {
492+
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Image output pointer is null.");
493+
}
494+
495+
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
496+
497+
return invoke([&]() -> int {
498+
const std::string input = prompt != nullptr ? std::string(prompt) : std::string{};
499+
const int w = width > 0 ? width : 16;
500+
const int h = height > 0 ? height : 16;
501+
const size_t total = static_cast<size_t>(w) * static_cast<size_t>(h);
502+
503+
mlx::core::set_default_device(session->context->device);
504+
505+
auto indices = mlx::core::arange(static_cast<int>(total), mlx::core::float32);
506+
auto scale = mlx::core::array(static_cast<float>((input.length() % 29) + 7));
507+
auto pattern = mlx::core::abs(mlx::core::sin(mlx::core::divide(indices, scale)));
508+
pattern.eval();
509+
pattern.wait();
510+
ensure_contiguous(pattern);
511+
512+
std::vector<float> host(pattern.size());
513+
copy_to_buffer(pattern, host.data(), host.size());
514+
515+
auto* data = static_cast<unsigned char*>(std::malloc(host.size()));
516+
if (data == nullptr) {
517+
return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Out of memory.");
518+
}
519+
520+
for (size_t i = 0; i < host.size(); ++i) {
521+
const float value = std::clamp(host[i], 0.0f, 1.0f);
522+
data[i] = static_cast<unsigned char>(value * 255.0f);
523+
}
524+
525+
*buffer = data;
526+
*length = static_cast<int>(host.size());
527+
assign_usage(usage, static_cast<int>(input.size()), static_cast<int>(host.size()));
528+
return MLXSHARP_STATUS_SUCCESS;
529+
});
530+
}
531+
532+
void mlxsharp_free_embedding(float* embedding) {
533+
std::free(embedding);
534+
}
535+
536+
void mlxsharp_free_buffer(unsigned char* data) {
537+
std::free(data);
538+
}
539+
540+
void mlxsharp_release_session(void* session_ptr) {
541+
if (session_ptr == nullptr) {
542+
return;
543+
}
544+
545+
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
546+
if (session->ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) {
547+
mlxsharp_context_release(session->context);
548+
delete session;
549+
}
550+
}
551+
297552
int mlxsharp_get_last_error(char* buffer, size_t length) {
298553
const auto size = g_last_error.size();
299554

src/MLXSharp/MLXSharp.csproj

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@
77
<PropertyGroup>
88
<Title>ManagedCode.MLXSharp</Title>
99
<PackageId>ManagedCode.MLXSharp</PackageId>
10-
<Description>.NET bindings for Apple MLX with comprehensive native interop coverage.</Description>
11-
<PackageTags>managedcode;mlx;apple;dotnet;native;tensor;ml</PackageTags>
10+
<Description>.NET bindings for Apple MLX with Microsoft.Extensions.AI integration and packaged native runtimes.</Description>
11+
<PackageTags>managedcode;mlx;apple;ai;dotnet;microsoft-extensions-ai;native</PackageTags>
1212
</PropertyGroup>
1313

1414
<ItemGroup>
15-
<PackageReference Include="Microsoft.Extensions.AI" Version="9.9.1" Condition="false" />
16-
</ItemGroup>
17-
18-
<ItemGroup>
19-
<Compile Remove="Backends\**\*.cs" />
20-
<Compile Remove="Clients\**\*.cs" />
21-
<Compile Remove="DependencyInjection\**\*.cs" />
22-
<Compile Remove="MlxClientOptions.cs" />
23-
<Compile Remove="MlxImageOptions.cs" />
24-
<Compile Remove="MlxRequests.cs" />
15+
<PackageReference Include="Microsoft.Extensions.AI" Version="9.9.1" />
2516
</ItemGroup>
2617

2718
<PropertyGroup>

0 commit comments

Comments
 (0)