Skip to content

Commit

Permalink
bugfix and feature (modelscope#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqianjing authored May 31, 2024
1 parent 3287ece commit 3a03f4f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 16 deletions.
4 changes: 4 additions & 0 deletions modelscope_agent/agents/role_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modelscope_agent.agent_env_util import AgentEnvMixin
from modelscope_agent.llm.base import BaseChatModel
from modelscope_agent.tools.base import BaseTool
from modelscope_agent.utils.logger import agent_logger as logger
from modelscope_agent.utils.tokenization_utils import count_tokens
from modelscope_agent.utils.utils import check_and_limit_input_length

Expand Down Expand Up @@ -260,11 +261,13 @@ def _run(self,
planning_prompt = self.llm.build_raw_prompt(messages)

max_turn = 10
call_llm_count = 0
while True and max_turn > 0:
# print('=====one input planning_prompt======')
# print(planning_prompt)
# print('=============Answer=================')
max_turn -= 1
call_llm_count += 1
if self.llm.support_function_calling():
output = self.llm.chat_with_functions(
messages=messages,
Expand All @@ -282,6 +285,7 @@ def _run(self,
**kwargs)

llm_result = ''
logger.info(f'call llm {call_llm_count} times output: {output}')
for s in output:
if isinstance(s, dict):
llm_result = s
Expand Down
4 changes: 2 additions & 2 deletions modelscope_agent/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .modelscope import ModelScopeChatGLM, ModelScopeLLM
from .ollama import OllamaLLM
from .openai import OpenAi
from .zhipu import GLM4, ZhipuLLM
from .zhipu import ZhipuLLM


def get_chat_model(model: str, model_server: str, **kwargs) -> BaseChatModel:
Expand All @@ -26,5 +26,5 @@ def get_chat_model(model: str, model_server: str, **kwargs) -> BaseChatModel:

__all__ = [
'LLM_REGISTRY', 'BaseChatModel', 'OpenAi', 'DashScopeLLM', 'QwenChatAtDS',
'ModelScopeLLM', 'ModelScopeChatGLM', 'ZhipuLLM', 'GLM4', 'OllamaLLM'
'ModelScopeLLM', 'ModelScopeChatGLM', 'ZhipuLLM', 'OllamaLLM'
]
3 changes: 3 additions & 0 deletions modelscope_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def support_function_calling(self) -> bool:
if response.get('function_call', None):
# logger.info('Support of function calling is detected.')
self._support_fn_call = True
if response.get('tool_calls', None):
# logger.info('Support of function calling is detected.')
self._support_fn_call = True
except FnCallNotImplError:
pass
except AttributeError:
Expand Down
26 changes: 14 additions & 12 deletions modelscope_agent/llm/zhipu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Dict, Iterator, List, Optional

from modelscope_agent.utils.logger import agent_logger as logger
from zhipuai import ZhipuAI

from .base import BaseChatModel, register_llm
Expand Down Expand Up @@ -31,9 +32,12 @@ class ZhipuLLM(BaseChatModel):
Universal LLM model interface on zhipu
"""

def __init__(self, model: str, model_server: str, **kwargs):
super().__init__(model, model_server)
self._support_fn_call = True
def __init__(self,
model: str,
model_server: str,
support_fn_call: bool = True,
**kwargs):
super().__init__(model, model_server, support_fn_call=support_fn_call)
api_key = kwargs.get('api_key', os.getenv('ZHIPU_API_KEY', '')).strip()
assert api_key, 'ZHIPU_API_KEY is required.'
self.client = ZhipuAI(api_key=api_key)
Expand All @@ -45,7 +49,8 @@ def _chat_stream(self,
**kwargs) -> Iterator[str]:
if not functions or not len(functions):
tool_choice = 'none'
print(f'====> stream messages: {messages}')
logger.info(
f'====> stream messages: {messages}, functions: {functions}')
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
Expand All @@ -62,18 +67,15 @@ def _chat_no_stream(self,
**kwargs) -> str:
if not functions or not len(functions):
tool_choice = 'none'
print(f'====> no stream messages: {messages}')
logger.info(
f'====> no stream messages: {messages}, functions: {functions}')
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=functions,
tool_choice=tool_choice,
)
return response.choices[0].message
message = response.choices[0].message
output = message.content if not functions else [message.model_dump()]


@register_llm('glm-4')
class GLM4(ZhipuLLM):
"""
glm-4 from zhipu
"""
return output
5 changes: 3 additions & 2 deletions modelscope_agent/utils/retry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from functools import wraps
from traceback import format_exc

from modelscope_agent.utils.logger import agent_logger as logger

Expand All @@ -26,9 +27,9 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
except AssertionError as e:
raise AssertionError(e)
except Exception as e:
except Exception:
logger.warning(
f'Attempt to run {func.__name__} {attempts + 1} failed: {e}'
f'Attempt to run {func.__name__} {attempts + 1} failed: {format_exc()}'
)
attempts += 1
time.sleep(delay_seconds)
Expand Down

0 comments on commit 3a03f4f

Please sign in to comment.