Skip to content

Commit 4a60373

Browse files
stduhpfthxCode
authored andcommitted
feat: partial LyCORIS support (tucker decomposition for LoCon + LoHa + LoKr) (leejet#577)
1 parent 3eb18db commit 4a60373

File tree

2 files changed

+572
-344
lines changed

2 files changed

+572
-344
lines changed

ggml_extend.hpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,71 @@ __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const cha
7979
}
8080
}
8181

82+
// n-mode trensor-matrix product
83+
// example: 2-mode product
84+
// A: [ne03, k, ne01, ne00]
85+
// B: k rows, m columns => [k, m]
86+
// result is [ne03, m, ne01, ne00]
87+
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
88+
// reshape A
89+
// swap 0th and nth axis
90+
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
91+
int ne1 = a->ne[1];
92+
int ne2 = a->ne[2];
93+
int ne3 = a->ne[3];
94+
// make 2D
95+
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
96+
97+
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
98+
99+
// reshape output (same shape as a after permutation except first dim)
100+
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
101+
// swap back 0th and nth axis
102+
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
103+
return result;
104+
}
105+
106+
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
107+
struct ggml_tensor* updown;
108+
// flat lora tensors to multiply it
109+
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
110+
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
111+
auto lora_down_n_dims = ggml_n_dims(lora_down);
112+
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
113+
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
114+
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
115+
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
116+
117+
// ggml_mul_mat requires tensor b transposed
118+
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
119+
if (lora_mid == NULL) {
120+
updown = ggml_mul_mat(ctx, lora_up, lora_down);
121+
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
122+
} else {
123+
// undoing tucker decomposition for conv layers.
124+
// lora_mid has shape (3, 3, Rank, Rank)
125+
// lora_down has shape (Rank, In, 1, 1)
126+
// lora_up has shape (Rank, Out, 1, 1)
127+
// conv layer shape is (3, 3, Out, In)
128+
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
129+
updown = ggml_cont(ctx, updown);
130+
}
131+
return updown;
132+
}
133+
134+
// Kronecker product
135+
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
136+
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
137+
return ggml_mul(ctx,
138+
ggml_upscale_ext(ctx,
139+
a,
140+
a->ne[0] * b->ne[0],
141+
a->ne[1] * b->ne[1],
142+
a->ne[2] * b->ne[2],
143+
a->ne[3] * b->ne[3]),
144+
b);
145+
}
146+
82147
__STATIC_INLINE__ void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr<RNG> rng) {
83148
uint32_t n = (uint32_t)ggml_nelements(tensor);
84149
std::vector<float> random_numbers = rng->randn(n);
@@ -1078,8 +1143,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
10781143
}
10791144

10801145
/* SDXL with LoRA requires more space */
1081-
#define MAX_PARAMS_TENSOR_NUM 15360
1082-
#define MAX_GRAPH_SIZE 15360
1146+
#define MAX_PARAMS_TENSOR_NUM 32768
1147+
#define MAX_GRAPH_SIZE 32768
10831148

10841149
struct GGMLRunner {
10851150
protected:

0 commit comments

Comments
 (0)