We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8709dde commit 35e57edCopy full SHA for 35e57ed
qlib/contrib/model/pytorch_gats_ts.py
@@ -27,10 +27,9 @@
27
28
class DailyBatchSampler(Sampler):
29
def __init__(self, data_source):
30
-
31
self.data_source = data_source
32
- self.data = self.data_source.data.loc[self.data_source.get_index()]
33
- self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
+ # calculate number of samples in each batch
+ self.daily_count = pd.Series(index=self.data_source.get_index()).groupby('datetime').size().values
34
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
35
self.daily_index[0] = 0
36
0 commit comments