Skip to content

Commit 35e57ed

Browse files
committed
fix gat dataset
1 parent 8709dde commit 35e57ed

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

qlib/contrib/model/pytorch_gats_ts.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727

2828
class DailyBatchSampler(Sampler):
2929
def __init__(self, data_source):
30-
3130
self.data_source = data_source
32-
self.data = self.data_source.data.loc[self.data_source.get_index()]
33-
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
31+
# calculate number of samples in each batch
32+
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby('datetime').size().values
3433
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
3534
self.daily_index[0] = 0
3635

0 commit comments

Comments
 (0)