Skip to content

Commit

Permalink
Add embeding compute task
Browse files Browse the repository at this point in the history
  • Loading branch information
photosssa committed Aug 31, 2023
1 parent 39eb96f commit febe28e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
2 changes: 1 addition & 1 deletion doc/mvp/workflow.drawio
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
</mxGraphModel>
</diagram>
<diagram id="kWxfmPxtNxOAf0a73TCG" name="Page-2">
<mxGraphModel dx="500" dy="864" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<mxGraphModel dx="625" dy="1080" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
Expand Down
8 changes: 8 additions & 0 deletions src/aios_kernel/compute_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def set_llm_params(self,prompts,model_name,max_token_size,callchain_id = None):
self.params["model_name"] = "gpt-4-0613"
self.params["max_token_size"] = max_token_size

def set_embeding_params(self, model_name, input, callchain_id = None):
self.task_type = "embeding"
self.create_time = time.time()
self.task_id = uuid.uuid4().hex
self.callchain_id = callchain_id
self.params["model_name"] = model_name
self.params["input"] = input

def display(self) -> str:
return f"ComputeTask: {self.task_id} {self.task_type} {self.state}"

Expand Down
58 changes: 33 additions & 25 deletions src/aios_kernel/open_ai_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,39 @@ async def remove_task(self,task_id:str):

def _run_task(self,task:ComputeTask):
task.state = ComputeTaskState.RUNNING
mode_name = task.params["model_name"]
# max_token_size = task.params["max_token_size"]
prompts = task.params["prompts"]

logger.info(f"call openai {mode_name} prompts: {prompts}")
resp = openai.ChatCompletion.create(model=mode_name,
messages=prompts,
max_tokens=4000,
temperature=1.2)
logger.info(f"openai response: {resp}")

status_code = resp["choices"][0]["finish_reason"]
if status_code != "stop":
task.state = ComputeTaskState.ERROR
task.error_str =f"The status code was {status_code}."
return None

result = ComputeTaskResult()
result.set_from_task(task)
result.worker_id = self.node_id
result.result_str = resp["choices"][0]["message"]["content"]
result.result = resp["choices"][0]["message"]

return result

# switch tsak type
if task.task_type == "llm_completion":
mode_name = task.params["model_name"]
# max_token_size = task.params["max_token_size"]
prompts = task.params["prompts"]

mode_name = task.params["model_name"]
# max_token_size = task.params["max_token_size"]
prompts = task.params["prompts"]

logger.info(f"call openai {mode_name} prompts: {prompts}")
resp = openai.ChatCompletion.create(model=mode_name,
messages=prompts,
max_tokens=4000,
temperature=1.2)
logger.info(f"openai response: {resp}")

status_code = resp["choices"][0]["finish_reason"]
if status_code != "stop":
task.state = ComputeTaskState.ERROR
task.error_str =f"The status code was {status_code}."
return None

result = ComputeTaskResult()
result.set_from_task(task)
result.worker_id = self.node_id
result.result_str = resp["choices"][0]["message"]["content"]
result.result = resp["choices"][0]["message"]

return result
if task.task_type == "embeding":
pass

def start(self):
async def _run_task_loop():
while True:
Expand Down

0 comments on commit febe28e

Please sign in to comment.