Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions xllm/core/framework/dit_model_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ bool DiTFolderLoader::load_args(const std::string& model_weights_path) {
return false;
}

if (!load_image_preprocessor_args(model_weights_path)) {
LOG(ERROR) << "Failed to load image preprocess args from "
<< model_weights_path;
return false;
}

return true;
}

Expand Down Expand Up @@ -219,6 +225,101 @@ bool DiTFolderLoader::load_tokenizer_args(
return true;
}

bool DiTFolderLoader::load_image_preprocessor_args(
const std::string& model_weights_path) {
// image preprocessor args
JsonReader image_preprocess_reader;
const std::string image_preprocess_file_path =
model_weights_path + "/preprocessor_config.json";
if (image_preprocess_reader.parse(image_preprocess_file_path)) {
LOG(INFO) << "Success to parse image preprocess args file: "
<< image_preprocess_file_path;
args_.mm_image_do_center_crop() =
image_preprocess_reader.value_or<bool>("do_center_crop", false);
args_.mm_image_crop_height_size() =
image_preprocess_reader.value_or<int>("crop_size.height", 335);
args_.mm_image_crop_width_size() =
image_preprocess_reader.value_or<int>("crop_size.width", 335);

args_.mm_image_size_height() =
image_preprocess_reader.value_or<int>("size.height", 384);

args_.mm_image_size_width() =
image_preprocess_reader.value_or<int>("size.width", 384);

args_.mm_image_do_resize() =
image_preprocess_reader.value_or<bool>("do_resize", false);
args_.mm_image_resize_shortest_edge() =
image_preprocess_reader.value_or<int>("size.shortest_edge", 335);
args_.mm_image_resample() =
image_preprocess_reader.value_or<int>("resample", 335);

args_.mm_image_do_rescale() =
image_preprocess_reader.value_or<bool>("do_rescale", false);
args_.mm_image_rescale_factor() =
image_preprocess_reader.value_or<double>("rescale_factor", 0);

args_.mm_image_do_normalize() =
image_preprocess_reader.value_or<bool>("do_normalize", false);

const auto& image_prerocess_data = image_preprocess_reader.data();
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}

if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}
Comment on lines +266 to +284

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code checks for image_mean and then for norm_mean, with both writing to args_.mm_image_normalize_mean(). The same pattern exists for image_std and norm_std. If a config file contains both keys (e.g., image_mean and norm_mean), the value from the latter key will silently overwrite the former. This could lead to unexpected behavior. It would be safer and clearer to use an else if structure to prioritize one key over the other.

Suggested change
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["image_mean"].get<std::vector<double>>();
} else if (image_preprocess_reader.contains("norm_mean")) {
args_.mm_image_normalize_mean() =
image_prerocess_data["norm_mean"].get<std::vector<double>>();
}
if (image_preprocess_reader.contains("image_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["image_std"].get<std::vector<double>>();
} else if (image_preprocess_reader.contains("norm_std")) {
args_.mm_image_normalize_std() =
image_prerocess_data["norm_std"].get<std::vector<double>>();
}


args_.mm_image_shortest_edge() =
image_preprocess_reader.value_or<int>("size.shortest_edge", 0);

args_.mm_image_longest_edge() =
image_preprocess_reader.value_or<int>("size.longest_edge", 0);

args_.mm_image_min_pixels() =
image_preprocess_reader.value_or<int>("min_pixels", 0);

args_.mm_image_max_pixels() =
image_preprocess_reader.value_or<int>("max_pixels", 0);

args_.mm_image_patch_size() =
image_preprocess_reader.value_or<int>("patch_size", 0);

args_.mm_image_temporal_patch_size() =
image_preprocess_reader.value_or<int>("temporal_patch_size", 0);

args_.mm_image_merge_size() =
image_preprocess_reader.value_or<int>("merge_size", 0);

args_.mm_image_feature_size() =
image_preprocess_reader.value_or<int>("image_feature_size", 0);

args_.mm_scale_resolution() =
image_preprocess_reader.value_or<int>("scale_resolution", 0);

args_.mm_slice_mode() =
image_preprocess_reader.value_or<bool>("slice_mode", false);

args_.mm_use_image_id() =
image_preprocess_reader.value_or<bool>("use_image_id", false);
}

return true;
}

DiTModelLoader::DiTModelLoader(const std::string& model_root_path)
: model_root_path_(model_root_path) {
if (!std::filesystem::exists(model_root_path_)) {
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/dit_model_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DiTFolderLoader {
bool load_args(const std::string& model_weights_path);
bool load_model_args(const std::string& model_weights_path);
bool load_tokenizer_args(const std::string& model_weights_path);
bool load_image_preprocessor_args(const std::string& model_weights_path);

// model args
ModelArgs args_;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ struct ModelArgs {
// VLM image preprocessor resize
PROPERTY(bool, mm_image_do_resize) = false;
PROPERTY(int, mm_image_resize_shortest_edge) = 336;
PROPERTY(int64_t, mm_image_size_height) = 384;
PROPERTY(int64_t, mm_image_size_width) = 384;

PROPERTY(int, mm_image_resample) = 0;

Expand Down
8 changes: 7 additions & 1 deletion xllm/core/framework/request/dit_request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ struct DiTGenerationParams {
num_images_per_prompt == other.num_images_per_prompt &&
seed == other.seed &&
max_sequence_length == other.max_sequence_length &&
strength == other.strength;
strength == other.strength &&
prompt_embeds_scale == other.prompt_embeds_scale &&
pooled_prompt_embeds_scale == other.pooled_prompt_embeds_scale;
}

bool operator!=(const DiTGenerationParams& other) const {
Expand All @@ -65,6 +67,10 @@ struct DiTGenerationParams {
int32_t max_sequence_length = 512;

float strength = 1.0;

float prompt_embeds_scale = 1.0;

float pooled_prompt_embeds_scale = 1.0;
};

struct DiTInputParams {
Expand Down
114 changes: 6 additions & 108 deletions xllm/models/dit/clip_text_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,12 @@ limitations under the License.

#pragma once

#include <atb/atb_infer.h>
#include <c10/core/ScalarType.h>
#include <torch/torch.h>

#include <regex>
#include <unordered_map>

#include "core/framework/dit_model_loader.h"
#include "core/framework/kv_cache/kv_cache.h"
#include "core/framework/model/model_input_params.h"
#include "core/framework/model_context.h"
#include "core/layers/npu/npu_siglip_encoder_layer_impl.h"
#include "dit_linear.h"
#include "models/model_registry.h"
#include "processors/clip_image_processor.h"
#include "processors/input_processor.h"
#include "processors/pywarpper_image_processor.h"
#include "xllm_kernels/core/include/atb_speed/log.h"

namespace xllm {
// clip_text_model compatible with huggingface weights
Expand All @@ -59,96 +47,6 @@ torch::Tensor _create_4d_causal_attention_mask(torch::IntArrayRef input_shape,
return causal_mask;
}

class CLIPVLInputProcessor : public InputProcessor {
enum class TokenType {
INVALID,
IMAGE,
VIDEO,
};

public:
explicit CLIPVLInputProcessor(const ModelArgs& args) {
merge_size_ = args.mm_image_merge_size();
}
void process(std::string& prompt, const MMData& mm_data) override {
torch::Tensor image_grid_thw;
if (auto res = mm_data.get<torch::Tensor>("image_grid_thw"))
image_grid_thw = res.value();
torch::Tensor video_grid_thw;
if (auto res = mm_data.get<torch::Tensor>("video_grid_thw"))
video_grid_thw = res.value();
if (!image_grid_thw.defined() && !video_grid_thw.defined()) return;
auto merge_length = merge_size_ * merge_size_;
int total_image_token = 0;
if (image_grid_thw.defined()) {
auto count = image_grid_thw.sizes()[0];
for (int idx = 0; idx < count; ++idx)
total_image_token +=
image_grid_thw[idx].prod().item<int>() / merge_length;
}
int total_video_token = 0;
if (video_grid_thw.defined()) {
auto count = video_grid_thw.sizes()[0];
for (int idx = 0; idx < count; ++idx)
total_video_token +=
video_grid_thw[idx].prod().item<int>() / merge_length;
}
size_t total_token_len = total_image_token * image_token_.size() +
total_video_token * video_token_.size();
std::string data;
data.reserve(prompt.size() + total_token_len);
int image_index = 0;
int video_index = 0;
const torch::Tensor* grid_thw = nullptr;
const std::string* token = nullptr;
int* index = 0;
size_t begin = 0;
auto pair = _find_vision_token(prompt, begin);
while (pair.second != std::string::npos) {
data.append(prompt, begin, pair.second - begin);
if (pair.first == TokenType::IMAGE) {
grid_thw = &image_grid_thw;
token = &image_token_;
index = &image_index;
} else if (pair.first == TokenType::VIDEO) {
grid_thw = &video_grid_thw;
token = &video_token_;
index = &video_index;
} else {
assert(false);
}
auto token_num = (*grid_thw)[(*index)].prod().item<int>() / merge_length;
while (token_num--) data.append(*token);
++(*index);
begin = pair.second + token->size();
pair = _find_vision_token(prompt, begin);
}
if (begin < prompt.size()) data.append(prompt, begin, std::string::npos);
prompt = std::move(data);
}

private:
std::pair<TokenType, size_t> _find_vision_token(const std::string& prompt,
size_t begin) {
auto img_pos = prompt.find(image_token_, begin);
auto vid_pos = prompt.find(video_token_, begin);
if (img_pos == std::string::npos && vid_pos == std::string::npos)
return {TokenType::INVALID, std::string::npos};
else if (vid_pos == std::string::npos)
return {TokenType::IMAGE, img_pos};
else if (img_pos == std::string::npos)
return {TokenType::VIDEO, vid_pos};
else
return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos)
: std::make_pair(TokenType::VIDEO, vid_pos);
}

private:
const std::string image_token_ = "<|image_pad|>";
const std::string video_token_ = "<|video_pad|>";
int merge_size_ = 0;
};

class CLIPTextEmbeddingImpl : public torch::nn::Module {
public:
explicit CLIPTextEmbeddingImpl(const ModelContext& context) {
Expand Down Expand Up @@ -189,23 +87,23 @@ class CLIPTextEmbeddingImpl : public torch::nn::Module {
weight::load_weight(state_dict,
"token_embedding.weight",
token_embedding_->weight,
is_token_embedding_loaded);
is_token_embedding_loaded_);
weight::load_weight(state_dict,
"position_embedding.weight",
position_embedding_,
is_position_embedding_loaded);
is_position_embedding_loaded_);
}

void verify_loaded_weights(const std::string& prefix) const {
CHECK(is_position_embedding_loaded)
CHECK(is_position_embedding_loaded_)
<< "weight is not loaded for " << prefix + "position_embedding.weight";
CHECK(is_token_embedding_loaded)
CHECK(is_token_embedding_loaded_)
<< "weight is not loaded for " << prefix + "token_embedding.weight";
}

private:
bool is_position_embedding_loaded = false;
bool is_token_embedding_loaded = false;
bool is_position_embedding_loaded_ = false;
bool is_token_embedding_loaded_ = false;
torch::Tensor position_ids_;
torch::nn::Embedding token_embedding_ = nullptr;
torch::Tensor position_embedding_;
Expand Down
Loading