Skip to content

Commit 2f83350

Browse files
massi-anghwchase17
andauthored
Feat bedrock cohere support (#11230)
**Description:** Added support for Cohere command model via Bedrock. With this change it is now possible to use the `cohere.command-text-v14` model via Bedrock API. About Streaming: Cohere model outputs 2 additional chunks at the end of the text being generated via streaming: a chunk containing the text `<EOS_TOKEN>`, and a chunk indicating the end of the stream. In this implementation I chose to ignore both chunks. An alternative solution could be to replace `<EOS_TOKEN>` with `\n` Tests: manually tested that the new model work with both `llm.generate()` and `llm.stream()`. Tested with `temperature`, `p` and `stop` parameters. **Issue:** #11181 **Dependencies:** No new dependencies **Tag maintainer:** @baskaryan **Twitter handle:** mangelino --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
1 parent 37f2f71 commit 2f83350

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

libs/langchain/langchain/llms/bedrock.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class LLMInputOutputAdapter:
6565
provider_to_output_key_map = {
6666
"anthropic": "completion",
6767
"amazon": "outputText",
68+
"cohere": "text",
6869
}
6970

7071
@classmethod
@@ -74,7 +75,7 @@ def prepare_input(
7475
input_body = {**model_kwargs}
7576
if provider == "anthropic":
7677
input_body["prompt"] = _human_assistant_format(prompt)
77-
elif provider == "ai21":
78+
elif provider == "ai21" or provider == "cohere":
7879
input_body["prompt"] = prompt
7980
elif provider == "amazon":
8081
input_body = dict()
@@ -98,6 +99,8 @@ def prepare_output(cls, provider: str, response: Any) -> str:
9899

99100
if provider == "ai21":
100101
return response_body.get("completions")[0].get("data").get("text")
102+
elif provider == "cohere":
103+
return response_body.get("generations")[0].get("text")
101104
else:
102105
return response_body.get("results")[0].get("outputText")
103106

@@ -119,6 +122,12 @@ def prepare_output_stream(
119122
chunk = event.get("chunk")
120123
if chunk:
121124
chunk_obj = json.loads(chunk.get("bytes").decode())
125+
if provider == "cohere" and (
126+
chunk_obj["is_finished"]
127+
or chunk_obj[cls.provider_to_output_key_map[provider]]
128+
== "<EOS_TOKEN>"
129+
):
130+
return
122131

123132
# chunk obj format varies with provider
124133
yield GenerationChunk(
@@ -159,6 +168,7 @@ class BedrockBase(BaseModel, ABC):
159168
"anthropic": "stop_sequences",
160169
"amazon": "stopSequences",
161170
"ai21": "stop_sequences",
171+
"cohere": "stop_sequences",
162172
}
163173

164174
@root_validator()
@@ -259,9 +269,10 @@ def _prepare_input_and_invoke_stream(
259269

260270
# stop sequence from _generate() overrides
261271
# stop sequences in the class attribute
262-
_model_kwargs[
263-
self.provider_stop_sequence_key_name_map.get(provider),
264-
] = stop
272+
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
273+
274+
if provider == "cohere":
275+
_model_kwargs["stream"] = True
265276

266277
params = {**_model_kwargs, **kwargs}
267278
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)

0 commit comments

Comments
 (0)