Skip to content

Commit

Permalink
fix: less than 200 now
Browse files Browse the repository at this point in the history
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
  • Loading branch information
yihong0618 committed Dec 20, 2024
1 parent 9bc0fa3 commit a9df950
Show file tree
Hide file tree
Showing 16 changed files with 140 additions and 88 deletions.
2 changes: 1 addition & 1 deletion api/constants/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from models.model import AppMode

default_app_templates = {
default_app_templates: dict[AppMode, dict] = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
Expand Down
10 changes: 8 additions & 2 deletions api/core/tools/tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
if not self.variables:
return
if self.identity is None:
return

self.variables.set_file(self.identity.name, variable_name, image_key)

Expand All @@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None:
"""
if not self.variables:
return
if self.identity is None:
return

self.variables.set_text(self.identity.name, variable_name, text)

Expand Down Expand Up @@ -197,9 +201,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]:

return result

def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
if self.runtime is None:
return []
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)

Expand All @@ -221,7 +227,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) ->
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
result: dict[str, Any] = deepcopy(dict(tool_parameters))
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/utils/workflow_configuration_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)

Expand Down
10 changes: 4 additions & 6 deletions api/core/workflow/nodes/llm/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM

def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs = None
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None

try:
Expand Down Expand Up @@ -196,7 +196,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
error_type=type(e).__name__,
)
)
return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
Expand All @@ -206,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
process_data=process_data,
)
)
return

outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}

Expand Down Expand Up @@ -302,7 +300,7 @@ def _transform_chat_messages(
return messages

def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
variables: dict[str, Any] = {}

if not node_data.prompt_config:
return variables
Expand All @@ -319,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
return str(input_dict["content"])

# else, parse the dict
try:
Expand Down
8 changes: 4 additions & 4 deletions api/extensions/storage/opendal_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Generator
from pathlib import Path

import opendal
import opendal # type: ignore[import]
from dotenv import dotenv_values

from extensions.storage.base_storage import BaseStorage
Expand All @@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value

file_env_vars = dotenv_values(env_file_path)
file_env_vars: dict = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value
Expand Down Expand Up @@ -48,7 +48,7 @@ def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")

content = self.op.read(path=filename)
content: bytes = self.op.read(path=filename)
logger.debug(f"file {filename} loaded")
return content

Expand All @@ -75,7 +75,7 @@ def exists(self, filename: str) -> bool:
# error handler here when opendal python-binding has a exists method, we should use it
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
try:
res = self.op.stat(path=filename).mode.is_file()
res: bool = self.op.stat(path=filename).mode.is_file()
logger.debug(f"file {filename} checked")
return res
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions api/libs/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
token_data = json.loads(token_data_json)
token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
return token_data

@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
current_token = redis_client.get(key)
current_token: Optional[str] = redis_client.get(key)
return current_token

@classmethod
Expand Down
24 changes: 16 additions & 8 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, cast

import sqlalchemy as sa
from flask import request
Expand All @@ -23,6 +23,9 @@
from .account import Account, Tenant
from .types import StringUUID

if TYPE_CHECKING:
from models.workflow import Workflow


class DifySetup(db.Model): # type: ignore[name-defined]
__tablename__ = "dify_setups"
Expand Down Expand Up @@ -150,7 +153,7 @@ def mode_compatible_with_agent(self) -> str:
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value

return self.mode
return str(self.mode)

@property
def deleted_tools(self) -> list:
Expand Down Expand Up @@ -318,7 +321,7 @@ def external_data_tools_list(self) -> list[dict]:
return json.loads(self.external_data_tools) if self.external_data_tools else []

@property
def user_input_form_list(self) -> dict:
def user_input_form_list(self) -> list[dict]:
return json.loads(self.user_input_form) if self.user_input_form else []

@property
Expand All @@ -340,7 +343,7 @@ def completion_prompt_config_dict(self) -> dict:
@property
def dataset_configs_dict(self) -> dict:
if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
dataset_configs: dict = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
Expand Down Expand Up @@ -582,6 +585,8 @@ def inputs(self, value: Mapping[str, Any]):
@property
def model_config(self):
model_config = {}
app_model_config: Optional[AppModelConfig] = None

if self.mode == AppMode.ADVANCED_CHAT.value:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
Expand All @@ -593,6 +598,7 @@ def model_config(self):
if "model" in override_model_configs:
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
Expand Down Expand Up @@ -1248,7 +1254,7 @@ class OperationLog(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class EndUser(UserMixin, db.Model):
class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "end_users"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
Expand Down Expand Up @@ -1488,7 +1494,7 @@ class MessageAgentThought(db.Model): # type: ignore[name-defined]
@property
def files(self) -> list:
if self.message_files:
return json.loads(self.message_files)
return cast(list, json.loads(self.message_files))
else:
return []

Expand All @@ -1500,7 +1506,7 @@ def tools(self) -> list[str]:
def tool_labels(self) -> dict:
try:
if self.tool_labels_str:
return json.loads(self.tool_labels_str)
return cast(dict, json.loads(self.tool_labels_str))
else:
return {}
except Exception as e:
Expand All @@ -1510,7 +1516,7 @@ def tool_labels(self) -> dict:
def tool_meta(self) -> dict:
try:
if self.tool_meta_str:
return json.loads(self.tool_meta_str)
return cast(dict, json.loads(self.tool_meta_str))
else:
return {}
except Exception as e:
Expand Down Expand Up @@ -1558,6 +1564,8 @@ def tool_outputs_dict(self) -> dict:
except Exception as e:
if self.observation:
return dict.fromkeys(tools, self.observation)
else:
return {}


class DatasetRetrieverResource(db.Model): # type: ignore[name-defined]
Expand Down
7 changes: 4 additions & 3 deletions api/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from factories import variable_factory
from libs import helper
from models.enums import CreatedByRole
from models.model import AppMode, Message

from .account import Account
from .types import StringUUID
Expand Down Expand Up @@ -43,7 +44,7 @@ def value_of(cls, value: str) -> "WorkflowType":
raise ValueError(f"invalid workflow type value {value}")

@classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> "WorkflowType":
"""
Get workflow type from app mode.
Expand Down Expand Up @@ -198,7 +199,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list:
return []

# get user_input_form from start node
variables = start_node.get("data", {}).get("variables", [])
variables: list[Any] = start_node.get("data", {}).get("variables", [])

if to_old_structure:
old_structure_variables = []
Expand Down Expand Up @@ -435,7 +436,7 @@ def outputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.outputs) if self.outputs else {}

@property
def message(self) -> Optional["Message"]:
def message(self) -> Optional[Message]:
from models.model import Message

return (
Expand Down
22 changes: 13 additions & 9 deletions api/services/app_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from datetime import UTC, datetime
from typing import cast
from typing import Optional, cast

from flask_login import current_user # type: ignore
from flask_sqlalchemy.pagination import Pagination
Expand Down Expand Up @@ -83,7 +83,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
Expand All @@ -100,6 +100,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for model {model_instance.model}")

default_model_dict = {
"provider": model_instance.provider,
Expand All @@ -109,7 +111,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
Expand Down Expand Up @@ -314,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict:
"""
app_mode = AppMode.value_of(app_model.mode)

meta = {"tool_icons": {}}
meta: dict = {"tool_icons": {}}

if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
Expand All @@ -336,7 +338,7 @@ def get_app_meta(self, app_model: App) -> dict:
}
)
else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config: Optional[AppModelConfig] = app_model.app_model_config

if not app_model_config:
return meta
Expand All @@ -352,16 +354,18 @@ def get_app_meta(self, app_model: App) -> dict:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get("provider_type")
provider_id = tool.get("provider_id")
tool_name = tool.get("tool_name")
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
tool_name = tool.get("tool_name", "")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: ApiToolProvider = (
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
Expand Down
8 changes: 6 additions & 2 deletions api/services/entities/model_provider_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from pydantic import BaseModel, ConfigDict

from configs import dify_config
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
)
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
Expand Down Expand Up @@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity.
"""

provider: SimpleProviderEntityResponse
provider: SimpleModelProviderEntity

def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.model_dump())
Loading

0 comments on commit a9df950

Please sign in to comment.