Skip to content

Commit

Permalink
Merge branch 'yjyjy131_dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
yjyjy131 committed Jun 27, 2022
2 parents 3a2d6bf + d3f04ae commit dba0cdb
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 44 deletions.
22 changes: 10 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@ For vehicles equipped with the automatic parking system, the accuracy and speed

## Usage

1. You can set your data path in './SPFCN/dataset/__init__.py'.
./data/training/image/
./data/training/label/
./data/validating/image/
./data/validating/label/
./data/testing/image/
./data/testing/label/

2. slot_network_training : A function that runs the network training code.

1. 데이터 셋 구성
./data/train/
./data/train_raw_label/
./data/test/all/
./data/test/test_raw_label/







3. slot_network_testing : A function that runs the network testing code.

4. SlotDetector : A class that helps to return coordinate values ​​that can be used in an image based on the results of the network.


## Performance
Expand Down
6 changes: 4 additions & 2 deletions SPFCN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def setup(seed):
cudnn.deterministic = True


def slot_network_training(data_num, batch_size, epoch, input_res, device_id=0, num_workers=0):
def slot_network_training(data_num, batch_size, valid_data_num, valid_batch_size, epoch, input_res, device_id=0, num_workers=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id=device_id)

# Train
auto_train(get_training_set(data_num, batch_size, input_res, device_id, num_workers), net, device_id=device_id,
auto_train(get_training_set(data_num, batch_size, input_res, device_id, num_workers),
get_validating_set(valid_data_num, valid_batch_size, input_res, device_id, num_workers),
net, device_id=device_id,
epoch_limit=epoch, save_path="parameters/")


Expand Down
11 changes: 6 additions & 5 deletions SPFCN/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ def get_training_set(data_size: int,
def get_validating_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0):
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 1538 and 0 < batch_size and 0 < resolution
vps_set = VisionParkingSlotDataset(
image_path="./data/testing/image/",
label_path="./data/testing/label/",
image_path="./data/validating/image/",
label_path="./data/validating/label/",
data_size=data_size,
resolution=resolution)
if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=False)
dataset=vps_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)


def get_testing_set(data_size: int,
Expand Down
1 change: 0 additions & 1 deletion SPFCN/model/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class SlotDetector(object):
def __init__(self, device_id: int, **kwargs):
self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
self.config = self.update_config(**kwargs)
print(self.config)
self.network = SlotNetwork(self.config['dim_encoder'], device_id)
self.network.merge()
try:
Expand Down
6 changes: 3 additions & 3 deletions SPFCN/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,17 @@ def step(self):
try:
mark_recall = mark_co_count / mark_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")
print("ZeroDivisionError at mark_gt_count")

try:
slot_precision = slot_co_count / slot_re_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")
print("ZeroDivisionError at slot_re_count")

try:
slot_recall = slot_co_count / slot_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")
print("ZeroDivisionError at slot_gt_count")

print("\rIndex: {}, Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_precision, mark_recall, slot_precision, slot_recall), end='')
Expand Down
22 changes: 12 additions & 10 deletions SPFCN/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,17 @@
@torch.no_grad()
def auto_validate(dataset,
network,
device_id: int = 0,
load_path: str = None):
device_id: int = 0):
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)

assert os.path.exists(load_path)
network.load_state_dict(torch.load(load_path, map_location=device))
network.eval()

auto_tester = Validator(dataset, network, device)
auto_tester.step()
auto_tester.get_network_inference_time()
auto_tester.get_detector_inference_time()
auto_validator = Validator(dataset, network, device)
auto_validator.step()
auto_validator.get_network_inference_time()
auto_validator.get_detector_inference_time()


def auto_train(dataset,
valid_dataset,
network,
device_id: int = 0,
load_path: str = None,
Expand All @@ -52,7 +48,9 @@ def auto_train(dataset,
epoch_unit = epoch_limit // 10
stage = "warm_up"
auto_trainer = Trainer(dataset, network, device, lr)

for epoch in range(1, epoch_limit + 1):
print("Train model ... ")
if epoch < epoch_unit:
pass
elif epoch == epoch_unit:
Expand Down Expand Up @@ -86,4 +84,8 @@ def auto_train(dataset,
info += 'Time left: ' + 'About {} minutes'.format(int(time_left / 60)) if time_left < 3600 \
else 'About {} hours'.format(int(time_left / 3600)) if time_left < 36000 else 'Just go to sleep'
print(info)

print("Validate model ... ")
auto_validate(valid_dataset, network, device_id)

print(network.get_encoder())
37 changes: 29 additions & 8 deletions SPFCN/train/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,34 @@ def step(self):

validating_image, validating_label = self.dataset.next()
index += 1

mark_precision, mark_recall, slot_precision, slot_recall = -1, -1, -1, -1

try:
mark_precision = mark_co_count / mark_re_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

try:
mark_recall = mark_co_count / mark_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_gt_count")

try:
slot_precision = slot_co_count / slot_re_count
except ZeroDivisionError:
print("ZeroDivisionError at slot_re_count")

try:
slot_recall = slot_co_count / slot_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at slot_gt_count")

print("\rIndex: {}, Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_co_count / mark_re_count, mark_co_count / mark_gt_count,
slot_co_count / slot_re_count, slot_co_count / slot_gt_count), end='')
print('\r' + ' ' * 50, end="")
print("Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(mark_co_count / mark_re_count, mark_co_count / mark_gt_count,
slot_co_count / slot_re_count, slot_co_count / slot_gt_count))
.format(index, mark_precision, mark_recall, slot_precision, slot_recall))

print("Current epoch score : Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(mark_precision, mark_recall, slot_precision, slot_recall))

def get_network_inference_time(self):
def foo(img):
Expand Down Expand Up @@ -144,6 +165,6 @@ def get_inference_time(self, foo):
time_step += time() - timestamp
validating_image, _ = self.dataset.next()
index += 1
print("\rIndex: {}, Inference Time: {:.1f}ms".format(index, 1e3 * time_step / index), end="")
print('\r' + ' ' * 40, end="")
print("\rIndex: {}, Inference Time: {:.1f}ms".format(index, 1e3 * time_step / index))
return "Inference Time: {:.1f}ms".format(1e3 * time_step / index)
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

### ORIGINAL VERSION ###
# Train model
slot_network_training(data_num=6535, batch_size=10, epoch=10, input_res=224, device_id=0, num_workers=0)
slot_network_training(data_num=6500, batch_size=32, valid_data_num=1500, valid_batch_size=48, epoch=80, input_res=224, device_id=0, num_workers=0)

# Test model
params_path = './parameters/merge_bn_epoch10_loss4.pkl'
slot_network_testing(parameter_path=params_path, data_num=1500, batch_size=50, input_res=224, device_id=0, num_workers=0)
params_path = './parameters/merge_bn_epoch80_loss1.pkl'
slot_network_testing(parameter_path=params_path, data_num=1500, batch_size=48, input_res=224, device_id=0, num_workers=0)

# Load detector
detector = SlotDetector(device_id=0, dim_encoder=[32, 44, 64, 92, 128], parameter_path=params_path)
Expand Down

0 comments on commit dba0cdb

Please sign in to comment.