diff --git a/torch_geometric/nn/nlp/txt2kg.py b/torch_geometric/nn/nlp/txt2kg.py index 6963536543cf..f8c47a4663f1 100644 --- a/torch_geometric/nn/nlp/txt2kg.py +++ b/torch_geometric/nn/nlp/txt2kg.py @@ -54,18 +54,19 @@ def chunks_to_triples_strs(self, txt_batch: List[str]) -> List[str]: self.model = LLM(LM_name, num_params=14).eval() self.initd_LM = True out_strs = self.model.inference( - question=[txt + '\n' + self.system_prompt for txt in txt_batch], - max_tokens=self.chunk_size) + question=[ + txt + '\n' + self.system_prompt for txt in txt_batch + ], max_tokens=self.chunk_size) else: messages = [] for txt in txt_batch: messages.append({ - "role": - "user", - "content": - txt + '\n' + self.system_prompt}) + "role": "user", + "content": txt + '\n' + self.system_prompt + }) completion = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0, top_p=1, max_tokens=1024, stream=True) + model=self.model, messages=messages, temperature=0, top_p=1, + max_tokens=1024, stream=True) out_str = "" for chunk in completion: if chunk.choices[0].delta.content is not None: