@@ -65,8 +65,8 @@ def __init__(
6565 username : str = "root" ,
6666 password : str = "root" ,
6767 time_zone : str = "UTC+8" ,
68- # start_split : float = 0,
69- # end_split : float = 1 .0,
68+ use_rate : float = 1. 0 ,
69+ offset_rate : float = 0 .0 ,
7070 ):
7171 super ().__init__ (ip , port , seq_len , input_token_len , output_token_len )
7272
@@ -75,7 +75,7 @@ def __init__(
7575 self .FETCH_SERIES_SQL = "select %s from %s%s"
7676 self .FETCH_SERIES_RANGE_SQL = "select %s from %s offset %s limit %s%s"
7777
78- self .TIME_CONDITION = " where time>%s and time<%s"
78+ self .TIME_CONDITION = " where time>= %s and time<%s"
7979
8080 self .session = Session .init_from_node_urls (
8181 node_urls = [f"{ ip } :{ port } " ],
@@ -84,9 +84,9 @@ def __init__(
8484 zone_id = time_zone ,
8585 )
8686 self .session .open (False )
87+ self .use_rate = use_rate
88+ self .offset_rate = offset_rate
8789 self ._fetch_schema (data_schema_list )
88- # self.start_idx = int(self.total_count * start_split)
89- # self.end_idx = int(self.total_count * end_split)
9090 self .cache_enable = _cache_enable ()
9191 self .cache_key_prefix = model_id + "_"
9292
@@ -128,21 +128,27 @@ def _fetch_schema(self, data_schema_list: list):
128128 )
129129
130130 sorted_series = sorted (series_to_length .items (), key = lambda x : x [1 ][1 ])
131- # TODO: we should define a data structure for this field
132- # structure: [(split_series_name, the number of windows of this series, prefix sum of window number, time_condition of this series), ...]
131+ # structure: [(split_series_name, the number of windows of this series, prefix sum of window number, window start offset, time_condition of this series), ...]
133132 sorted_series_with_prefix_sum = []
134133 window_sum = 0
135134 for seq_name , seq_value in sorted_series :
136135 # calculate and sum the number of training data windows for each time series
137136 window_count = seq_value [1 ] - self .seq_len - self .output_token_len + 1
138- if window_count <= 0 :
137+ if window_count <= 1 :
139138 continue
140- window_sum += window_count
139+ use_window_count = int (window_count * self .use_rate )
140+ window_sum += use_window_count
141141 sorted_series_with_prefix_sum .append (
142- (seq_value [0 ], window_count , window_sum , seq_value [2 ])
142+ (
143+ seq_value [0 ],
144+ use_window_count ,
145+ window_sum ,
146+ int (window_count * self .offset_rate ),
147+ seq_value [2 ],
148+ )
143149 )
144150
145- self .total_count = window_sum
151+ self .total_window_count = window_sum
146152 self .sorted_series = sorted_series_with_prefix_sum
147153
148154 def __getitem__ (self , index ):
@@ -155,14 +161,17 @@ def __getitem__(self, index):
155161 if series_index != 0 :
156162 window_index -= self .sorted_series [series_index - 1 ][2 ]
157163 series = self .sorted_series [series_index ][0 ]
158- time_condition = self .sorted_series [series_index ][3 ]
164+ window_index += self .sorted_series [series_index ][3 ]
165+ time_condition = self .sorted_series [series_index ][4 ]
159166 if self .cache_enable :
160167 cache_key = self .cache_key_prefix + "." .join (series ) + time_condition
161168 series_data = self .cache .get (cache_key )
162169 if series_data is not None :
163170 # try to get the training data window from cache first
164171 series_data = torch .tensor (series_data )
165- result = series_data [window_index : window_index + self .seq_len + self .output_token_len ]
172+ result = series_data [
173+ window_index : window_index + self .seq_len + self .output_token_len
174+ ]
166175 return (
167176 result [0 : self .seq_len ],
168177 result [self .input_token_len : self .seq_len + self .output_token_len ],
@@ -196,13 +205,13 @@ def __getitem__(self, index):
196205 self .cache .put (cache_key , result )
197206 result = torch .tensor (result )
198207 return (
199- result [0 : self .seq_len ],
200- result [self .input_token_len : self .seq_len + self .output_token_len ],
208+ result [0 : self .seq_len ],
209+ result [self .input_token_len : self .seq_len + self .output_token_len ],
201210 np .ones (self .token_num , dtype = np .int32 ),
202211 )
203212
204213 def __len__ (self ):
205- return self .total_count
214+ return self .total_window_count
206215
207216
208217class IoTDBTableModelDataset (BasicDatabaseForecastDataset ):
0 commit comments