Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 76 additions & 76 deletions iotdb-core/ainode/ainode/core/ingress/iotdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,120 +216,120 @@ def __len__(self):

class IoTDBTableModelDataset(BasicDatabaseForecastDataset):

DEFAULT_TAG = "__DEFAULT_TAG__"

def __init__(
self,
input_len: int,
out_len: int,
model_id: str,
seq_len: int,
input_token_len: int,
output_token_len: int,
data_schema_list: list,
ip: str = "127.0.0.1",
port: int = 6667,
username: str = "root",
password: str = "root",
time_zone: str = "UTC+8",
start_split: float = 0,
end_split: float = 1,
use_rate: float = 1.0,
offset_rate: float = 0.0,
):
super().__init__(ip, port, input_len, out_len)
if end_split < start_split:
raise ValueError("end_split must be greater than start_split")

# database , table
self.SELECT_SERIES_FORMAT_SQL = "select distinct item_id from %s"
self.COUNT_SERIES_LENGTH_SQL = (
"select count(value) from %s where item_id = '%s'"
)
self.FETCH_SERIES_SQL = (
"select value from %s where item_id = '%s' offset %s limit %s"
)
self.SERIES_NAME = "%s.%s"
super().__init__(ip, port, seq_len, input_token_len, output_token_len)

table_session_config = TableSessionConfig(
node_urls=[f"{ip}:{port}"],
username=username,
password=password,
time_zone=time_zone,
)

self.session = TableSession(table_session_config)
self.context_length = self.input_len + self.output_len
self.token_num = self.context_length // self.input_len
self._fetch_schema(data_schema_list)
self.use_rate = use_rate
self.offset_rate = offset_rate

self.start_index = int(self.total_count * start_split)
self.end_index = self.total_count * end_split
# used for caching data
self._fetch_schema(data_schema_list)

def _fetch_schema(self, data_schema_list: list):
series_to_length = {}
for data_schema in data_schema_list:
series_list = []
with self.session.execute_query_statement(
self.SELECT_SERIES_FORMAT_SQL % data_schema
) as show_devices_result:
while show_devices_result.has_next():
series_map = {}
for target_sql in data_schema_list:
target_sql = target_sql.schemaName
with self.session.execute_query_statement(target_sql) as target_data:
while target_data.has_next():
cur_data = target_data.next()
# TODO: currently, we only support the following simple table form
time_col, value_col, tag_col = -1, -1, -1
for i, field in enumerate(cur_data.get_fields()):
if field.get_data_type() == TSDataType.TIMESTAMP:
time_col = i
elif field.get_data_type() in (
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE,
):
value_col = i
elif field.get_data_type() == TSDataType.TEXT:
tag_col = i
if time_col == -1 or value_col == -1:
raise ValueError(
"The training cannot start due to invalid data schema"
)
if tag_col == -1:
tag = self.DEFAULT_TAG
else:
tag = cur_data.get_fields()[tag_col].get_string_value()
if tag not in series_map:
series_map[tag] = []
series_list = series_map[tag]
series_list.append(
get_field_value(show_devices_result.next().get_fields()[0])
get_field_value(cur_data.get_fields()[value_col])
)

for series in series_list:
with self.session.execute_query_statement(
self.COUNT_SERIES_LENGTH_SQL % (data_schema.schemaName, series)
) as count_series_result:
length = get_field_value(count_series_result.next().get_fields()[0])
series_to_length[
self.SERIES_NAME % (data_schema.schemaName, series)
] = length

sorted_series = sorted(series_to_length.items(), key=lambda x: x[1])
sorted_series_with_prefix_sum = []
# TODO: Unify the following implementation
# structure: [(series_name, the number of windows of this series, prefix sum of window number, window start offset, series_data), ...]
series_with_prefix_sum = []
window_sum = 0
for seq_name, seq_length in sorted_series:
window_count = seq_length - self.context_length + 1
if window_count < 0:
for seq_name, seq_values in series_map.items():
# calculate and sum the number of training data windows for each time series
window_count = len(seq_values) - self.seq_len - self.output_token_len + 1
if window_count <= 1:
continue
window_sum += window_count
sorted_series_with_prefix_sum.append((seq_name, window_count, window_sum))
use_window_count = int(window_count * self.use_rate)
window_sum += use_window_count
series_with_prefix_sum.append(
(
seq_name,
use_window_count,
window_sum,
int(window_count * self.offset_rate),
seq_values,
)
)

self.total_count = window_sum
self.sorted_series = sorted_series_with_prefix_sum
self.total_window_count = window_sum
self.series_with_prefix_sum = series_with_prefix_sum

def __getitem__(self, index):
window_index = index

# locate the series to be queried
series_index = 0

while self.sorted_series[series_index][2] < window_index:
while self.series_with_prefix_sum[series_index][1] < window_index:
series_index += 1

# locate the window of this series to be queried
if series_index != 0:
window_index -= self.sorted_series[series_index - 1][2]

if window_index != 0:
window_index -= 1
series = self.sorted_series[series_index][0]
schema = series.split(".")

result = []
sql = self.FETCH_SERIES_SQL % (
schema[0:1],
schema[2],
window_index,
self.context_length,
)
try:
with self.session.execute_query_statement(sql) as query_result:
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
except Exception as e:
logger.error("Executing sql: {} with exception: {}".format(sql, e))
window_index -= self.series_with_prefix_sum[series_index - 1][2]
window_index += self.series_with_prefix_sum[series_index][3]
result = self.series_with_prefix_sum[series_index][4][
window_index : window_index + self.seq_len + self.output_token_len
]
result = torch.tensor(result)
return (
result[0 : self.input_len],
result[-self.output_len :],
result[0 : self.seq_len],
result[self.input_token_len : self.seq_len + self.output_token_len],
np.ones(self.token_num, dtype=np.int32),
)

def __len__(self):
return self.end_index - self.start_index
return self.total_window_count


def register_dataset(key: str, dataset: Dataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTable;
import org.apache.iotdb.confignode.rpc.thrift.TDatabaseSchema;
import org.apache.iotdb.confignode.rpc.thrift.TDeactivateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TDeleteDatabasesReq;
Expand Down Expand Up @@ -248,7 +247,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TStartPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TStopPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TSubscribeReq;
import org.apache.iotdb.confignode.rpc.thrift.TTableInfo;
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
Expand Down Expand Up @@ -2641,10 +2639,6 @@ public TSStatus createModel(TCreateModelReq req) {

private List<IDataSchema> fetchSchemaForTreeModel(TCreateTrainingReq req) {
List<IDataSchema> dataSchemaList = new ArrayList<>();
if (req.useAllData) {
dataSchemaList.add(new IDataSchema("root.**"));
return dataSchemaList;
}
for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) {
IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i));
dataSchema.setTimeRange(req.getTimeRanges().get(i));
Expand All @@ -2654,28 +2648,7 @@ private List<IDataSchema> fetchSchemaForTreeModel(TCreateTrainingReq req) {
}

private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
List<IDataSchema> dataSchemaList = new ArrayList<>();
TDataSchemaForTable dataSchemaForTable = req.getDataSchemaForTable();
if (req.useAllData || !dataSchemaForTable.getDatabaseList().isEmpty()) {
List<String> databaseNameList = new ArrayList<>();
if (req.useAllData) {
TShowDatabaseResp resp = showDatabase(new TGetDatabaseReq());
databaseNameList.addAll(resp.getDatabaseInfoMap().keySet());
} else {
databaseNameList.addAll(dataSchemaForTable.getDatabaseList());
}

for (String database : databaseNameList) {
TShowTableResp resp = showTables(database, false);
for (TTableInfo tableInfo : resp.getTableInfoList()) {
dataSchemaList.add(new IDataSchema(database + DOT + tableInfo.tableName));
}
}
}
for (String tableName : dataSchemaForTable.getTableList()) {
dataSchemaList.add(new IDataSchema(tableName));
}
return dataSchemaList;
return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql()));
}

public TSStatus createTraining(TCreateTrainingReq req) {
Expand All @@ -2687,11 +2660,11 @@ public TSStatus createTraining(TCreateTrainingReq req) {

TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
trainingReq.setModelType("sundial");
if (req.existingModelId != null) {
trainingReq.setModelType(req.getModelType());
if (req.isSetExistingModelId()) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
if (!req.parameters.isEmpty()) {
if (req.isSetParameters() && !req.getParameters().isEmpty()) {
trainingReq.setParameters(req.getParameters());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.Pair;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -1359,26 +1358,12 @@ protected IConfigTask visitRemoveRegion(RemoveRegion removeRegion, MPPQueryConte
protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext context) {
context.setQueryType(QueryType.WRITE);

String curDatabase = clientSession.getDatabaseName();
List<String> tableList = new ArrayList<>();
for (QualifiedName tableName : node.getTargetTables()) {
List<String> parts = tableName.getParts();
if (parts.size() == 1) {
tableList.add(curDatabase + "." + parts.get(0));
} else {
tableList.add(parts.get(1) + "." + parts.get(0));
}
}

return new CreateTrainingTask(
node.getModelId(),
node.getModelType(),
node.getParameters(),
node.isUseAllData(),
node.getTargetTimeRanges(),
node.getExistingModelId(),
tableList,
node.getTargetDbs());
node.getTargetSql());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,6 @@ public IConfigTask visitCreateTraining(
createTrainingStatement.getModelId(),
createTrainingStatement.getModelType(),
createTrainingStatement.getParameters(),
false,
createTrainingStatement.getTargetTimeRanges(),
createTrainingStatement.getExistingModelId(),
targetPathPatterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3334,11 +3334,9 @@ public SettableFuture<ConfigTaskResult> createTraining(
String modelType,
boolean isTableModel,
Map<String, String> parameters,
boolean useAllData,
List<List<Long>> timeRanges,
String existingModelId,
@Nullable List<String> tableList,
@Nullable List<String> databaseList,
@Nullable String targetSql,
@Nullable List<String> pathList) {
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (final ConfigNodeClient client =
Expand All @@ -3347,16 +3345,14 @@ public SettableFuture<ConfigTaskResult> createTraining(

if (isTableModel) {
TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
dataSchemaForTable.setTableList(tableList);
dataSchemaForTable.setDatabaseList(databaseList);
dataSchemaForTable.setTargetSql(targetSql);
req.setDataSchemaForTable(dataSchemaForTable);
} else {
TDataSchemaForTree dataSchemaForTree = new TDataSchemaForTree();
dataSchemaForTree.setPath(pathList);
req.setDataSchemaForTree(dataSchemaForTree);
}
req.setParameters(parameters);
req.setUseAllData(useAllData);
req.setTimeRanges(timeRanges);
req.setExistingModelId(existingModelId);
final TSStatus executionStatus = client.createTraining(req);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,8 @@ SettableFuture<ConfigTaskResult> createTraining(
String modelType,
boolean isTableModel,
Map<String, String> parameters,
boolean useAllData,
List<List<Long>> timeRanges,
String existingModelId,
@Nullable List<String> tableList,
@Nullable List<String> databaseList,
@Nullable String targetSql,
@Nullable List<String> pathList);
}
Loading
Loading