适配任务列表
diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py
index 5be4842efee4..83270831ab45 100644
--- a/paddlenlp/taskflow/taskflow.py
+++ b/paddlenlp/taskflow/taskflow.py
@@ -37,6 +37,7 @@
from .dialogue import DialogueTask
from .information_extraction import UIETask
from .code_generation import CodeGenerationTask
+from .text2image_generation import Text2ImageGenerationTask
warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)
@@ -317,6 +318,46 @@
},
"default": {
"model": "Salesforce/codegen-350M-mono",
+ },
+ },
+ "text2image_generation": {
+ "models": {
+ "dalle-mini": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag": "text2image_generation-dalle-mini",
+ "task_priority_path": "dalle-mini",
+ },
+ "dalle-mega-v16": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag": "text2image_generation-dalle-mega-v16",
+ "task_priority_path": "dalle-mega-v16",
+ },
+ "dalle-mega": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag": "text2image_generation-dalle-mega",
+ "task_priority_path": "dalle-mega",
+ },
+ "pai-painter-painting-base-zh": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag":
+ "text2image_generation-pai-painter-painting-base-zh",
+ "task_priority_path": "pai-painter-painting-base-zh",
+ },
+ "pai-painter-scenery-base-zh": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag":
+ "text2image_generation-pai-painter-scenery-base-zh",
+ "task_priority_path": "pai-painter-scenery-base-zh",
+ },
+ "pai-painter-commercial-base-zh": {
+ "task_class": Text2ImageGenerationTask,
+ "task_flag":
+ "text2image_generation-pai-painter-commercial-base-zh",
+ "task_priority_path": "pai-painter-commercial-base-zh",
+ },
+ },
+ "default": {
+ "model": "pai-painter-painting-base-zh",
}
}
}
diff --git a/paddlenlp/taskflow/text2image_generation.py b/paddlenlp/taskflow/text2image_generation.py
new file mode 100644
index 000000000000..9fa3ea7e4e5b
--- /dev/null
+++ b/paddlenlp/taskflow/text2image_generation.py
@@ -0,0 +1,144 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import numpy as np
+from PIL import Image
+from ..transformers import AutoModelForImageGeneration, AutoTokenizer
+from .task import Task
+
+usage = r"""
+ from paddlenlp import Taskflow
+
+ text2imagegen = Taskflow("text2image_generation")
+ images = text2imagegen("风阁水帘今在眼,且来先看早梅红")
+ images[0].save("figure.png")
+
+ """
+
+
+class Text2ImageGenerationTask(Task):
+ """
+ The text2image generation model to generate the image.
+ Args:
+ task(string): The name of task.
+ model(string): The model name in the task.
+ kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
+ """
+
+ def __init__(self, task, model="pai-painter-painting-base-zh", **kwargs):
+ super().__init__(task=task, model=model, **kwargs)
+ self._batch_size = kwargs.get("batch_size", 1)
+ self._temperature = kwargs.get("temperature", 1.)
+ self._top_k = kwargs.get("top_k", 32)
+ self._top_p = kwargs.get("top_p", 1.)
+ self._condition_scale = kwargs.get("condition_scale", 10.)
+ self._num_return_images = kwargs.get("num_return_images", 4)
+ self._use_faster = kwargs.get("use_faster", False)
+ self._use_fp16_decoding = kwargs.get("use_fp16_decoding", False)
+ self._construct_tokenizer(model)
+ self._construct_model(model)
+
+ def _construct_model(self, model):
+ """
+ Construct the inference model for the predictor.
+ """
+ self._model = AutoModelForImageGeneration.from_pretrained(model)
+ self._model.eval()
+
+ def _construct_tokenizer(self, model):
+ """
+ Construct the tokenizer for the predictor.
+ """
+ self._tokenizer = AutoTokenizer.from_pretrained(model)
+
+ def _batchify(self, data, batch_size):
+ """
+ Generate input batches.
+ """
+
+ def _parse_batch(batch_examples):
+ tokenizerd_inputs = self._tokenizer(batch_examples,
+ return_tensors="pd",
+ padding="max_length",
+ truncation=True)
+ if self._model.base_model_prefix == "dallebart":
+ tokenizerd_inputs["condition_scale"] = self._condition_scale
+ return tokenizerd_inputs
+
+ # Seperates data into some batches.
+ one_batch = []
+ for example in data:
+ one_batch.append(example)
+ if len(one_batch) == batch_size:
+ yield _parse_batch(one_batch)
+ one_batch = []
+ if one_batch:
+ yield _parse_batch(one_batch)
+
+ def _preprocess(self, inputs):
+ """
+ Transform the raw text to the model inputs, two steps involved:
+ 1) Transform the raw text to token ids.
+ 2) Generate the other model inputs from the raw text and token ids.
+ """
+ inputs = self._check_input_text(inputs)
+ batches = self._batchify(inputs, self._batch_size)
+ outputs = {'batches': batches, 'text': inputs}
+ return outputs
+
+ def _run_model(self, inputs):
+ """
+ Run the task model from the outputs of the `_preprocess` function.
+ """
+ all_images = []
+
+ for batch_inputs in inputs["batches"]:
+ images = self._model.generate(
+ **batch_inputs,
+ temperature=self._temperature,
+ top_k=self._top_k,
+ top_p=self._top_p,
+ num_return_sequences=self._num_return_images,
+ use_faster=self._use_faster,
+ use_fp16_decoding=self._use_fp16_decoding)
+ all_images.append(images.numpy())
+ inputs['images'] = np.concatenate(all_images, axis=0)
+ return inputs
+
+ def _postprocess(self, inputs):
+ """
+ The model output is images, this function will convert the model output to PIL Image.
+ """
+ batch_out = []
+ generated_images = inputs['images']
+ # [batch_size, num_return_sequences, 256, 256, 3] -> [batch_size, 256, num_return_sequences*256, 3]
+ generated_images = generated_images.transpose([0, 2, 1, 3, 4]).reshape([
+ -1, generated_images.shape[-3],
+ self._num_return_images * generated_images.shape[-2],
+ generated_images.shape[-1]
+ ])
+ for generated_image in generated_images:
+ batch_out.append(Image.fromarray(generated_image))
+
+ return batch_out
+
+ def _construct_input_spec(self):
+ """
+ Construct the input spec for the convert dygraph model to static model.
+ """
+ self._input_spec = [
+ paddle.static.InputSpec(shape=[None, None],
+ dtype="int64",
+ name='input_ids'),
+ ]
diff --git a/paddlenlp/transformers/artist/modeling.py b/paddlenlp/transformers/artist/modeling.py
index 7f9946d0bcd6..c86579a76872 100644
--- a/paddlenlp/transformers/artist/modeling.py
+++ b/paddlenlp/transformers/artist/modeling.py
@@ -212,7 +212,7 @@ def generate(self,
Returns:
Tensor: Returns tensor `images`, which is the output of :class:`VQGanDetokenizer`.
- Its data type should be float32 and has a shape of [batch_size, num_return_sequences, 256, 256, 3].
+ Its data type should be uint8 and has a shape of [batch_size, num_return_sequences, 256, 256, 3].
Example:
.. code-block::
@@ -228,28 +228,18 @@ def generate(self,
# Prepare the model inputs.
prompts = ["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"]
- tokenized_inputs = tokenizer(
- prompts,
- return_tensors="pd",
- padding="max_length",
- truncation=True,
- return_token_type_ids=False,
- return_attention_mask=False,
- max_length=32,
- )
+ tokenized_inputs = tokenizer(prompts, return_tensors="pd")
top_k = 32
num_return_sequences = 4
images = model.generate(**tokenized_inputs,
top_k=top_k,
num_return_sequences=num_return_sequences)
- print(images.shape)
- # [2, 4, 256, 256, 3]
- images = ((images.cpu().numpy() + 1.0) * 127.5).clip(0, 255).astype("uint8")
+ print(images.shape) # [2, 4, 256, 256, 3]
# [2, 256, 4*256, 3]
- images = images.transpose([0, 2, 1, 3,
- 4]).reshape(-1, images.shape[-3],
+ images = images.numpy().transpose([0, 2, 1, 3,
+ 4]).reshape([-1, images.shape[-3],
num_return_sequences * images.shape[-2],
- images.shape[-1])
+ images.shape[-1]])
for i, image in enumerate(images):
image = Image.fromarray(image)
image.save(f"figure_{i}.png")
@@ -273,4 +263,5 @@ def generate(self,
-1, num_return_sequences, images.shape[1], images.shape[2],
images.shape[3]
])
- return images
+ images = ((images + 1.0) * 127.5).clip(0, 255).astype("uint8")
+ return images
\ No newline at end of file
diff --git a/paddlenlp/transformers/artist/tokenizer.py b/paddlenlp/transformers/artist/tokenizer.py
index befd11849886..5e5a74c3f147 100644
--- a/paddlenlp/transformers/artist/tokenizer.py
+++ b/paddlenlp/transformers/artist/tokenizer.py
@@ -203,11 +203,11 @@ def __call__(
self,
text,
text_pair=None,
- max_length=None,
+ max_length=32, # default
stride=0,
is_split_into_words=False,
- padding=False,
- truncation=False,
+ padding="max_length", # default
+ truncation=True, # default
return_position_ids=False,
return_token_type_ids=False, # don't return token_type_ids
return_attention_mask=False,
diff --git a/paddlenlp/transformers/dallebart/modeling.py b/paddlenlp/transformers/dallebart/modeling.py
index 95b91de30655..ef661acd61bd 100644
--- a/paddlenlp/transformers/dallebart/modeling.py
+++ b/paddlenlp/transformers/dallebart/modeling.py
@@ -1729,30 +1729,23 @@ def generate(self,
sequences for each sequence in the batch. Default to 1.
Returns:
Tensor: Returns tensor `images`, which is the output of :class:`VQGanDetokenizer`.
- Its data type should be float32 and has a shape of [batch_size, num_return_sequences, 256, 256, 3].
+ Its data type should be uint8 and has a shape of [batch_size, num_return_sequences, 256, 256, 3].
Example:
.. code-block::
import paddle
- from paddlenlp.transformers import DalleBartForImageGeneration, DalleBartTokenizer
+ from paddlenlp.transformers import AutoModelForImageGeneration, AutoTokenizer
from PIL import Image
# Initialize the model and tokenizer
model_name_or_path = 'dalle-mini'
- model = DalleBartForImageGeneration.from_pretrained(model_name_or_path)
- tokenizer = DalleBartTokenizer.from_pretrained(model_name_or_path)
+ model = AutoModelForImageGeneration.from_pretrained(model_name_or_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.eval()
# Prepare the model inputs.
prompts = ["graphite sketch of Elon Musk", "Mohanlal graphite sketch"]
- tokenized_inputs = tokenizer(
- prompts,
- return_tensors="pd",
- padding="max_length",
- truncation=True,
- return_attention_mask=True,
- max_length=64,
- )
+ tokenized_inputs = tokenizer(prompts, return_tensors="pd")
top_k = 32
condition_scale = 16.0
num_return_sequences = 4
@@ -1760,14 +1753,12 @@ def generate(self,
top_k=top_k,
condition_scale=condition_scale,
num_return_sequences=num_return_sequences)
- print(images.shape)
- # [2, 4, 256, 256, 3]
- images = (images.cpu().numpy().clip(0, 1) * 255).astype("uint8")
+ print(images.shape) # [2, 4, 256, 256, 3]
# [2, 256, 4*256, 3]
- images = images.transpose([0, 2, 1, 3,
- 4]).reshape(-1, images.shape[-3],
+ images = images.numpy().transpose([0, 2, 1, 3,
+ 4]).reshape([-1, images.shape[-3],
num_return_sequences * images.shape[-2],
- images.shape[-1])
+ images.shape[-1]])
for i, image in enumerate(images):
image = Image.fromarray(image)
image.save(f"figure_{i}.png")
@@ -1787,4 +1778,5 @@ def generate(self,
-1, num_return_sequences, images.shape[1], images.shape[2],
images.shape[3]
])
+ images = (images.clip(0, 1) * 255).astype("uint8")
return images
diff --git a/paddlenlp/transformers/dallebart/tokenizer.py b/paddlenlp/transformers/dallebart/tokenizer.py
index 49f745ac852f..9be91793ab21 100644
--- a/paddlenlp/transformers/dallebart/tokenizer.py
+++ b/paddlenlp/transformers/dallebart/tokenizer.py
@@ -490,27 +490,28 @@ def create_token_type_ids_from_sequences(self,
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
- def __call__(self,
- text,
- text_pair=None,
- max_length=None,
- stride=0,
- is_split_into_words=False,
- padding=False,
- truncation=False,
- return_position_ids=False,
- return_token_type_ids=False,
- return_attention_mask=False,
- return_length=False,
- return_overflowing_tokens=False,
- return_special_tokens_mask=False,
- return_dict=True,
- return_offsets_mapping=False,
- add_special_tokens=True,
- pad_to_multiple_of=None,
- return_tensors=None,
- verbose: bool = True,
- **kwargs):
+ def __call__(
+ self,
+ text,
+ text_pair=None,
+ max_length=64, # default
+ stride=0,
+ is_split_into_words=False,
+ padding="max_length", # default
+ truncation=True, # default
+ return_position_ids=False,
+ return_token_type_ids=False, # don't return token_type_ids
+ return_attention_mask=True, # default
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ return_dict=True,
+ return_offsets_mapping=False,
+ add_special_tokens=True,
+ pad_to_multiple_of=None,
+ return_tensors=None,
+ verbose: bool = True,
+ **kwargs):
if self.normalize_text:
is_batched = isinstance(text, (list, tuple))
if is_batched: