Skip to content

[AINode] Support concurrent inference for Timer-Sundial #15897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions iotdb-core/ainode/ainode/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AINODE_CONF_FILE_NAME,
AINODE_CONF_GIT_FILE_NAME,
AINODE_CONF_POM_FILE_NAME,
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
AINODE_INFERENCE_RPC_ADDRESS,
AINODE_INFERENCE_RPC_PORT,
AINODE_LOG_DIR,
Expand All @@ -55,6 +56,9 @@ def __init__(self):
# Used for connection of DataNode/ConfigNode clients
self._ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS
self._ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT
self._ain_inference_batch_interval_in_ms: int = (
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
)

# log directory
self._ain_logs_dir: str = AINODE_LOG_DIR
Expand Down Expand Up @@ -132,6 +136,14 @@ def get_ain_inference_rpc_port(self) -> int:
def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None:
self._ain_inference_rpc_port = ain_inference_rpc_port

def get_ain_inference_batch_interval_in_ms(self) -> int:
return self._ain_inference_batch_interval_in_ms

def set_ain_inference_batch_interval_in_ms(
self, ain_inference_batch_interval_in_ms: int
) -> None:
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms

def get_ain_logs_dir(self) -> str:
return self._ain_logs_dir

Expand Down Expand Up @@ -273,6 +285,11 @@ def _load_config_from_file(self) -> None:
int(file_configs["ain_inference_rpc_port"])
)

if "ain_inference_batch_interval_in_ms" in config_keys:
self._config.set_ain_inference_batch_interval_in_ms(
int(file_configs["ain_inference_batch_interval_in_ms"])
)

if "ain_models_dir" in config_keys:
self._config.set_ain_models_dir(file_configs["ain_models_dir"])

Expand Down
5 changes: 5 additions & 0 deletions iotdb-core/ainode/ainode/core/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,27 @@
AINODE_CONF_GIT_FILE_NAME = "git.properties"
AINODE_CONF_POM_FILE_NAME = "pom.properties"
AINODE_SYSTEM_FILE_NAME = "system.properties"

# inference_rpc_address
AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
AINODE_INFERENCE_RPC_PORT = 10810
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15

# AINode folder structure
AINODE_MODELS_DIR = "data/ainode/models"
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config.
AINODE_SYSTEM_DIR = "data/ainode/system"
AINODE_LOG_DIR = "logs/ainode"
AINODE_THRIFT_COMPRESSION_ENABLED = False

# use for node management
AINODE_CLUSTER_NAME = "defaultCluster"
AINODE_VERSION_INFO = "UNKNOWN"
AINODE_BUILD_INFO = "UNKNOWN"
AINODE_ROOT_DIR = os.path.dirname(
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
)

# connect IoTDB cluster
AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
AINODE_CLUSTER_INGRESS_PORT = 6667
Expand Down
17 changes: 17 additions & 0 deletions iotdb-core/ainode/ainode/core/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
121 changes: 121 additions & 0 deletions iotdb-core/ainode/ainode/core/inference/inference_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import threading
from typing import Any

import torch

from ainode.core.inference.strategy.abstract_inference_pipeline import (
AbstractInferencePipeline,
)
from ainode.core.log import Logger

logger = Logger()


class InferenceRequestState:
WAITING = "waiting"
RUNNING = "running"
FINISHED = "finished"


class InferenceRequest:
def __init__(
self,
req_id: str,
inputs: torch.Tensor,
inference_pipeline: AbstractInferencePipeline,
max_new_tokens: int = 96,
**infer_kwargs,
):
if inputs.ndim == 1:
inputs = inputs.unsqueeze(0)

self.req_id = req_id
self.inputs = inputs
self.infer_kwargs = infer_kwargs
self.inference_pipeline = inference_pipeline
self.max_new_tokens = (
max_new_tokens # Number of time series data points to generate
)

self.batch_size = inputs.size(0)
self.state = InferenceRequestState.WAITING
self.cur_step_idx = 0 # Current write position in the output step index

# Preallocate output buffer [batch_size, max_new_tokens]
device = inputs.device
self.output_tensor = torch.zeros(
self.batch_size, max_new_tokens, device=device
) # shape: [self.batch_size, max_new_steps]

def mark_running(self):
self.state = InferenceRequestState.RUNNING

def mark_finished(self):
self.state = InferenceRequestState.FINISHED

def is_finished(self) -> bool:
return (
self.state == InferenceRequestState.FINISHED
or self.cur_step_idx >= self.max_new_tokens
)

def write_step_output(self, step_output: torch.Tensor):
if step_output.ndim == 1:
step_output = step_output.unsqueeze(0)

batch_size, step_size = step_output.shape
end_idx = self.cur_step_idx + step_size

if end_idx > self.max_new_tokens:
self.output_tensor[:, self.cur_step_idx :] = step_output[
:, : self.max_new_tokens - self.cur_step_idx
]
self.cur_step_idx = self.max_new_tokens
else:
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
self.cur_step_idx = end_idx

if self.is_finished():
self.mark_finished()

def get_final_output(self) -> torch.Tensor:
return self.output_tensor[:, : self.cur_step_idx]


class InferenceRequestProxy:
"""
Wrap the raw request for handling multiprocess processing.
"""

def __init__(self, req_id: str):
self.req_id = req_id
self.result = None
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)

def set_result(self, result: Any):
with self._lock:
self.result = result
self._condition.notify_all()

def wait_for_completion(self) -> Any:
with self._lock:
self._condition.wait()
return self.result
140 changes: 140 additions & 0 deletions iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import random
import threading
import time

import numpy as np
import torch
import torch.multiprocessing as mp
from transformers import PretrainedConfig, PreTrainedModel

from ainode.core.config import AINodeDescriptor
from ainode.core.inference.inference_request import InferenceRequest
from ainode.core.log import Logger

logger = Logger()


class InferenceRequestPool(mp.Process):
"""
The request pool to handle inference for a specific model.
"""

FIX_SEED = 2021
WAITING_INTERVAL_IN_MS = (
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
) # How often to check for requests in the waiting/running queue

def __init__(
self,
pool_id: int,
model: PreTrainedModel,
config: PretrainedConfig,
request_queue: mp.Queue,
result_queue: mp.Queue,
**pool_kwargs,
):
super().__init__()
self.pool_id = pool_id
self.model = model
self.device = self.model.device
self.config = config
self.pool_kwargs = pool_kwargs

# TODO: A scheduler is necessary for better handling following queues
self._threads = []
self._waiting_queue = request_queue # Requests that are waiting to be processed
self._running_queue = mp.Queue() # Requests that are currently being processed
self._finished_queue = result_queue # Requests that are finished
self._stop_event = mp.Event()

# Fix inference seed
random.seed(self.FIX_SEED)
torch.manual_seed(self.FIX_SEED)
np.random.seed(self.FIX_SEED)

def memory_is_available(self, request):
# need test with several rounds of dummy data
pass

def _activate_requests(self):
if self._waiting_queue.empty():
return
request: InferenceRequest = self._waiting_queue.get()
# TODO: Check memory size before activating requests
request.inputs = request.inference_pipeline.preprocess_inputs(request.inputs)
request.mark_running()
logger.debug(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is activated with inputs shape {request.inputs.shape}"
)
self._running_queue.put(request)

def _requests_activate_loop(self):
while not self._stop_event.is_set():
time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
self._activate_requests()

def _step(self):
if self._running_queue.empty():
return
# TODO: We need a batcher to accelerate the concurrent inference
# TODO: Check memory size before executing requests
request: InferenceRequest = self._running_queue.get()
output = self.model.generate(
request.inputs,
max_new_tokens=request.max_new_tokens,
num_samples=10,
revin=True,
)
request.write_step_output(output[0].mean(dim=0))
request.inference_pipeline.post_decode()
if request.is_finished():
request.inference_pipeline.post_inference()
logger.debug(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
)
self._finished_queue.put(request)
else:
logger.debug(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
)
self._waiting_queue.put(request)

def _requests_execute_loop(self):
while not self._stop_event.is_set():
time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
self._step()

def run(self):
activate_daemon = threading.Thread(
target=self._requests_activate_loop, daemon=True
)
self._threads.append(activate_daemon)
activate_daemon.start()
execute_daemon = threading.Thread(
target=self._requests_execute_loop, daemon=True
)
self._threads.append(execute_daemon)
execute_daemon.start()
for thread in self._threads:
thread.join()

def stop(self):
self._stop_event.set()
17 changes: 17 additions & 0 deletions iotdb-core/ainode/ainode/core/inference/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
Loading
Loading