Skip to content

Commit

Permalink
Merge branch 'yjyjy131_dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jun 25, 2022
2 parents 30d771e + 0ebce7e commit 96f63f6
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ lib/
lib64/
parts/
sdist/
parameters/
data/
var/
wheels/
Expand Down
6 changes: 3 additions & 3 deletions SPFCN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def setup(seed):
cudnn.deterministic = True


def slot_network_training(device_id=0):
def slot_network_training(data_num, batch_size, epoch, input_res, device_id=0, num_workers=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id=device_id)

# Train
auto_train(get_training_set(6535, 12, 224, device_id), net, device_id=device_id,
epoch_limit=1000, save_path="parameters/")
auto_train(get_training_set(data_num, batch_size, input_res, device_id, num_workers), net, device_id=device_id,
epoch_limit=epoch, save_path="parameters/")


# TODO
Expand Down
14 changes: 7 additions & 7 deletions SPFCN/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
def get_training_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0):
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 6596 and 0 < batch_size and 0 < resolution

vps_set = VisionParkingSlotDataset(
Expand All @@ -17,12 +18,11 @@ def get_training_set(data_size: int,
data_size=data_size,
resolution=resolution)

return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
# if device_id < 0:
# return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
# else:
# return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
# dataset=vps_set, batch_size=batch_size, shuffle=True)
if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)


def get_validating_set(data_size: int,
Expand Down
4 changes: 2 additions & 2 deletions SPFCN/dataset/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


class DataPrefetcher(object):
def __init__(self, dataset, batch_size, shuffle, device):
def __init__(self, dataset, batch_size, shuffle, device, num_workers):
self.stream = torch.cuda.Stream(device=device)
self.device = device

self.loader = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size,
num_workers=4, pin_memory=True)
num_workers=num_workers, pin_memory=True)
self.fetcher = None
self.next_images = None
self.next_labels = None
Expand Down
2 changes: 2 additions & 0 deletions SPFCN/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def auto_train(dataset,
if epoch == epoch_limit:
network.merge()
stage = "merge_bn"

torch.save(network.state_dict(), "%s%s_epoch%d_loss%d.pkl" % (save_path, stage, epoch, int(epoch_loss)))
torch.save()

curr = datetime.now()
info = '{:02d}:{:02d}:{:02d} '.format(curr.hour, curr.minute, curr.minute)
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if __name__ == "__main__":

# auto train
slot_network_training(device_id=0)
slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)

# auto test
model_path = './SPFCN/'
Expand Down

0 comments on commit 96f63f6

Please sign in to comment.