|
2 | 2 | from functools import partial |
3 | 3 | from typing import ( |
4 | 4 | Any, |
5 | | - AsyncIterator, |
6 | | - Iterator, |
7 | 5 | List, |
8 | 6 | Optional, |
9 | 7 | ) |
10 | 8 |
|
11 | 9 | from ai21.models import CompletionsResponse, Penalty |
12 | | - |
13 | | -from langchain_ai21.ai21_base import AI21Base |
14 | 10 | from langchain_core.callbacks import ( |
15 | 11 | AsyncCallbackManagerForLLMRun, |
16 | 12 | CallbackManagerForLLMRun, |
17 | 13 | ) |
18 | 14 | from langchain_core.language_models import BaseLLM |
19 | | -from langchain_core.outputs import GenerationChunk, LLMResult, Generation, RunInfo |
| 15 | +from langchain_core.outputs import Generation, LLMResult |
| 16 | + |
| 17 | +from langchain_ai21.ai21_base import AI21Base |
20 | 18 |
|
21 | 19 |
|
22 | 20 | class AI21(BaseLLM, AI21Base): |
@@ -58,11 +56,17 @@ class AI21(BaseLLM, AI21Base): |
58 | 56 | """ A penalty applied to tokens that are already present in the prompt.""" |
59 | 57 |
|
60 | 58 | count_penalty: Optional[Penalty] = None |
61 | | - """A penalty applied to tokens based on their frequency in the generated responses.""" |
| 59 | + """A penalty applied to tokens based on their frequency |
| 60 | + in the generated responses.""" |
62 | 61 |
|
63 | 62 | custom_model: Optional[str] = None |
64 | 63 | epoch: Optional[int] = None |
65 | 64 |
|
| 65 | + class Config: |
| 66 | + """Configuration for this pydantic object.""" |
| 67 | + |
| 68 | + allow_population_by_field_name = True |
| 69 | + |
66 | 70 | @property |
67 | 71 | def _llm_type(self) -> str: |
68 | 72 | """Return type of LLM.""" |
@@ -101,30 +105,12 @@ async def _agenerate( |
101 | 105 | None, partial(self._generate, **kwargs), prompts, stop, run_manager |
102 | 106 | ) |
103 | 107 |
|
104 | | - def _stream( |
105 | | - self, |
106 | | - prompt: str, |
107 | | - stop: Optional[List[str]] = None, |
108 | | - run_manager: Optional[CallbackManagerForLLMRun] = None, |
109 | | - **kwargs: Any, |
110 | | - ) -> Iterator[GenerationChunk]: |
111 | | - raise NotImplementedError |
112 | | - |
113 | | - async def _astream( |
114 | | - self, |
115 | | - prompt: str, |
116 | | - stop: Optional[List[str]] = None, |
117 | | - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
118 | | - **kwargs: Any, |
119 | | - ) -> AsyncIterator[GenerationChunk]: |
120 | | - raise NotImplementedError |
121 | | - |
122 | 108 | def _invoke_completion( |
123 | 109 | self, |
124 | 110 | prompt: str, |
125 | 111 | model: str, |
126 | 112 | stop_sequences: Optional[List[str]] = None, |
127 | | - **kwargs, |
| 113 | + **kwargs: Any, |
128 | 114 | ) -> CompletionsResponse: |
129 | 115 | return self.client.completion.create( |
130 | 116 | prompt=prompt, |
|
0 commit comments