Skip to content

Commit 10762bc

Browse files
committed
Update iotdb.py
1 parent 724dae5 commit 10762bc

File tree

1 file changed

+25
-16
lines changed
  • iotdb-core/ainode/ainode/core/ingress

1 file changed

+25
-16
lines changed

iotdb-core/ainode/ainode/core/ingress/iotdb.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(
6565
username: str = "root",
6666
password: str = "root",
6767
time_zone: str = "UTC+8",
68-
# start_split: float = 0,
69-
# end_split: float = 1.0,
68+
use_rate: float = 1.0,
69+
offset_rate: float = 0.0,
7070
):
7171
super().__init__(ip, port, seq_len, input_token_len, output_token_len)
7272

@@ -75,7 +75,7 @@ def __init__(
7575
self.FETCH_SERIES_SQL = "select %s from %s%s"
7676
self.FETCH_SERIES_RANGE_SQL = "select %s from %s offset %s limit %s%s"
7777

78-
self.TIME_CONDITION = " where time>%s and time<%s"
78+
self.TIME_CONDITION = " where time>=%s and time<%s"
7979

8080
self.session = Session.init_from_node_urls(
8181
node_urls=[f"{ip}:{port}"],
@@ -84,9 +84,9 @@ def __init__(
8484
zone_id=time_zone,
8585
)
8686
self.session.open(False)
87+
self.use_rate = use_rate
88+
self.offset_rate = offset_rate
8789
self._fetch_schema(data_schema_list)
88-
# self.start_idx = int(self.total_count * start_split)
89-
# self.end_idx = int(self.total_count * end_split)
9090
self.cache_enable = _cache_enable()
9191
self.cache_key_prefix = model_id + "_"
9292

@@ -128,21 +128,27 @@ def _fetch_schema(self, data_schema_list: list):
128128
)
129129

130130
sorted_series = sorted(series_to_length.items(), key=lambda x: x[1][1])
131-
# TODO: we should define a data structure for this field
132-
# structure: [(split_series_name, the number of windows of this series, prefix sum of window number, time_condition of this series), ...]
131+
# structure: [(split_series_name, the number of windows of this series, prefix sum of window number, window start offset, time_condition of this series), ...]
133132
sorted_series_with_prefix_sum = []
134133
window_sum = 0
135134
for seq_name, seq_value in sorted_series:
136135
# calculate and sum the number of training data windows for each time series
137136
window_count = seq_value[1] - self.seq_len - self.output_token_len + 1
138-
if window_count <= 0:
137+
if window_count <= 1:
139138
continue
140-
window_sum += window_count
139+
use_window_count = int(window_count * self.use_rate)
140+
window_sum += use_window_count
141141
sorted_series_with_prefix_sum.append(
142-
(seq_value[0], window_count, window_sum, seq_value[2])
142+
(
143+
seq_value[0],
144+
use_window_count,
145+
window_sum,
146+
int(window_count * self.offset_rate),
147+
seq_value[2],
148+
)
143149
)
144150

145-
self.total_count = window_sum
151+
self.total_window_count = window_sum
146152
self.sorted_series = sorted_series_with_prefix_sum
147153

148154
def __getitem__(self, index):
@@ -155,14 +161,17 @@ def __getitem__(self, index):
155161
if series_index != 0:
156162
window_index -= self.sorted_series[series_index - 1][2]
157163
series = self.sorted_series[series_index][0]
158-
time_condition = self.sorted_series[series_index][3]
164+
window_index += self.sorted_series[series_index][3]
165+
time_condition = self.sorted_series[series_index][4]
159166
if self.cache_enable:
160167
cache_key = self.cache_key_prefix + ".".join(series) + time_condition
161168
series_data = self.cache.get(cache_key)
162169
if series_data is not None:
163170
# try to get the training data window from cache first
164171
series_data = torch.tensor(series_data)
165-
result = series_data[window_index : window_index + self.seq_len + self.output_token_len]
172+
result = series_data[
173+
window_index : window_index + self.seq_len + self.output_token_len
174+
]
166175
return (
167176
result[0 : self.seq_len],
168177
result[self.input_token_len : self.seq_len + self.output_token_len],
@@ -196,13 +205,13 @@ def __getitem__(self, index):
196205
self.cache.put(cache_key, result)
197206
result = torch.tensor(result)
198207
return (
199-
result[0: self.seq_len],
200-
result[self.input_token_len: self.seq_len + self.output_token_len],
208+
result[0 : self.seq_len],
209+
result[self.input_token_len : self.seq_len + self.output_token_len],
201210
np.ones(self.token_num, dtype=np.int32),
202211
)
203212

204213
def __len__(self):
205-
return self.total_count
214+
return self.total_window_count
206215

207216

208217
class IoTDBTableModelDataset(BasicDatabaseForecastDataset):

0 commit comments

Comments
 (0)