Skip to content

Commit

Permalink
Switch Unet mode
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed May 13, 2024
1 parent 4134ece commit 2dcd76f
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 48 deletions.
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/ControlNetExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public async Task RunAsync()
var controlImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png");

// Create ControlNet
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion);
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose);

// Create Pipeline
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");
Expand All @@ -49,7 +49,7 @@ public async Task RunAsync()
};

// Preload (optional)
await pipeline.LoadAsync(true);
await pipeline.LoadAsync(UnetModeType.ControlNet);

// Run pipeline
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);
Expand Down
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/ControlNetFeatureExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task RunAsync()
await controlImage.SaveAsync(Path.Combine(_outputDirectory, $"Depth.png"));

// Create ControlNet
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion);
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth);

// Create Pipeline
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");
Expand All @@ -57,7 +57,7 @@ public async Task RunAsync()
};

// Preload (optional)
await pipeline.LoadAsync(true);
await pipeline.LoadAsync(UnetModeType.ControlNet);

// Run pipeline
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);
Expand Down
9 changes: 9 additions & 0 deletions OnnxStack.StableDiffusion/Enums/UnetModeType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace OnnxStack.StableDiffusion.Enums
{
public enum UnetModeType
{
Default = 0,
ControlNet = 1,
Both = 2
}
}
8 changes: 3 additions & 5 deletions OnnxStack.StableDiffusion/Models/ControlNetModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.ML.OnnxRuntime;

using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Enums;
Expand All @@ -15,14 +16,13 @@ public ControlNetModel(ControlNetModelConfig configuration)
}

public ControlNetType Type => _configuration.Type;
public DiffuserPipelineType PipelineType => _configuration.PipelineType;

public static ControlNetModel Create(ControlNetModelConfig configuration)
{
return new ControlNetModel(configuration);
}

public static ControlNetModel Create(string modelFile, ControlNetType type, DiffuserPipelineType pipeline, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
public static ControlNetModel Create(string modelFile, ControlNetType type, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new ControlNetModelConfig
{
Expand All @@ -33,7 +33,6 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, Diff
IntraOpNumThreads = 0,
OnnxModelPath = modelFile,
Type = type,
PipelineType = pipeline,
};
return new ControlNetModel(configuration);
}
Expand All @@ -42,6 +41,5 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, Diff
public record ControlNetModelConfig : OnnxModelConfig
{
public ControlNetType Type { get; set; }
public DiffuserPipelineType PipelineType { get; set; }
}
}
12 changes: 11 additions & 1 deletion OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ public interface IPipeline
/// </summary>
string Name { get; }

/// <summary>
/// Gets the type of the pipeline.
/// </summary>
DiffuserPipelineType PipelineType { get; }

/// <summary>
/// Gets the pipelines supported diffusers.
Expand All @@ -39,11 +43,17 @@ public interface IPipeline
SchedulerOptions DefaultSchedulerOptions { get; }


/// <summary>
/// Gets the current unet mode.
/// </summary>
UnetModeType CurrentUnetMode { get; }


/// <summary>
/// Loads the pipeline.
/// </summary>
/// <returns></returns>
Task LoadAsync(bool controlNet = false);
Task LoadAsync(UnetModeType unetMode = UnetModeType.Default);


/// <summary>
Expand Down
10 changes: 7 additions & 3 deletions OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
public abstract SchedulerOptions DefaultSchedulerOptions { get; }


/// <summary>
/// Gets the current unet mode.
/// </summary>
public abstract UnetModeType CurrentUnetMode { get; }


/// <summary>
/// Loads the pipeline.
/// </summary>
/// <returns></returns>
public abstract Task LoadAsync(bool controlNet = false);
public abstract Task LoadAsync(UnetModeType unetMode = UnetModeType.Default);


/// <summary>
Expand Down Expand Up @@ -96,8 +102,6 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
public abstract Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);




/// <summary>
/// Runs the pipeline batch.
/// </summary>
Expand Down
6 changes: 3 additions & 3 deletions OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
public UNetConditionModel DecoderUnet => _decoderUnet;


public override Task LoadAsync(bool controlNet = false)
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
{
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
return base.LoadAsync(controlNet);
return base.LoadAsync(unetMode);

// Preload all models into VRAM
return Task.WhenAll
(
_decoderUnet.LoadAsync(),
base.LoadAsync(controlNet)
base.LoadAsync(unetMode)
);
}

Expand Down
28 changes: 21 additions & 7 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class StableDiffusionPipeline : PipelineBase
protected List<DiffuserType> _supportedDiffusers;
protected IReadOnlyList<SchedulerType> _supportedSchedulers;
protected SchedulerOptions _defaultSchedulerOptions;
private UnetModeType _currentUnetMode;

protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, List<DenseTensor<float>> Result);

Expand Down Expand Up @@ -110,6 +111,11 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
/// </summary>
public override SchedulerOptions DefaultSchedulerOptions => _defaultSchedulerOptions;

/// <summary>
/// Gets the current unet mode.
/// </summary>
public override UnetModeType CurrentUnetMode => _currentUnetMode;

/// <summary>
/// Gets the unet.
/// </summary>
Expand Down Expand Up @@ -144,22 +150,29 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
/// <summary>
/// Loads the pipeline.
/// </summary>
public override Task LoadAsync(bool controlNet = false)
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
{
_currentUnetMode = unetMode;
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
return Task.CompletedTask;

// Preload all models into VRAM
return Task.WhenAll
var unetModels = Task.CompletedTask;
if (_currentUnetMode == UnetModeType.Default)
unetModels = Task.WhenAll(_unet.LoadAsync(), _controlNetUnet?.UnloadAsync() ?? Task.CompletedTask);
if (_currentUnetMode == UnetModeType.ControlNet)
unetModels = Task.WhenAll(_controlNetUnet.LoadAsync(), _unet.UnloadAsync());
if (_currentUnetMode == UnetModeType.Both)
unetModels = Task.WhenAll(_unet.LoadAsync(), _controlNetUnet?.LoadAsync() ?? Task.CompletedTask);

var subModels = Task.WhenAll
(
controlNet
? _controlNetUnet.LoadAsync()
: _unet.LoadAsync(),
_tokenizer.LoadAsync(),
_tokenizer.LoadAsync(),
_textEncoder.LoadAsync(),
_vaeDecoder.LoadAsync(),
_vaeEncoder.LoadAsync()
);

return Task.WhenAll(unetModels, subModels);
}


Expand Down Expand Up @@ -695,4 +708,5 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusion, modelType, deviceId, executionProvider, memoryMode), logger);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel
/// <summary>
/// Loads the pipeline
/// </summary>
public override Task LoadAsync(bool controlNet = false)
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
{
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
return base.LoadAsync(controlNet);
return base.LoadAsync(unetMode);

// Preload all models into VRAM
return Task.WhenAll
(
_tokenizer2.LoadAsync(),
_textEncoder2.LoadAsync(),
base.LoadAsync(controlNet)
base.LoadAsync(unetMode)
);
}

Expand Down
11 changes: 1 addition & 10 deletions OnnxStack.UI/Dialogs/AddControlNetModelDialog.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,6 @@ public ControlNetType SelectedControlNetType

}

private DiffuserPipelineType _selectedPipelineType;

public DiffuserPipelineType SelectedPipelineType
{
get { return _selectedPipelineType; }
set { _selectedPipelineType = value; NotifyPropertyChanged(); CreateModelSet(); }
}


public ControlNetModelSet ModelSetResult
{
get { return _modelSetResult; }
Expand All @@ -104,7 +95,7 @@ private void CreateModelSet()
if (string.IsNullOrEmpty(_modelFile))
return;

_modelSetResult = _modelFactory.CreateControlNetModelSet(ModelName.Trim(), _selectedControlNetType, _selectedPipelineType, _modelFile);
_modelSetResult = _modelFactory.CreateControlNetModelSet(ModelName.Trim(), _selectedControlNetType, _modelFile);

// Validate
ValidationResults.Add(new ValidationResult("Name", !_invalidOptions.Contains(_modelName, StringComparer.OrdinalIgnoreCase) && _modelName.Length > 2 && _modelName.Length < 50));
Expand Down
9 changes: 0 additions & 9 deletions OnnxStack.UI/Models/UpdateControlNetModelSetViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public class UpdateControlNetModelSetViewModel : INotifyPropertyChanged
private ExecutionProvider _executionProvider;
private string _modelFile;
private ControlNetType _controlNetType;
private DiffuserPipelineType _pipelineType;


public string Name
Expand All @@ -37,12 +36,6 @@ public ControlNetType ControlNetType
set { _controlNetType = value; NotifyPropertyChanged(); }
}

public DiffuserPipelineType PipelineType
{
get { return _pipelineType; }
set { _pipelineType = value; NotifyPropertyChanged(); }
}

public int DeviceId
{
get { return _deviceId; }
Expand Down Expand Up @@ -86,7 +79,6 @@ public static UpdateControlNetModelSetViewModel FromModelSet(ControlNetModelSet
{
Name = modelset.Name,
ControlNetType = modelset.ControlNetConfig.Type,
PipelineType = modelset.ControlNetConfig.PipelineType,
DeviceId = modelset.DeviceId,
ExecutionMode = modelset.ExecutionMode,
ExecutionProvider = modelset.ExecutionProvider,
Expand All @@ -111,7 +103,6 @@ public static ControlNetModelSet ToModelSet(UpdateControlNetModelSetViewModel mo
ControlNetConfig = new ControlNetModelConfig
{
Type = modelset.ControlNetType,
PipelineType = modelset.PipelineType,
OnnxModelPath = modelset.ModelFile
}
};
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.UI/Services/IModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface IModelFactory

UpscaleModelSet CreateUpscaleModelSet(string name, string filename, UpscaleModelTemplate modelTemplate);
StableDiffusionModelSet CreateStableDiffusionModelSet(string name, string folder, StableDiffusionModelTemplate modelTemplate);
ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, DiffuserPipelineType pipelineType, string modelFilename);
ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, string modelFilename);
FeatureExtractorModelSet CreateFeatureExtractorModelSet(string name, bool normalize, int sampleSize, int channels, string modelFilename);
}
}
3 changes: 1 addition & 2 deletions OnnxStack.UI/Services/ModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public UpscaleModelSet CreateUpscaleModelSet(string name, string filename, Upsca
}


public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, DiffuserPipelineType pipelineType, string modelFilename)
public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, string modelFilename)
{
return new ControlNetModelSet
{
Expand All @@ -159,7 +159,6 @@ public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType c
ControlNetConfig = new ControlNetModelConfig
{
Type = controlNetType,
PipelineType = pipelineType,
OnnxModelPath = modelFilename
}
};
Expand Down

0 comments on commit 2dcd76f

Please sign in to comment.