Skip to content

Commit d7ae5fb

Browse files
committed
model
1 parent ae516e0 commit d7ae5fb

File tree

9 files changed

+79
-498
lines changed

9 files changed

+79
-498
lines changed

native/include/mlxsharp/api.h

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,6 @@ typedef struct mlx_usage {
137137
int output_tokens;
138138
} mlx_usage;
139139

140-
typedef struct mlxsharp_generation_options {
141-
int max_tokens;
142-
float temperature;
143-
float top_p;
144-
int top_k;
145-
} mlxsharp_generation_options;
146-
147-
typedef struct mlxsharp_token_buffer {
148-
int32_t* tokens;
149-
size_t length;
150-
} mlxsharp_token_buffer;
151-
152140
int mlxsharp_create_session(
153141
const char* chat_model_id,
154142
const char* embedding_model_id,
@@ -183,21 +171,6 @@ void mlxsharp_free_buffer(unsigned char* buffer);
183171

184172
void mlxsharp_release_session(void* session);
185173

186-
int mlxsharp_session_load_model(
187-
void* session,
188-
const char* model_directory,
189-
const char* tokenizer_path);
190-
191-
int mlxsharp_session_generate_tokens(
192-
void* session,
193-
const int32_t* prompt_tokens,
194-
size_t prompt_token_count,
195-
const mlxsharp_generation_options* options,
196-
mlxsharp_token_buffer* output_tokens,
197-
mlx_usage* usage);
198-
199-
void mlxsharp_release_tokens(mlxsharp_token_buffer* buffer);
200-
201174
#ifdef __cplusplus
202175
}
203176
#endif

native/src/mlxsharp.cpp

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "mlxsharp/api.h"
2-
#include "mlxsharp/llm_model_runner.h"
32

43
#include <algorithm>
54
#include <atomic>
@@ -44,8 +43,6 @@ struct mlxsharp_session {
4443
std::string chat_model;
4544
std::string embedding_model;
4645
std::string image_model;
47-
std::unique_ptr<mlxsharp::llm::ModelRunner> model_runner;
48-
4946
mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image)
5047
: context(ctx),
5148
chat_model(std::move(chat)),
@@ -565,104 +562,6 @@ void mlxsharp_free_buffer(unsigned char* data) {
565562
std::free(data);
566563
}
567564

568-
int mlxsharp_session_load_model(
569-
void* session_ptr,
570-
const char* model_directory,
571-
const char* tokenizer_path) {
572-
if (session_ptr == nullptr) {
573-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null.");
574-
}
575-
576-
if (model_directory == nullptr || tokenizer_path == nullptr) {
577-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model directory or tokenizer path is null.");
578-
}
579-
580-
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
581-
582-
return invoke([&]() -> int {
583-
auto model = mlxsharp::llm::ModelRunner::Create(model_directory, tokenizer_path);
584-
session->model_runner = std::move(model);
585-
return MLXSHARP_STATUS_SUCCESS;
586-
});
587-
}
588-
589-
int mlxsharp_session_generate_tokens(
590-
void* session_ptr,
591-
const int32_t* prompt_tokens,
592-
size_t prompt_token_count,
593-
const mlxsharp_generation_options* options,
594-
mlxsharp_token_buffer* output_tokens,
595-
mlx_usage* usage) {
596-
if (session_ptr == nullptr) {
597-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null.");
598-
}
599-
600-
if (output_tokens == nullptr) {
601-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullOutParameter);
602-
}
603-
604-
output_tokens->tokens = nullptr;
605-
output_tokens->length = 0;
606-
607-
auto* session = static_cast<mlxsharp_session_t*>(session_ptr);
608-
609-
if (session->model_runner == nullptr) {
610-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model is not loaded. Call mlxsharp_session_load_model first.");
611-
}
612-
613-
if (prompt_token_count > 0 && prompt_tokens == nullptr) {
614-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Prompt tokens pointer is null.");
615-
}
616-
617-
if (options == nullptr) {
618-
return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Generation options pointer is null.");
619-
}
620-
621-
return invoke([&]() -> int {
622-
std::vector<int32_t> prompt;
623-
prompt.reserve(prompt_token_count);
624-
for (size_t i = 0; i < prompt_token_count; ++i) {
625-
prompt.push_back(prompt_tokens[i]);
626-
}
627-
628-
mlxsharp::llm::GenerationOptions native_options{
629-
options->max_tokens,
630-
options->temperature,
631-
options->top_p,
632-
options->top_k,
633-
};
634-
635-
auto generated = session->model_runner->Generate(prompt, native_options);
636-
output_tokens->length = generated.size();
637-
638-
if (generated.empty()) {
639-
assign_usage(usage, static_cast<int>(prompt_token_count), 0);
640-
return MLXSHARP_STATUS_SUCCESS;
641-
}
642-
643-
auto* buffer = static_cast<int32_t*>(std::malloc(generated.size() * sizeof(int32_t)));
644-
if (buffer == nullptr) {
645-
return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Failed to allocate output token buffer.");
646-
}
647-
648-
std::memcpy(buffer, generated.data(), generated.size() * sizeof(int32_t));
649-
output_tokens->tokens = buffer;
650-
651-
assign_usage(usage, static_cast<int>(prompt_token_count), static_cast<int>(generated.size()));
652-
return MLXSHARP_STATUS_SUCCESS;
653-
});
654-
}
655-
656-
void mlxsharp_release_tokens(mlxsharp_token_buffer* buffer) {
657-
if (buffer == nullptr || buffer->tokens == nullptr) {
658-
return;
659-
}
660-
661-
std::free(buffer->tokens);
662-
buffer->tokens = nullptr;
663-
buffer->length = 0;
664-
}
665-
666565
void mlxsharp_release_session(void* session_ptr) {
667566
if (session_ptr == nullptr) {
668567
return;

src/MLXSharp.Tests/ModelIntegrationTests.cs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using MLXSharp;
66
using MLXSharp.Backends;
77
using Xunit;
8-
using Xunit.Sdk;
98

109
namespace MLXSharp.Tests;
1110

@@ -24,7 +23,7 @@ public async Task NativeBackendAnswersSimpleMathAsync()
2423
new[] { new ChatMessage(ChatRole.User, "Скільки буде 2+2?") },
2524
new ChatOptions { Temperature = 0 });
2625

27-
var result = await backend.GenerateTextAsync(request, CancellationToken.None).ConfigureAwait(false);
26+
var result = await backend.GenerateTextAsync(request, CancellationToken.None);
2827

2928
Assert.False(string.IsNullOrWhiteSpace(result.Text));
3029
Assert.Contains("4", result.Text);
@@ -39,6 +38,24 @@ private static MlxClientOptions CreateOptions()
3938
EnableNativeModelRunner = false,
4039
};
4140

41+
var modelId = Environment.GetEnvironmentVariable("MLXSHARP_HF_MODEL_ID");
42+
if (!string.IsNullOrWhiteSpace(modelId))
43+
{
44+
options.ChatModelId = modelId;
45+
}
46+
47+
var modelDirectory = Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH");
48+
if (!string.IsNullOrWhiteSpace(modelDirectory))
49+
{
50+
options.NativeModelDirectory = modelDirectory;
51+
}
52+
53+
var tokenizerPath = Environment.GetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH");
54+
if (!string.IsNullOrWhiteSpace(tokenizerPath))
55+
{
56+
options.TokenizerPath = tokenizerPath;
57+
}
58+
4259
return options;
4360
}
4461

@@ -47,13 +64,13 @@ private static void EnsureAssetsOrSkip()
4764
var modelPath = Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH");
4865
if (string.IsNullOrWhiteSpace(modelPath) || !System.IO.Directory.Exists(modelPath))
4966
{
50-
throw new SkipException("Native model bundle not found.");
67+
Skip.If(true, "Native model bundle not found.");
5168
}
5269

5370
var library = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY");
5471
if (string.IsNullOrWhiteSpace(library) || !System.IO.File.Exists(library))
5572
{
56-
throw new SkipException("Native libmlxsharp library not configured.");
73+
Skip.If(true, "Native libmlxsharp library not configured.");
5774
}
5875
}
5976
}

src/MLXSharp.Tests/TestEnvironment.cs

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,40 @@ public static void EnsureInitialized()
2525

2626
private static void ConfigureNativeLibrary(string repoRoot)
2727
{
28-
if (!string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY")))
28+
var existing = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY");
29+
if (!string.IsNullOrWhiteSpace(existing) && File.Exists(existing))
2930
{
31+
ApplyNativeLibrary(existing);
3032
return;
3133
}
3234

3335
string? libraryPath = null;
3436
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
3537
{
36-
var candidate = Path.Combine(repoRoot, "libs", "native-osx-arm64", "libmlxsharp.dylib");
37-
if (File.Exists(candidate))
38+
var candidates = new[]
3839
{
39-
libraryPath = candidate;
40-
}
40+
Path.Combine(repoRoot, "libs", "native-osx-arm64", "libmlxsharp.dylib"),
41+
Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.dylib"),
42+
Path.Combine(repoRoot, "libs", "native-libs", "osx-arm64", "libmlxsharp.dylib"),
43+
};
44+
45+
libraryPath = Array.Find(candidates, File.Exists);
4146
}
4247
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
4348
{
44-
var candidate = Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.so");
45-
if (File.Exists(candidate))
49+
var candidates = new[]
4650
{
47-
libraryPath = candidate;
48-
}
51+
Path.Combine(repoRoot, "libs", "native-linux", "libmlxsharp.so"),
52+
Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.so"),
53+
Path.Combine(repoRoot, "libs", "native-libs", "linux-x64", "libmlxsharp.so"),
54+
};
55+
56+
libraryPath = Array.Find(candidates, File.Exists);
4957
}
5058

5159
if (!string.IsNullOrWhiteSpace(libraryPath))
5260
{
53-
Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", libraryPath);
54-
55-
var metalPath = Path.Combine(Path.GetDirectoryName(libraryPath)!, "mlx.metallib");
56-
if (File.Exists(metalPath))
57-
{
58-
Environment.SetEnvironmentVariable("MLX_METAL_PATH", metalPath);
59-
}
61+
ApplyNativeLibrary(libraryPath);
6062
}
6163
}
6264

@@ -77,4 +79,41 @@ private static void ConfigureModelPaths(string repoRoot)
7779
Environment.SetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH", tokenizerPath);
7880
}
7981
}
82+
83+
private static void ApplyNativeLibrary(string libraryPath)
84+
{
85+
Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", libraryPath);
86+
87+
var metalPath = Path.Combine(Path.GetDirectoryName(libraryPath)!, "mlx.metallib");
88+
if (File.Exists(metalPath))
89+
{
90+
Environment.SetEnvironmentVariable("MLX_METAL_PATH", metalPath);
91+
Environment.SetEnvironmentVariable("MLX_METALLIB", metalPath);
92+
}
93+
94+
var fileName = RuntimeInformation.IsOSPlatform(OSPlatform.OSX)
95+
? "libmlxsharp.dylib"
96+
: RuntimeInformation.IsOSPlatform(OSPlatform.Linux)
97+
? "libmlxsharp.so"
98+
: "libmlxsharp";
99+
100+
TryCopy(libraryPath, Path.Combine(AppContext.BaseDirectory, fileName));
101+
if (File.Exists(metalPath))
102+
{
103+
TryCopy(metalPath, Path.Combine(AppContext.BaseDirectory, "mlx.metallib"));
104+
}
105+
}
106+
107+
private static void TryCopy(string source, string destination)
108+
{
109+
try
110+
{
111+
Directory.CreateDirectory(Path.GetDirectoryName(destination)!);
112+
File.Copy(source, destination, overwrite: true);
113+
}
114+
catch
115+
{
116+
// best effort copy; ignore IO errors
117+
}
118+
}
80119
}

src/MLXSharp.Tests/TokenizerSmokeTests.cs

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)