forked from junyanz/CycleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaligned_data_loader.lua
44 lines (36 loc) · 1.37 KB
/
aligned_data_loader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are aligned
-- The datasets are of the same size
--------------------------------------------------------------------------------
require 'data.base_data_loader'
local class = require 'class'
data_util = paths.dofile('data_util.lua')
AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')
function AlignedDataLoader:__init(conf)
BaseDataLoader.__init(self, conf)
conf = conf or {}
end
function AlignedDataLoader:name()
return 'AlignedDataLoader'
end
function AlignedDataLoader:Initialize(opt)
opt.align_data = 1
self.idx_A = {1, opt.input_nc}
self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
local nc = 3--opt.input_nc + opt.output_nc
self.data = data_util.load_dataset('', opt, nc)
end
-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function AlignedDataLoader:LoadBatchForAllDatasets()
local batch_data, path = self.data:getBatch()
local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
local batchB = batch_data[{ {}, self.idx_B, {}, {} }]
return batchA, batchB, path, path
end
-- returns the size of each dataset
function AlignedDataLoader:size(dataset)
return self.data:size()
end