Skip to content

Commit 8f6ba31

Browse files
committed
add streaming feature
1 parent a09b133 commit 8f6ba31

File tree

1 file changed

+337
-2
lines changed

1 file changed

+337
-2
lines changed

examples/sd-server/main.cpp

Lines changed: 337 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <optional>
1212
#include <sstream>
1313
#include <iostream>
14+
#include <memory>
1415
#include <limits>
1516
#include <mutex>
1617
#include <random>
@@ -509,6 +510,315 @@ struct ImageResultGuard {
509510
}
510511
};
511512

513+
json logs_to_json(const LogCollector& collector);
514+
json make_telemetry(const LogCollector& collector,
515+
const GenerationRequest& request,
516+
const CtxConfig& config,
517+
int64_t elapsed_ms,
518+
int64_t effective_seed);
519+
520+
class StreamingImageResponder {
521+
public:
522+
StreamingImageResponder(ServerState& state,
523+
std::unique_lock<std::mutex>&& ctx_lock,
524+
std::unique_ptr<LogCaptureScope>&& capture_scope,
525+
std::shared_ptr<LogCollector> collector,
526+
GenerationRequest request,
527+
CtxConfig ctx_config,
528+
bool random_seed_requested,
529+
int64_t effective_seed)
530+
: state_(state),
531+
ctx_lock_(std::move(ctx_lock)),
532+
capture_scope_(std::move(capture_scope)),
533+
collector_(std::move(collector)),
534+
request_(std::move(request)),
535+
ctx_config_(std::move(ctx_config)),
536+
random_seed_requested_(random_seed_requested),
537+
effective_seed_(effective_seed),
538+
default_sample_method_(sd_get_default_sample_method(state.ctx)),
539+
start_time_(std::chrono::steady_clock::now()) {}
540+
541+
~StreamingImageResponder() {
542+
finalize_resources();
543+
}
544+
545+
bool next(httplib::DataSink& sink) {
546+
if (done_) {
547+
return false;
548+
}
549+
550+
if (next_index_ < request_.batch_count) {
551+
if (!emit_image_chunk(sink, next_index_)) {
552+
done_ = true;
553+
return false;
554+
}
555+
++next_index_;
556+
return true;
557+
}
558+
559+
emit_final_summary(sink);
560+
done_ = true;
561+
return false;
562+
}
563+
564+
void cancel() {
565+
done_ = true;
566+
finalize_resources();
567+
}
568+
569+
private:
570+
bool emit_image_chunk(httplib::DataSink& sink, int index) {
571+
sd_img_gen_params_t params;
572+
sd_img_gen_params_init(&params);
573+
574+
params.prompt = request_.prompt.c_str();
575+
params.negative_prompt = request_.negative_prompt.c_str();
576+
params.clip_skip = request_.clip_skip;
577+
params.width = request_.width;
578+
params.height = request_.height;
579+
params.batch_count = 1;
580+
params.seed = effective_seed_ + index;
581+
if (request_.has_vae_tiling_override) {
582+
params.vae_tiling_params = request_.vae_tiling_params;
583+
}
584+
585+
sd_sample_params_t& sample_params = params.sample_params;
586+
sample_params.sample_steps = request_.sample_steps;
587+
sample_params.guidance.txt_cfg = request_.cfg_scale;
588+
if (request_.has_img_cfg_scale) {
589+
sample_params.guidance.img_cfg = request_.img_cfg_scale;
590+
}
591+
if (!std::isfinite(sample_params.guidance.img_cfg)) {
592+
sample_params.guidance.img_cfg = sample_params.guidance.txt_cfg;
593+
}
594+
if (request_.override_sample_method) {
595+
sample_params.sample_method = request_.sample_method;
596+
}
597+
if (sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
598+
sample_params.sample_method = default_sample_method_;
599+
}
600+
if (request_.override_scheduler) {
601+
sample_params.scheduler = request_.scheduler;
602+
}
603+
if (request_.has_eta) {
604+
sample_params.eta = request_.eta;
605+
}
606+
sample_params.shifted_timestep = request_.shifted_timestep;
607+
608+
sd_image_t* results = generate_image(state_.ctx, &params);
609+
if (results == nullptr) {
610+
emit_error(sink, "image generation failed", index);
611+
return false;
612+
}
613+
614+
ImageResultGuard guard{results, params.batch_count};
615+
616+
sd_image_t& image = results[0];
617+
if (image.data == nullptr) {
618+
emit_error(sink, "image data is empty", index);
619+
return false;
620+
}
621+
622+
auto encode_start = std::chrono::steady_clock::now();
623+
624+
int png_size = 0;
625+
unsigned char* png_data = stbi_write_png_to_mem(image.data, 0, image.width, image.height, image.channel, &png_size, nullptr);
626+
if (png_data == nullptr) {
627+
emit_error(sink, "failed to encode PNG", index);
628+
return false;
629+
}
630+
std::string encoded = base64_encode(png_data, static_cast<size_t>(png_size));
631+
STBIW_FREE(png_data);
632+
633+
auto encode_end = std::chrono::steady_clock::now();
634+
const double encode_ms = std::chrono::duration_cast<std::chrono::microseconds>(encode_end - encode_start).count() / 1000.0;
635+
const std::size_t encoded_size = encoded.size();
636+
637+
// Preserve the legacy -1 seed while still reporting the concrete seed that was used.
638+
int64_t actual_seed = random_seed_requested_ ? (effective_seed_ + index) : (request_.seed + index);
639+
int64_t reported_seed = random_seed_requested_ ? -1 : actual_seed;
640+
641+
json image_chunk = json::object();
642+
image_chunk["type"] = "image";
643+
image_chunk["index"] = index;
644+
image_chunk["seed"] = reported_seed;
645+
image_chunk["actual_seed"] = actual_seed;
646+
image_chunk["width"] = image.width;
647+
image_chunk["height"] = image.height;
648+
image_chunk["format"] = "png";
649+
image_chunk["mime_type"] = "image/png";
650+
image_chunk["payload_bytes"] = png_size;
651+
image_chunk["encoded_bytes"] = static_cast<int64_t>(encoded_size);
652+
image_chunk["encode_ms"] = encode_ms;
653+
image_chunk["data"] = std::move(encoded);
654+
655+
auto prepare_end = std::chrono::steady_clock::now();
656+
const double prepare_ms = std::chrono::duration_cast<std::chrono::microseconds>(prepare_end - encode_start).count() / 1000.0;
657+
image_chunk["dispatch_prepare_ms"] = prepare_ms;
658+
659+
std::size_t serialized_bytes = 0;
660+
if (!write_json_array_item(sink, image_chunk, false, &serialized_bytes)) {
661+
done_ = true;
662+
finalize_resources();
663+
return false;
664+
}
665+
666+
auto dispatch_end = std::chrono::steady_clock::now();
667+
const double dispatch_total_ms = std::chrono::duration_cast<std::chrono::microseconds>(dispatch_end - encode_start).count() / 1000.0;
668+
const double write_ms = std::chrono::duration_cast<std::chrono::microseconds>(dispatch_end - prepare_end).count() / 1000.0;
669+
670+
json summary_entry = json::object();
671+
summary_entry["index"] = index;
672+
summary_entry["seed"] = reported_seed;
673+
summary_entry["actual_seed"] = actual_seed;
674+
summary_entry["width"] = image.width;
675+
summary_entry["height"] = image.height;
676+
summary_entry["format"] = "png";
677+
summary_entry["mime_type"] = "image/png";
678+
summary_entry["streamed"] = true;
679+
summary_entry["encode_ms"] = encode_ms;
680+
summary_entry["dispatch_prepare_ms"] = prepare_ms;
681+
summary_entry["dispatch_total_ms"] = dispatch_total_ms;
682+
summary_entry["write_ms"] = write_ms;
683+
summary_entry["payload_bytes"] = png_size;
684+
summary_entry["encoded_bytes"] = static_cast<int64_t>(encoded_size);
685+
summary_entry["serialized_bytes"] = static_cast<int64_t>(serialized_bytes);
686+
image_summaries_.push_back(std::move(summary_entry));
687+
688+
return true;
689+
}
690+
691+
void emit_error(httplib::DataSink& sink, const std::string& message, int index) {
692+
encountered_error_ = true;
693+
done_ = true;
694+
const int64_t elapsed = elapsed_ms();
695+
json error_chunk = json::object();
696+
error_chunk["type"] = "error";
697+
error_chunk["success"] = false;
698+
error_chunk["error"] = message;
699+
error_chunk["index"] = index;
700+
error_chunk["requested_seed"] = request_.seed;
701+
error_chunk["applied_seed"] = effective_seed_;
702+
error_chunk["random_seed_requested"] = random_seed_requested_;
703+
error_chunk["elapsed_ms"] = elapsed;
704+
if (!ctx_config_.model_path.empty()) {
705+
error_chunk["model_path"] = ctx_config_.model_path;
706+
}
707+
error_chunk["logs"] = logs_to_json(*collector_);
708+
error_chunk["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_);
709+
if (write_json_array_item(sink, error_chunk, true)) {
710+
finalize_stream(sink);
711+
} else {
712+
finalize_resources();
713+
}
714+
}
715+
716+
void emit_final_summary(httplib::DataSink& sink) {
717+
const int64_t elapsed = elapsed_ms();
718+
json summary = json::object();
719+
summary["type"] = "complete";
720+
summary["success"] = !encountered_error_;
721+
summary["batch_count"] = request_.batch_count;
722+
summary["requested_seed"] = request_.seed;
723+
summary["applied_seed"] = effective_seed_;
724+
summary["random_seed_requested"] = random_seed_requested_;
725+
summary["elapsed_ms"] = elapsed;
726+
if (!ctx_config_.model_path.empty()) {
727+
summary["model_path"] = ctx_config_.model_path;
728+
}
729+
summary["images"] = image_summaries_;
730+
summary["logs"] = logs_to_json(*collector_);
731+
summary["telemetry"] = make_telemetry(*collector_, request_, ctx_config_, elapsed, effective_seed_);
732+
done_ = true;
733+
if (write_json_array_item(sink, summary, true)) {
734+
finalize_stream(sink);
735+
} else {
736+
finalize_resources();
737+
}
738+
}
739+
740+
int64_t elapsed_ms() const {
741+
auto end_time = std::chrono::steady_clock::now();
742+
return std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time_).count();
743+
}
744+
745+
bool write_json_array_item(httplib::DataSink& sink,
746+
const json& payload,
747+
bool final_item,
748+
std::size_t* serialized_size = nullptr) {
749+
std::string serialized = payload.dump();
750+
if (serialized_size != nullptr) {
751+
*serialized_size = serialized.size();
752+
}
753+
754+
if (!array_opened_) {
755+
const char prefix[] = "[\n";
756+
if (!sink.write(prefix, sizeof(prefix) - 1)) {
757+
return false;
758+
}
759+
sink.os.flush();
760+
array_opened_ = true;
761+
}
762+
763+
if (!first_object_) {
764+
const char separator[] = ",\n";
765+
if (!sink.write(separator, sizeof(separator) - 1)) {
766+
return false;
767+
}
768+
sink.os.flush();
769+
}
770+
771+
std::string chunk = std::move(serialized);
772+
if (final_item) {
773+
chunk.append("\n]");
774+
}
775+
chunk.push_back('\n');
776+
bool ok = sink.write(chunk.data(), chunk.size());
777+
if (ok) {
778+
sink.os.flush();
779+
}
780+
first_object_ = false;
781+
return ok;
782+
}
783+
784+
void finalize_stream(httplib::DataSink& sink) {
785+
if (sink.done) {
786+
sink.done();
787+
}
788+
finalize_resources();
789+
}
790+
791+
void finalize_resources() {
792+
if (finalized_) {
793+
return;
794+
}
795+
capture_scope_.reset();
796+
collector_.reset();
797+
if (ctx_lock_.owns_lock()) {
798+
ctx_lock_.unlock();
799+
}
800+
finalized_ = true;
801+
}
802+
803+
ServerState& state_;
804+
std::unique_lock<std::mutex> ctx_lock_;
805+
std::unique_ptr<LogCaptureScope> capture_scope_;
806+
std::shared_ptr<LogCollector> collector_;
807+
GenerationRequest request_;
808+
CtxConfig ctx_config_;
809+
bool random_seed_requested_ = false;
810+
int64_t effective_seed_ = 0;
811+
sample_method_t default_sample_method_ = SAMPLE_METHOD_DEFAULT;
812+
std::chrono::steady_clock::time_point start_time_;
813+
int next_index_ = 0;
814+
bool done_ = false;
815+
bool encountered_error_ = false;
816+
bool finalized_ = false;
817+
bool array_opened_ = false;
818+
bool first_object_ = true;
819+
std::vector<json> image_summaries_;
820+
};
821+
512822
bool apply_context_overrides(const json& body, CtxConfig& config, std::string& error) {
513823
auto assign_string = [&](const char* key, std::string& target) -> bool {
514824
auto it = body.find(key);
@@ -1830,7 +2140,8 @@ int main(int argc, char** argv) {
18302140
});
18312141

18322142
server.Post("/generate", [&](const httplib::Request& req, httplib::Response& res) {
1833-
LogCollector collector;
2143+
auto collector_ptr = std::make_shared<LogCollector>();
2144+
LogCollector& collector = *collector_ptr;
18342145

18352146
json body;
18362147
try {
@@ -1852,7 +2163,7 @@ int main(int argc, char** argv) {
18522163
}
18532164

18542165
std::unique_lock<std::mutex> lock(state.mutex);
1855-
LogCaptureScope capture(state, collector);
2166+
auto capture_scope = std::make_unique<LogCaptureScope>(state, collector);
18562167

18572168
CtxConfig desired_config = state.ctx_config;
18582169
if (desired_config.model_path.empty()) {
@@ -1886,6 +2197,30 @@ int main(int argc, char** argv) {
18862197
effective_seed = generate_random_seed();
18872198
}
18882199

2200+
const bool enable_streaming = request_params.batch_count > 1;
2201+
if (enable_streaming) {
2202+
GenerationRequest streaming_request = std::move(request_params);
2203+
CtxConfig active_config = state.ctx_config;
2204+
auto streaming_responder = std::make_shared<StreamingImageResponder>(state,
2205+
std::move(lock),
2206+
std::move(capture_scope),
2207+
collector_ptr,
2208+
std::move(streaming_request),
2209+
std::move(active_config),
2210+
random_seed_requested,
2211+
effective_seed);
2212+
res.status = 200;
2213+
res.set_chunked_content_provider(
2214+
"application/json",
2215+
[streaming_responder](size_t, httplib::DataSink& sink) {
2216+
return streaming_responder->next(sink);
2217+
},
2218+
[streaming_responder](bool) {
2219+
streaming_responder->cancel();
2220+
});
2221+
return;
2222+
}
2223+
18892224
sd_img_gen_params_t img_params;
18902225
sd_img_gen_params_init(&img_params);
18912226

0 commit comments

Comments
 (0)