Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 97b691d

Browse files
committed
Tidy up ModelSet collections
1 parent e39f454 commit 97b691d

File tree

6 files changed

+100
-62
lines changed

6 files changed

+100
-62
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System.Collections.Generic;
2+
3+
namespace OnnxStack.Core.Config
4+
{
5+
public class OnnxModelEqualityComparer : IEqualityComparer<IOnnxModel>
6+
{
7+
public bool Equals(IOnnxModel x, IOnnxModel y)
8+
{
9+
return x != null && y != null && x.Name == y.Name;
10+
}
11+
12+
public int GetHashCode(IOnnxModel obj)
13+
{
14+
return obj?.Name?.GetHashCode() ?? 0;
15+
}
16+
}
17+
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ namespace OnnxStack.Core.Services
1616
public sealed class OnnxModelService : IOnnxModelService
1717
{
1818
private readonly OnnxStackConfig _configuration;
19-
private readonly ConcurrentDictionary<string, OnnxModelSet> _onnxModelSets;
20-
private readonly ConcurrentDictionary<string, IOnnxModelSetConfig> _onnxModelSetConfigs;
19+
private readonly ConcurrentDictionary<IOnnxModel, OnnxModelSet> _onnxModelSets;
20+
private readonly ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig> _onnxModelSetConfigs;
2121

2222
/// <summary>
2323
/// Initializes a new instance of the <see cref="OnnxModelService"/> class.
@@ -26,8 +26,8 @@ public sealed class OnnxModelService : IOnnxModelService
2626
public OnnxModelService(OnnxStackConfig configuration)
2727
{
2828
_configuration = configuration;
29-
_onnxModelSets = new ConcurrentDictionary<string, OnnxModelSet>();
30-
_onnxModelSetConfigs = new ConcurrentDictionary<string, IOnnxModelSetConfig>();
29+
_onnxModelSets = new ConcurrentDictionary<IOnnxModel, OnnxModelSet>(new OnnxModelEqualityComparer());
30+
_onnxModelSetConfigs = new ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig>(new OnnxModelEqualityComparer());
3131
}
3232

3333

@@ -50,7 +50,7 @@ public OnnxModelService(OnnxStackConfig configuration)
5050
/// <returns></returns>
5151
public Task<bool> AddModelSet(IOnnxModelSetConfig modelSet)
5252
{
53-
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet));
53+
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
5454
}
5555

5656
/// <summary>
@@ -74,7 +74,7 @@ public Task AddModelSet(IEnumerable<IOnnxModelSetConfig> modelSets)
7474
/// <returns></returns>
7575
public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
7676
{
77-
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet.Name, out _));
77+
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet, out _));
7878
}
7979

8080

@@ -85,8 +85,8 @@ public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
8585
/// <returns></returns>
8686
public Task<bool> UpdateModelSet(IOnnxModelSetConfig modelSet)
8787
{
88-
_onnxModelSetConfigs.TryRemove(modelSet.Name, out _);
89-
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet));
88+
_onnxModelSetConfigs.TryRemove(modelSet, out _);
89+
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
9090
}
9191

9292

@@ -120,7 +120,7 @@ public async Task<bool> UnloadModelAsync(IOnnxModel model)
120120
/// </returns>
121121
public bool IsModelLoaded(IOnnxModel model)
122122
{
123-
return _onnxModelSets.ContainsKey(model.Name);
123+
return _onnxModelSets.ContainsKey(model);
124124
}
125125

126126

@@ -251,7 +251,7 @@ private OnnxMetadata GetNodeMetadataInternal(IOnnxModel model, OnnxModelType mod
251251
/// <exception cref="System.Exception">Model {model.Name} has not been loaded</exception>
252252
private OnnxModelSet GetModelSet(IOnnxModel model)
253253
{
254-
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
254+
if (!_onnxModelSets.TryGetValue(model, out var modelSet))
255255
throw new Exception($"Model {model.Name} has not been loaded");
256256

257257
return modelSet;
@@ -266,17 +266,17 @@ private OnnxModelSet GetModelSet(IOnnxModel model)
266266
/// <exception cref="System.Exception">Model {model.Name} not found in configuration</exception>
267267
private OnnxModelSet LoadModelSet(IOnnxModel model)
268268
{
269-
if (_onnxModelSets.ContainsKey(model.Name))
270-
return _onnxModelSets[model.Name];
269+
if (_onnxModelSets.ContainsKey(model))
270+
return _onnxModelSets[model];
271271

272-
if (!_onnxModelSetConfigs.TryGetValue(model.Name, out var modelSetConfig))
273-
throw new Exception($"Model {model.Name} not found in configuration");
272+
if (!_onnxModelSetConfigs.TryGetValue(model, out var modelSetConfig))
273+
throw new Exception($"Model {model.Name} not found");
274274

275275
if (!modelSetConfig.IsEnabled)
276276
throw new Exception($"Model {model.Name} is not enabled");
277277

278278
var modelSet = new OnnxModelSet(modelSetConfig);
279-
_onnxModelSets.TryAdd(model.Name, modelSet);
279+
_onnxModelSets.TryAdd(model, modelSet);
280280
return modelSet;
281281
}
282282

@@ -288,10 +288,10 @@ private OnnxModelSet LoadModelSet(IOnnxModel model)
288288
/// <returns></returns>
289289
private bool UnloadModelSet(IOnnxModel model)
290290
{
291-
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
291+
if (!_onnxModelSets.TryGetValue(model, out _))
292292
return true;
293293

294-
if (_onnxModelSets.TryRemove(model.Name, out modelSet))
294+
if (_onnxModelSets.TryRemove(model, out var modelSet))
295295
{
296296
modelSet?.Dispose();
297297
return true;

OnnxStack.ImageUpscaler/Services/IUpscaleService.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,10 @@ namespace OnnxStack.ImageUpscaler.Services
1313
public interface IUpscaleService
1414
{
1515

16-
/// <summary>
17-
/// Gets the configuration.
18-
/// </summary>
19-
ImageUpscalerConfig Configuration { get; }
20-
2116
/// <summary>
2217
/// Gets the model sets.
2318
/// </summary>
24-
IReadOnlyList<UpscaleModelSet> ModelSets { get; }
19+
IReadOnlyCollection<UpscaleModelSet> ModelSets { get; }
2520

2621
/// <summary>
2722
/// Adds the model.

OnnxStack.ImageUpscaler/Services/UpscaleService.cs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using SixLabors.ImageSharp;
1212
using SixLabors.ImageSharp.PixelFormats;
1313
using SixLabors.ImageSharp.Processing;
14+
using System;
1415
using System.Collections.Generic;
1516
using System.IO;
1617
using System.Linq;
@@ -22,6 +23,7 @@ public class UpscaleService : IUpscaleService
2223
{
2324
private readonly IOnnxModelService _modelService;
2425
private readonly ImageUpscalerConfig _configuration;
26+
private readonly HashSet<UpscaleModelSet> _modelSetConfigs;
2527

2628
/// <summary>
2729
/// Initializes a new instance of the <see cref="UpscaleService"/> class.
@@ -33,20 +35,15 @@ public UpscaleService(ImageUpscalerConfig configuration, IOnnxModelService model
3335
{
3436
_configuration = configuration;
3537
_modelService = modelService;
36-
_modelService.AddModelSet(_configuration.ModelSets);
38+
_modelSetConfigs = new HashSet<UpscaleModelSet>(_configuration.ModelSets, new OnnxModelEqualityComparer());
39+
_modelService.AddModelSet(_modelSetConfigs);
3740
}
3841

3942

40-
/// <summary>
41-
/// Gets the configuration.
42-
/// </summary>
43-
public ImageUpscalerConfig Configuration => _configuration;
44-
45-
4643
/// <summary>
4744
/// Gets the model sets.
4845
/// </summary>
49-
public IReadOnlyList<UpscaleModelSet> ModelSets => _configuration.ModelSets;
46+
public IReadOnlyCollection<UpscaleModelSet> ModelSets => _modelSetConfigs;
5047

5148

5249
/// <summary>
@@ -55,9 +52,14 @@ public UpscaleService(ImageUpscalerConfig configuration, IOnnxModelService model
5552
/// <param name="model">The model.</param>
5653
/// <returns></returns>
5754
/// <exception cref="System.NotImplementedException"></exception>
58-
public Task<bool> AddModelAsync(UpscaleModelSet model)
55+
public async Task<bool> AddModelAsync(UpscaleModelSet model)
5956
{
60-
return _modelService.AddModelSet(model);
57+
if (await _modelService.AddModelSet(model))
58+
{
59+
_modelSetConfigs.Add(model);
60+
return true;
61+
}
62+
return false;
6163
}
6264

6365

@@ -67,9 +69,14 @@ public Task<bool> AddModelAsync(UpscaleModelSet model)
6769
/// <param name="model">The model.</param>
6870
/// <returns></returns>
6971
/// <exception cref="System.NotImplementedException"></exception>
70-
public Task<bool> RemoveModelAsync(UpscaleModelSet model)
72+
public async Task<bool> RemoveModelAsync(UpscaleModelSet model)
7173
{
72-
return _modelService.RemoveModelSet(model);
74+
if (await _modelService.RemoveModelSet(model))
75+
{
76+
_modelSetConfigs.Remove(model);
77+
return true;
78+
}
79+
return false;
7380
}
7481

7582

@@ -79,9 +86,15 @@ public Task<bool> RemoveModelAsync(UpscaleModelSet model)
7986
/// <param name="model">The model.</param>
8087
/// <returns></returns>
8188
/// <exception cref="System.NotImplementedException"></exception>
82-
public Task<bool> UpdateModelAsync(UpscaleModelSet model)
89+
public async Task<bool> UpdateModelAsync(UpscaleModelSet model)
8390
{
84-
return _modelService.UpdateModelSet(model);
91+
if (await _modelService.UpdateModelSet(model))
92+
{
93+
_modelSetConfigs.Remove(model);
94+
_modelSetConfigs.Add(model);
95+
return true;
96+
}
97+
return false;
8598
}
8699

87100

@@ -92,6 +105,9 @@ public Task<bool> UpdateModelAsync(UpscaleModelSet model)
92105
/// <returns></returns>
93106
public async Task<bool> LoadModelAsync(UpscaleModelSet model)
94107
{
108+
if (!_modelSetConfigs.TryGetValue(model, out _))
109+
throw new Exception("ModelSet not found");
110+
95111
var modelSet = await _modelService.LoadModelAsync(model);
96112
return modelSet is not null;
97113
}

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,10 @@ namespace OnnxStack.StableDiffusion.Common
1313
{
1414
public interface IStableDiffusionService
1515
{
16-
17-
/// <summary>
18-
/// Gets the configuration.
19-
/// </summary>
20-
StableDiffusionConfig Configuration { get; }
21-
22-
/// <summary>
16+
/// <summary>
2317
/// Gets the models.
2418
/// </summary>
25-
IReadOnlyList<StableDiffusionModelSet> ModelSets { get; }
19+
IReadOnlyCollection<StableDiffusionModelSet> ModelSets { get; }
2620

2721
/// <summary>
2822
/// Adds the model.

OnnxStack.StableDiffusion/Services/StableDiffusionService.cs

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.Core;
3+
using OnnxStack.Core.Config;
34
using OnnxStack.Core.Services;
45
using OnnxStack.StableDiffusion.Common;
56
using OnnxStack.StableDiffusion.Config;
@@ -27,6 +28,7 @@ public sealed class StableDiffusionService : IStableDiffusionService
2728
{
2829
private readonly IOnnxModelService _modelService;
2930
private readonly StableDiffusionConfig _configuration;
31+
private readonly HashSet<StableDiffusionModelSet> _modelSetConfigs;
3032
private readonly ConcurrentDictionary<DiffuserPipelineType, IPipeline> _pipelines;
3133

3234
/// <summary>
@@ -37,31 +39,31 @@ public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelSer
3739
{
3840
_configuration = configuration;
3941
_modelService = onnxModelService;
40-
_modelService.AddModelSet(configuration.ModelSets);
42+
_modelSetConfigs = new HashSet<StableDiffusionModelSet>(_configuration.ModelSets, new OnnxModelEqualityComparer());
43+
_modelService.AddModelSet(_modelSetConfigs);
4144
_pipelines = pipelines.ToConcurrentDictionary(k => k.PipelineType, k => k);
4245
}
4346

4447

45-
/// <summary>
46-
/// Gets the configuration.
47-
/// </summary>
48-
public StableDiffusionConfig Configuration => _configuration;
49-
50-
5148
/// <summary>
5249
/// Gets the model sets.
5350
/// </summary>
54-
public IReadOnlyList<StableDiffusionModelSet> ModelSets => _configuration.ModelSets;
51+
public IReadOnlyCollection<StableDiffusionModelSet> ModelSets => _modelSetConfigs;
5552

5653

5754
/// <summary>
5855
/// Adds the model.
5956
/// </summary>
6057
/// <param name="model">The model.</param>
6158
/// <returns></returns>
62-
public Task<bool> AddModelAsync(StableDiffusionModelSet model)
59+
public async Task<bool> AddModelAsync(StableDiffusionModelSet model)
6360
{
64-
return _modelService.AddModelSet(model);
61+
if (await _modelService.AddModelSet(model))
62+
{
63+
_modelSetConfigs.Add(model);
64+
return true;
65+
}
66+
return false;
6567
}
6668

6769

@@ -70,9 +72,14 @@ public Task<bool> AddModelAsync(StableDiffusionModelSet model)
7072
/// </summary>
7173
/// <param name="model">The model.</param>
7274
/// <returns></returns>
73-
public Task<bool> RemoveModelAsync(StableDiffusionModelSet model)
75+
public async Task<bool> RemoveModelAsync(StableDiffusionModelSet model)
7476
{
75-
return _modelService.RemoveModelSet(model);
77+
if (await _modelService.RemoveModelSet(model))
78+
{
79+
_modelSetConfigs.Remove(model);
80+
return true;
81+
}
82+
return false;
7683
}
7784

7885

@@ -81,9 +88,15 @@ public Task<bool> RemoveModelAsync(StableDiffusionModelSet model)
8188
/// </summary>
8289
/// <param name="model">The model.</param>
8390
/// <returns></returns>
84-
public Task<bool> UpdateModelAsync(StableDiffusionModelSet model)
91+
public async Task<bool> UpdateModelAsync(StableDiffusionModelSet model)
8592
{
86-
return _modelService.UpdateModelSet(model);
93+
if (await _modelService.UpdateModelSet(model))
94+
{
95+
_modelSetConfigs.Remove(model);
96+
_modelSetConfigs.Add(model);
97+
return true;
98+
}
99+
return false;
87100
}
88101

89102

@@ -92,10 +105,13 @@ public Task<bool> UpdateModelAsync(StableDiffusionModelSet model)
92105
/// </summary>
93106
/// <param name="model">The model options.</param>
94107
/// <returns></returns>
95-
public async Task<bool> LoadModelAsync(StableDiffusionModelSet modelSet)
108+
public async Task<bool> LoadModelAsync(StableDiffusionModelSet model)
96109
{
97-
var model = await _modelService.LoadModelAsync(modelSet);
98-
return model is not null;
110+
if (!_modelSetConfigs.TryGetValue(model, out _))
111+
throw new Exception("ModelSet not found");
112+
113+
var modelSet = await _modelService.LoadModelAsync(model);
114+
return modelSet is not null;
99115
}
100116

101117

0 commit comments

Comments
 (0)