1
- #
2
- # From: https://raw.githubusercontent.com/KohakuBlueleaf/LoCon/main/extract_locon.py
3
- #
4
-
1
+ import os , sys
2
+ sys .path .insert (0 , os .getcwd ())
5
3
import argparse
6
4
5
+
7
6
def get_args ():
8
7
parser = argparse .ArgumentParser ()
9
8
parser .add_argument (
@@ -29,11 +28,15 @@ def get_args():
29
28
parser .add_argument (
30
29
"--mode" ,
31
30
help = (
32
- 'extraction mode, can be "fixed", "threshold", "ratio", "percentile ". '
31
+ 'extraction mode, can be "fixed", "threshold", "ratio", "quantile ". '
33
32
'If not "fixed", network_dim and conv_dim will be ignored'
34
33
),
35
34
default = 'fixed' , type = str
36
35
)
36
+ parser .add_argument (
37
+ "--safetensors" , help = 'use safetensors to save locon model' ,
38
+ default = True , action = "store_true"
39
+ )
37
40
parser .add_argument (
38
41
"--linear_dim" , help = "network dim for linear layer in fixed mode" ,
39
42
default = 1 , type = int
@@ -59,20 +62,34 @@ def get_args():
59
62
default = 0. , type = float
60
63
)
61
64
parser .add_argument (
62
- "--linear_percentile " , help = "singular value percentile for linear layer percentile mode" ,
65
+ "--linear_quantile " , help = "singular value quantile for linear layer quantile mode" ,
63
66
default = 1. , type = float
64
67
)
65
68
parser .add_argument (
66
- "--conv_percentile " , help = "singular value percentile for conv layer percentile mode" ,
69
+ "--conv_quantile " , help = "singular value quantile for conv layer quantile mode" ,
67
70
default = 1. , type = float
68
71
)
72
+ parser .add_argument (
73
+ "--use_sparse_bias" , help = "enable sparse bias" ,
74
+ default = False , action = "store_true"
75
+ )
76
+ parser .add_argument (
77
+ "--sparsity" , help = "sparsity for sparse bias" ,
78
+ default = 0.98 , type = float
79
+ )
80
+ parser .add_argument (
81
+ "--disable_cp" , help = "don't use cp decomposition" ,
82
+ default = False , action = "store_true"
83
+ )
69
84
return parser .parse_args ()
70
85
ARGS = get_args ()
71
86
72
- from locon .utils import extract_diff
73
- from locon .kohya_model_utils import load_models_from_stable_diffusion_checkpoint
87
+
88
+ from lycoris .utils import extract_diff
89
+ from lycoris .kohya .model_utils import load_models_from_stable_diffusion_checkpoint
74
90
75
91
import torch
92
+ from safetensors .torch import save_file
76
93
77
94
78
95
def main ():
@@ -84,22 +101,28 @@ def main():
84
101
'fixed' : args .linear_dim ,
85
102
'threshold' : args .linear_threshold ,
86
103
'ratio' : args .linear_ratio ,
87
- 'percentile ' : args .linear_percentile ,
104
+ 'quantile ' : args .linear_quantile ,
88
105
}[args .mode ]
89
106
conv_mode_param = {
90
107
'fixed' : args .conv_dim ,
91
108
'threshold' : args .conv_threshold ,
92
109
'ratio' : args .conv_ratio ,
93
- 'percentile ' : args .conv_percentile ,
110
+ 'quantile ' : args .conv_quantile ,
94
111
}[args .mode ]
95
112
96
113
state_dict = extract_diff (
97
114
base , db ,
98
115
args .mode ,
99
116
linear_mode_param , conv_mode_param ,
100
- args .device
117
+ args .device ,
118
+ args .use_sparse_bias , args .sparsity ,
119
+ not args .disable_cp
101
120
)
102
- torch .save (state_dict , args .output_name )
121
+
122
+ if args .safetensors :
123
+ save_file (state_dict , args .output_name )
124
+ else :
125
+ torch .save (state_dict , args .output_name )
103
126
104
127
105
128
if __name__ == '__main__' :
0 commit comments