Skip to content

Commit

Permalink
DashScope: Support Wanx Models (for text-generated images) (langchain…
Browse files Browse the repository at this point in the history
…4j#1710)

## Change
Alibaba uses Wanx models to support text-to-image features (not Qwen),
and provides services on DashScope.
    
See:
https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-wanxiang
    
Integrate them into langchain4j-dashscope as ImageModel.



## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
  • Loading branch information
jiangsier-xyz authored Sep 5, 2024
1 parent 4ee7b8a commit c14c86c
Show file tree
Hide file tree
Showing 14 changed files with 409 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/docs/integrations/language-models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sidebar_position: 0
| [Anthropic](/integrations/language-models/anthropic) ||| text, image | ||
| [Azure OpenAI](/integrations/language-models/azure-open-ai) ||| text, image | | |
| [ChatGLM](/integrations/language-models/chatglm) | | | text | | |
| [DashScope](/integrations/language-models/dashscope) ||| text, image | | |
| [DashScope](/integrations/language-models/dashscope) ||| text, image, audio | | |
| [Google AI Gemini](/integrations/language-models/google-ai-gemini) | || text, image, audio, video, PDF | | |
| [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) ||| text, image, audio, video, PDF | | |
| [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | text | ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ static Map<String, Object> toMultiModalContent(Content content) {
}
}

private static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
String tmpFileName = UUID.randomUUID().toString();
if (Utils.isNotNullOrBlank(mimeType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class QwenModelName {
public static final String QWEN_VL_PLUS = "qwen-vl-plus"; // Qwen multi-modal model, supports image and text information.
public static final String QWEN_VL_MAX = "qwen-vl-max"; // Qwen multi-modal model, offers optimal performance on a wider range of complex tasks.
public static final String QWEN_AUDIO_CHAT = "qwen-audio-chat"; // Qwen open sourced speech model, sft for chatting.
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2), sft for instruction
public static final String QWEN2_AUDIO_INSTRUCT = "qwen2-audio-instruct"; // Qwen open sourced speech model (v2)

// Use with QwenEmbeddingModel
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1"; // Support: en, zh, es, fr, pt, id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.dashscope.spi.QwenTokenizerBuilderFactory;
import lombok.Builder;

import java.util.Collections;

import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.model.dashscope.QwenHelper.toQwenMessages;
import static dev.langchain4j.model.dashscope.QwenModelName.QWEN_PLUS;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;

public class QwenTokenizer implements Tokenizer {
private final String apiKey;
Expand Down Expand Up @@ -94,4 +96,18 @@ public static boolean isBlank(CharSequence cs) {
}
return true;
}

public static QwenTokenizer.QwenTokenizerBuilder builder() {
for (QwenTokenizerBuilderFactory factory : loadFactories(QwenTokenizerBuilderFactory.class)) {
return factory.get();
}
return new QwenTokenizer.QwenTokenizerBuilder();
}

public static class QwenTokenizerBuilder {
public QwenTokenizerBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dev.langchain4j.model.dashscope;

import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisOutput;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.OSSUtils;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Utils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.*;
import java.util.stream.Collectors;

public class WanxHelper {
static List<Image> imagesFrom(ImageSynthesisResult result) {
return Optional.of(result)
.map(ImageSynthesisResult::getOutput)
.map(ImageSynthesisOutput::getResults)
.orElse(Collections.emptyList())
.stream()
.map(resultMap -> resultMap.get("url"))
.map(url -> Image.builder().url(url).build())
.collect(Collectors.toList());
}

static String imageUrl(Image image, String model, String apiKey) {
String imageUrl;

if (image.url() != null) {
imageUrl = image.url().toString();
} else if (Utils.isNotNullOrBlank(image.base64Data())) {
String filePath = saveDataAsTemporaryFile(image.base64Data(), image.mimeType());
try {
imageUrl = OSSUtils.upload(model, filePath, apiKey);
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
} else {
throw new IllegalArgumentException("Failed to get image url from " + image);
}

return imageUrl;
}

static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
String tmpDir = System.getProperty("java.io.tmpdir", "/tmp");
String tmpFileName = UUID.randomUUID().toString();
if (Utils.isNotNullOrBlank(mimeType)) {
// e.g. "image/png", "image/jpeg"...
int lastSlashIndex = mimeType.lastIndexOf("/");
if (lastSlashIndex >= 0 && lastSlashIndex < mimeType.length() - 1) {
String fileSuffix = mimeType.substring(lastSlashIndex + 1);
tmpFileName = tmpFileName + "." + fileSuffix;
}
}

Path tmpFilePath = Paths.get(tmpDir, tmpFileName);
byte[] data = Base64.getDecoder().decode(base64Data);
try {
Files.copy(new ByteArrayInputStream(data), tmpFilePath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
return tmpFilePath.toAbsolutePath().toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package dev.langchain4j.model.dashscope;

import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.dashscope.spi.WanxImageModelBuilderFactory;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.util.List;

import static dev.langchain4j.model.dashscope.WanxHelper.imageUrl;
import static dev.langchain4j.model.dashscope.WanxHelper.imagesFrom;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;

/**
* Represents a Wanx models to generate artistic images.
* More details are available <a href="https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9">here</a>.
*/
public class WanxImageModel implements ImageModel {
private final String apiKey;
private final String modelName;
// The generation method of the reference image. The optional values are
// 'repaint' and 'refonly'; repaint represents the reference content and
// refonly represents the reference style. Default is 'repaint'.
private final WanxImageRefMode refMode;
// The similarity between the expected output result and the reference image,
// the value range is [0.0, 1.0]. The larger the number, the more similar the
// generated result is to the reference image. Default is 0.5.
private final Float refStrength;
private final Integer seed;
// The resolution of the generated image currently only supports '1024*1024',
// '720*1280', and '1280*720' resolutions. Default is '1024*1024'.
private final WanxImageSize size;
private final WanxImageStyle style;
private final ImageSynthesis imageSynthesis;

@Builder
public WanxImageModel(String baseUrl,
String apiKey,
String modelName,
WanxImageRefMode refMode,
Float refStrength,
Integer seed,
WanxImageSize size,
WanxImageStyle style) {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
this.modelName = Utils.isNullOrBlank(modelName) ? WanxModelName.WANX_V1 : modelName;
this.apiKey = apiKey;
this.refMode = refMode;
this.refStrength = refStrength;
this.seed = seed;
this.size = size;
this.style = style;
this.imageSynthesis = Utils.isNullOrBlank(baseUrl) ? new ImageSynthesis() : new ImageSynthesis("text2image", baseUrl);
}

@Override
public Response<Image> generate(String prompt) {
ImageSynthesisParam param = requestBuilder(prompt).n(1).build();

try {
ImageSynthesisResult result = imageSynthesis.call(param);
return Response.from(imagesFrom(result).get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}

@Override
public Response<List<Image>> generate(String prompt, int n) {
ImageSynthesisParam param = requestBuilder(prompt).n(n).build();

try {
ImageSynthesisResult result = imageSynthesis.call(param);
return Response.from(imagesFrom(result));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}

@Override
public Response<Image> edit(Image image, String prompt) {
String imageUrl = imageUrl(image, modelName, apiKey);

ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = requestBuilder(prompt)
.refImage(imageUrl)
.n(1);

if (imageUrl.startsWith("oss://")) {
builder.header("X-DashScope-OssResourceResolve", "enable");
}

try {
ImageSynthesisResult result = imageSynthesis.call(builder.build());
return Response.from(imagesFrom(result).get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
}
}

private ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> requestBuilder(String prompt) {
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = ImageSynthesisParam.builder()
.apiKey(apiKey)
.model(modelName)
.prompt(prompt);

if (seed != null) {
builder.seed(seed);
}

if (size != null) {
builder.size(size.toString());
}

if (style != null) {
builder.style(style.toString());
}

if (refMode != null) {
builder.parameter("ref_mode", refMode.toString());
}

if (refStrength != null) {
builder.parameter("ref_strength", refStrength);
}

return builder;
}

public static WanxImageModel.WanxImageModelBuilder builder() {
for (WanxImageModelBuilderFactory factory : loadFactories(WanxImageModelBuilderFactory.class)) {
return factory.get();
}
return new WanxImageModel.WanxImageModelBuilder();
}

public static class WanxImageModelBuilder {
public WanxImageModelBuilder() {
// This is public so it can be extended
// By default with Lombok it becomes package private
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.langchain4j.model.dashscope;

public enum WanxImageRefMode {
REPAINT("repaint"),
REFONLY("refonly");

private final String mode;

WanxImageRefMode(String mode) {
this.mode = mode;
}

@Override
public String toString() {
return mode;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.model.dashscope;

public enum WanxImageSize {
SIZE_1024_1024("1024*1024"),
SIZE_720_1280("720*1280"),
SIZE_1280_720("1280*720");

private final String size;

WanxImageSize(String size) {
this.size = size;
}

@Override
public String toString() {
return size;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dev.langchain4j.model.dashscope;

public enum WanxImageStyle {
PHOTOGRAPHY("<photography>"),
PORTRAIT("<portrait>"),
CARTOON_3D("<3d cartoon>"),
ANIME("<anime>"),
OIL_PAINTING("<oil painting>"),
WATERCOLOR("<watercolor>"),
SKETCH("<sketch>"),
CHINESE_PAINTING("<chinese painting>"),
FLAT_ILLUSTRATION("<flat illustration>"),
AUTO("<auto>");

private final String style;

WanxImageStyle(String style) {
this.style = style;
}

@Override
public String toString() {
return style;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package dev.langchain4j.model.dashscope;

public class WanxModelName {
// Use with WanxImageModel
public static final String WANX_V1 = "wanx-v1"; // Wanx model for text-generated images, supports Chinese and English
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.langchain4j.model.dashscope.spi;

import dev.langchain4j.model.dashscope.QwenTokenizer;

import java.util.function.Supplier;

public interface QwenTokenizerBuilderFactory extends Supplier<QwenTokenizer.QwenTokenizerBuilder> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.langchain4j.model.dashscope.spi;

import dev.langchain4j.model.dashscope.WanxImageModel;

import java.util.function.Supplier;

public interface WanxImageModelBuilderFactory extends Supplier<WanxImageModel.WanxImageModelBuilder> {
}
Loading

0 comments on commit c14c86c

Please sign in to comment.