Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,11 @@ public void testInformationSchema() throws SQLException {
"model_id,",
new HashSet<>(
Arrays.asList(
"_STLForecaster,", "_NaiveForecaster,", "_ARIMA,", "_ExponentialSmoothing,")));
"_timerxl,",
"_STLForecaster,",
"_NaiveForecaster,",
"_ARIMA,",
"_ExponentialSmoothing,")));

TestUtils.assertResultSetEqual(
statement.executeQuery(
Expand Down
20 changes: 13 additions & 7 deletions iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
from ainode.core.util.masking import prepare_4d_causal_attention_mask
from ainode.core.util.huggingface_cache import Cache, DynamicCache

import safetensors
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import hf_hub_download

from ainode.core.log import Logger
logger = Logger()

@dataclass
class Output:
outputs: torch.Tensor
Expand Down Expand Up @@ -211,12 +214,15 @@ def __init__(self, config: TimerxlConfig):
state_dict = torch.load(config.ckpt_path)
elif config.ckpt_path.endswith('.safetensors'):
if not os.path.exists(config.ckpt_path):
print(f"[INFO] Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace...")
logger.info(f"Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace...")
repo_id = "thuml/timer-base-84m"
filename = os.path.basename(config.ckpt_path) # eg: model.safetensors
config.ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"[INFO] Downloaded checkpoint to {config.ckpt_path}")
state_dict = safetensors.torch.load_file(config.ckpt_path)
try:
config.ckpt_path = hf_hub_download(repo_id=repo_id, filename=os.path.basename(config.ckpt_path), local_dir=os.path.dirname(config.ckpt_path))
logger.info(f"Got checkpoint to {config.ckpt_path}")
except Exception as e:
logger.error(f"Failed to download checkpoint to {config.ckpt_path} due to {e}")
raise e
state_dict = load_safetensors(config.ckpt_path)
else:
raise ValueError('unsupported model weight type')
# If there is no key beginning with 'model.model' in state_dict, add a 'model.' before all keys. (The model code here has an additional layer of encapsulation compared to the code on huggingface.)
Expand All @@ -234,7 +240,7 @@ def inference(self, x, max_new_tokens: int = 96):
# change [L, C=1] to [batchsize=1, L]
self.device = next(self.model.parameters()).device

x = torch.tensor(x.values, dtype=next(self.model.parameters()).dtype, device=self.device)
x = torch.tensor(x, dtype=next(self.model.parameters()).dtype, device=self.device)
x = x.view(1, -1)

preds = self.forward(x, max_new_tokens)
Expand Down
5 changes: 3 additions & 2 deletions iotdb-core/ainode/ainode/core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def __init__(self):
self._model_manager = ModelManager()
self._inference_manager = InferenceManager(model_manager=self._model_manager)

def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
return self._model_manager.register_model(req)
Expand All @@ -37,10 +38,10 @@ def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
return self._model_manager.delete_model(req)

def inference(self, req: TInferenceReq) -> TInferenceResp:
return InferenceManager.inference(req, self._model_manager)
return self._inference_manager.inference(req)

def forecast(self, req: TForecastReq) -> TSStatus:
return InferenceManager.forecast(req, self._model_manager)
return self._inference_manager.forecast(req)

def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
Expand Down
329 changes: 112 additions & 217 deletions iotdb-core/ainode/ainode/core/manager/inference_manager.py

Large diffs are not rendered by default.

22 changes: 9 additions & 13 deletions iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
#
import os
from abc import abstractmethod
from typing import List, Dict
import os

import numpy as np
from sklearn.preprocessing import MinMaxScaler
Expand All @@ -28,15 +28,15 @@
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.trend import STLForecaster

from ainode.TimerXL.models import timer_xl
from ainode.TimerXL.models.configuration_timer import TimerxlConfig
from ainode.core.config import AINodeDescriptor
from ainode.core.constant import AttributeName, BuiltInModelType
from ainode.core.exception import InferenceModelInternalError, AttributeNotSupportError
from ainode.core.exception import InferenceModelInternalError
from ainode.core.exception import WrongAttributeTypeError, NumericalRangeException, StringRangeException, \
ListRangeException, BuiltInModelNotSupportError
from ainode.core.log import Logger

from ainode.TimerXL.models import timer_xl
from ainode.TimerXL.models.configuration_timer import TimerxlConfig

logger = Logger()


Expand Down Expand Up @@ -79,11 +79,6 @@ def fetch_built_in_model(model_id, inference_attributes):
"""
attribute_map = get_model_attributes(model_id)

# validate the inference attributes
for attribute_name in inference_attributes:
if attribute_name not in attribute_map:
raise AttributeNotSupportError(model_id, attribute_name)

# parse the inference attributes, attributes is a Dict[str, Any]
attributes = parse_attribute(inference_attributes, attribute_map)

Expand Down Expand Up @@ -398,9 +393,10 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A
),
AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
name=AttributeName.TIMERXL_CKPT_PATH.value,
default_value=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights', 'timerxl', 'model.safetensors'),
value_choices=[os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights', 'timerxl', 'model.safetensors'), ""],
),
default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
'timerxl', 'model.safetensors'),
value_choices=['']
)
}

# built-in sktime model attributes
Expand Down
Loading
Loading