Skip to content

Commit

Permalink
refactor(api/core/app/app_config/entities.py): Move Type to outside a…
Browse files Browse the repository at this point in the history
…nd add EXTERNAL_DATA_TOOL. (#7444)
  • Loading branch information
laipz8200 authored Aug 20, 2024
1 parent e2d214e commit a10b207
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 104 deletions.
63 changes: 27 additions & 36 deletions api/core/app/app_config/easy_ui_based_app/variables/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re

from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory


Expand All @@ -13,7 +13,7 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV
:param config: model config args
"""
external_data_variables = []
variables = []
variable_entities = []

# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
Expand All @@ -30,50 +30,41 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV
)

# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue

external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
variables.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)

return variables, external_data_variables
return variable_entities, external_data_variables

@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
Expand Down Expand Up @@ -183,4 +174,4 @@ def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: d
config=config
)

return config, ["external_data_tools"]
return config, ["external_data_tools"]
34 changes: 10 additions & 24 deletions api/core/app/app_config/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,43 +82,29 @@ def value_of(cls, value: str) -> 'PromptType':
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None


class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"


class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'

@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')

variable: str
label: str
description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None

@property
def name(self) -> str:
return self.variable


class ExternalDataVariableEntity(BaseModel):
"""
Expand Down Expand Up @@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str
workflow_id: str
28 changes: 14 additions & 14 deletions api/core/app/apps/base_app_generator.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
from collections.abc import Mapping
from typing import Any, Optional

from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType


class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs

def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')

return user_input_value

Expand Down
43 changes: 19 additions & 24 deletions api/core/tools/provider/workflow_tool_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
Expand All @@ -18,6 +18,13 @@
from models.tools import WorkflowToolProvider
from models.workflow import Workflow

VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}


class WorkflowToolProviderController(ToolProviderController):
provider_id: str
Expand All @@ -28,7 +35,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont

if not app:
raise ValueError('app not found')

controller = WorkflowToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
Expand All @@ -46,7 +53,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont
'credentials_schema': {},
'provider_id': db_provider.id or '',
})

# init tools

controller.tools = [controller._get_db_provider_tool(db_provider, app)]
Expand All @@ -56,7 +63,7 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW

def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
Expand Down Expand Up @@ -93,23 +100,11 @@ def fetch_workflow_variable(variable_name: str) -> VariableEntity:
if variable:
parameter_type = None
options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}')

if variable.type == VariableEntity.Type.SELECT and variable.options:
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]

if variable.type == VariableEntityType.SELECT and variable.options:
options = [
ToolParameterOption(
value=option,
Expand Down Expand Up @@ -200,19 +195,19 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
"""
if self.tools is not None:
return self.tools

db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
).first()

if not db_providers:
return []

self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]

return self.tools

def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
"""
get tool by name
Expand All @@ -226,5 +221,5 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
for tool in self.tools:
if tool.identity.name == tool_name:
return tool

return None
6 changes: 5 additions & 1 deletion api/core/workflow/nodes/start/entities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from collections.abc import Sequence

from pydantic import Field

from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData

Expand All @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []
variables: Sequence[VariableEntity] = Field(default_factory=list)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
Expand All @@ -25,23 +26,24 @@

@pytest.fixture
def default_variables():
return [
value = [
VariableEntity(
variable="text_input",
label="text-input",
type=VariableEntity.Type.TEXT_INPUT
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="paragraph",
label="paragraph",
type=VariableEntity.Type.PARAGRAPH
type=VariableEntityType.PARAGRAPH,
),
VariableEntity(
variable="select",
label="select",
type=VariableEntity.Type.SELECT
)
type=VariableEntityType.SELECT,
),
]
return value


def test__convert_to_start_node(default_variables):
Expand Down

0 comments on commit a10b207

Please sign in to comment.