Skip to content

Commit

Permalink
add valid
Browse files Browse the repository at this point in the history
  • Loading branch information
shippingwang committed Nov 18, 2020
1 parent 0f22930 commit a81f48e
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 19 deletions.
25 changes: 16 additions & 9 deletions configs/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ DATASET: #DATASET field
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
data_prefix: "" #Mandatory, train data root path
file_path: "data/ucf101/ucf101_train_split_1_videos.txt" #Mandatory, train data index file path
# suffix: '{:05}.jpg'
valid:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
data_prefix: "data/ucf101/videos_val" #Mandatory, valid data root path
file_path: "./data/dataset/ucf101/valid_video.txt" #Mandatory, valid data index file path
# suffix: '{:05}.jpg'
data_prefix: "" #Mandatory, valid data root path
file_path: "./data/ucf101/ucf101_val_split_1_videos.txt" #Mandatory, valid data index file path

PIPELINE: #PIPELINE field
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
Expand All @@ -36,7 +34,7 @@ PIPELINE: #PIPELINE field
valid_mode: False
transform: #Mandotary, image transfrom operator
- Scale:
short_size: 224
short_size: 256
- RandomCrop:
target_size: 224
- RandomFlip:
Expand All @@ -46,14 +44,23 @@ PIPELINE: #PIPELINE field
std: [0.229, 0.224, 0.225]

valid: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
valid_mode: True
clip_len: 1
num_clips: 8
num_seg: 8
seg_len: 1
transform:
- Resize:
- Norm:
- Scale:
short_size: 256
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]


OPTIMIZER: #OPTIMIZER field
name: 'Momentum' #Mandatory, the type of optimizer, associate to the 'paddlevideo/solver/'
Expand Down
5 changes: 5 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def main():
paddle.distributed.init_parallel_env()

model = build_model(cfg.MODEL)

#NOTE: To debug dataloader or wanna inspect your data, please try to print(dataset) here.
# for i in dataset[0]:
# logger.error(i)

dataset = [build_dataset((cfg.DATASET.train, cfg.PIPELINE.train))]
if args.validate:
dataset.append(build_dataset((cfg.DATASET.valid, cfg.PIPELINE.valid)))
Expand Down
2 changes: 1 addition & 1 deletion paddlevideo/modeling/framework/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def train_step(self, data_batch, **kwargs):
return loss_metrics

def val_step(self, data_batch, **kwargs):
"""Interface to valid
"""Validating setp.
"""
imgs = data_batch[0]
labels = data_batch[1]
Expand Down
50 changes: 41 additions & 9 deletions paddlevideo/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train_model(model,
dataset,
cfg,
parallel=True,
validate=False):
validate=True):
"""Train model entry
Args:
Expand Down Expand Up @@ -63,7 +63,9 @@ def train_model(model,
model = paddle.DataParallel(model)

train_loader = data_loaders[0]

if validate:
valid_loader = data_loaders[1]

metric_list = [
("loss", AverageMeter('loss', '7.5f')),
("lr", AverageMeter(
Expand Down Expand Up @@ -116,17 +118,47 @@ def train_model(model,
if i == 20:
break

"""
Note: validate is not ready now!

"""
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
[metric_list['batch_time'].total])
end_epoch_str = "END epoch:{:<3d}".format(epoch)
end_epoch_str = "END epoch:{:<3d}".format(epoch)

logger.info("{:s} {:s} {:s}s".format(
logger.info("{:s} {:s} {:s}s".format(
coloring(end_epoch_str, "RED"),
coloring("TRAIN", "PURPLE"),
coloring(end_str, "OKGREEN")))

logger.info('[TRAIN] training finished')
if validate:
model.eval()
tic = time.time()
for i, data in enumerate(valid_loader):
if parallel:
outputs = model._layers.val_step(data)
else:
outputs = model.val_step(data)
for name, value in outputs.items():
metric_list[name].update(value.numpy()[0], batch_size)
metric_list['batch_time'].update(time.time() - tic)
tic = time.time()


if i % cfg.get("log_interval", 10) == 0:
fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
epoch_str = "epoch:[{:>3d}/{:<3d}]".format(epoch, cfg.epochs)
step_str = "{:s} step:{:<4d}".format("valid", i)
logger.info("{:s} {:s} {:s}s".format(
coloring(epoch_str, "HEADER")
if i == 0 else epoch_str,
coloring(step_str, "PURPLE"),
coloring(fetchs_str, 'OKGREEN')))
if i == 20:
break

logger.info("{:s} {:s} {:s}s".format(
coloring(end_epoch_str, "RED"),
coloring("VALID", "PURPLE"),
coloring(end_str, "OKGREEN")))

#save and resume is not ready!!!

logger.info('training finished') #info of yaml
2 changes: 2 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export CUDA_VISIBLE_DEVICES=0
python3 -B -m paddle.distributed.launch --selected_gpus=0 main.py -c configs/example.yaml -o log_level="INFO"

0 comments on commit a81f48e

Please sign in to comment.