Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Chroma: Fix t5 chunk length
  • Loading branch information
stduhpf committed Jun 2, 2025
commit 4fdedd5ac59bf0d5e4d197077eed53929fbe35e6
31 changes: 16 additions & 15 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ struct FluxCLIPEmbedder : public Conditioner {
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<CLIPTextModelRunner> clip_l;
std::shared_ptr<T5Runner> t5;
size_t chunk_len = 256;

FluxCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
Expand Down Expand Up @@ -1109,7 +1110,6 @@ struct FluxCLIPEmbedder : public Conditioner {
struct ggml_tensor* pooled = NULL; // [768,]
std::vector<float> hidden_states_vec;

size_t chunk_len = 256;
size_t chunk_count = t5_tokens.size() / chunk_len;
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
// clip_l
Expand Down Expand Up @@ -1196,7 +1196,7 @@ struct FluxCLIPEmbedder : public Conditioner {
int height,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
auto tokens_and_weights = tokenize(text, 256, true);
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
}

Expand All @@ -1221,6 +1221,7 @@ struct FluxCLIPEmbedder : public Conditioner {
struct PixArtCLIPEmbedder : public Conditioner {
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<T5Runner> t5;
size_t chunk_len = 512;

PixArtCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
Expand Down Expand Up @@ -1304,8 +1305,18 @@ struct PixArtCLIPEmbedder : public Conditioner {

std::vector<float> hidden_states_vec;

size_t chunk_len = 256;
size_t chunk_count = t5_tokens.size() / chunk_len;

bool use_mask = true;
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
if (SD_CHROMA_USE_T5_MASK != nullptr) {
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
use_mask = false;
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
}
}
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
// t5
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
Expand All @@ -1316,17 +1327,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);

const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
if (SD_CHROMA_USE_T5_MASK != nullptr) {
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
t5_attn_mask_chunk = NULL;
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
}
}
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;

t5->compute(n_threads,
input_ids,
Expand Down Expand Up @@ -1384,7 +1385,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
int height,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
auto tokens_and_weights = tokenize(text, 512, true);
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
}

Expand Down
Loading