Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions iotdb-core/ainode/ainode/core/manager/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
InvalidUriError,
)
from ainode.core.log import Logger
from ainode.core.model.model_info import BuiltInModelType, ModelInfo, ModelStates
from ainode.core.model.model_storage import ModelStorage
from ainode.core.util.status import get_status
from ainode.thrift.ainode.ttypes import (
Expand Down Expand Up @@ -140,3 +141,15 @@ def get_ckpt_path(self, model_id: str) -> str:

def show_models(self) -> TShowModelsResp:
return self.model_storage.show_models()

def register_built_in_model(self, model_info: ModelInfo):
self.model_storage.register_built_in_model(model_info)

def update_model_state(self, model_id: str, state: ModelStates):
self.model_storage.update_model_state(model_id, state)

def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
"""
Get the type of the model with the given model_id.
"""
return self.model_storage.get_built_in_model_type(model_id.lower())
38 changes: 35 additions & 3 deletions iotdb-core/ainode/ainode/core/model/model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def delete_model(self, model_id: str) -> None:
with self._lock_pool.get_lock(model_id).write_lock():
if os.path.exists(storage_path):
shutil.rmtree(storage_path)
if model_id in self._model_info_map:
del self._model_info_map[model_id]
logger.info(f"Model {model_id} deleted successfully.")

def _is_built_in(self, model_id: str) -> bool:
"""
Expand All @@ -218,9 +221,9 @@ def _is_built_in(self, model_id: str) -> bool:
Returns:
bool: True if the model is built-in, False otherwise.
"""
return (
model_id in self._model_info_map
and self._model_info_map[model_id].category == ModelCategory.BUILT_IN
return model_id in self._model_info_map and (
self._model_info_map[model_id].category == ModelCategory.BUILT_IN
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
)

def load_model(self, model_id: str, acceleration: bool) -> Callable:
Expand Down Expand Up @@ -291,3 +294,32 @@ def show_models(self) -> TShowModelsResp:
for model_id, model_info in self._model_info_map.items()
),
)

def register_built_in_model(self, model_info: ModelInfo):
with self._lock_pool.get_lock(model_info.model_id).write_lock():
self._model_info_map[model_info.model_id] = model_info

def update_model_state(self, model_id: str, state: ModelStates):
with self._lock_pool.get_lock(model_id).write_lock():
if model_id in self._model_info_map:
self._model_info_map[model_id].state = state
else:
raise ValueError(f"Model {model_id} does not exist.")

def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
"""
Get the type of the model with the given model_id.

Args:
model_id (str): The ID of the model.

Returns:
str: The type of the model.
"""
with self._lock_pool.get_lock(model_id).read_lock():
if model_id in self._model_info_map:
return get_built_in_model_type(
self._model_info_map[model_id].model_type
)
else:
raise ValueError(f"Model {model_id} does not exist.")
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,8 @@ public TSStatus createModel(CreateModelPlan plan) {
try {
acquireModelTableWriteLock();
String modelName = plan.getModelName();
if (modelTable.containsModel(modelName)) {
return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
.setMessage(String.format("model [%s] has already been created.", modelName));
} else {
modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING));
return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
}
modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING));
return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
} catch (Exception e) {
final String errorMessage =
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ private TWindowParams getWindowParams() {
}

private TsBlock preProcess(TsBlock inputTsBlock) {
boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn();
// boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn();
boolean notBuiltIn = false;
if (windowType == null || windowType == InferenceWindowType.HEAD) {
if (notBuiltIn
&& totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,6 @@ private void checkWindowSize(long windowSize, ModelInformation modelInformation)
if (modelInformation.isBuiltIn()) {
return;
}

if (modelInformation.getInputShape()[0] != windowSize) {
throw new SemanticException(
String.format(
"Window output %d is not equal to input size of model %d",
windowSize, modelInformation.getInputShape()[0]));
}
}

private ISchemaTree analyzeSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ModelInformation(
}

public ModelInformation(String modelName, ModelStatus status) {
this.modelType = ModelType.USER_DEFINED;
this.modelType = ModelType.BUILT_IN_FORECAST;
this.modelName = modelName;
this.inputShape = new int[0];
this.outputShape = new int[0];
Expand Down
Loading