Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update config for resnet50 #1

Merged
merged 9 commits into from
Mar 4, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update config.py
Update learning rate and weight decay
  • Loading branch information
John1231983 authored Mar 2, 2018
commit 04a69eea0d3227440df9317dee017dec5522de0b
12 changes: 6 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Config(object):
# Validation stats are also calculated at each epoch end and they
# might take a while, so don't set this too small to avoid spending
# a lot of time on validation stats.
STEPS_PER_EPOCH = 1000
STEPS_PER_EPOCH = 670

# Number of validation steps to run at the end of every training epoch.
# A bigger number improves accuracy of validation stats, but slows
Expand Down Expand Up @@ -87,8 +87,8 @@ class Config(object):
# Images are resized such that the smallest side is >= IMAGE_MIN_DIM and
# the longest side is <= IMAGE_MAX_DIM. In case both conditions can't
# be satisfied together the IMAGE_MAX_DIM is enforced.
IMAGE_MIN_DIM = 800
IMAGE_MAX_DIM = 1024
IMAGE_MIN_DIM = 512
IMAGE_MAX_DIM = 512
# If True, pad images with zeros such that they're (max_dim by max_dim)
IMAGE_PADDING = True # currently, the False option is not supported

Expand All @@ -111,7 +111,7 @@ class Config(object):
MASK_SHAPE = [28, 28]

# Maximum number of ground truth instances to use in one image
MAX_GT_INSTANCES = 512
MAX_GT_INSTANCES = 256

# Bounding box refinement standard deviation for RPN and final detections.
RPN_BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2])
Expand All @@ -131,11 +131,11 @@ class Config(object):
# The Mask RCNN paper uses lr=0.02, but on TensorFlow it causes
# weights to explode. Likely due to differences in optimzer
# implementation.
LEARNING_RATE = 0.0001
LEARNING_RATE = 0.001
LEARNING_MOMENTUM = 0.9

# Weight decay regularization
WEIGHT_DECAY = 0.001
WEIGHT_DECAY = 0.0001

# Use RPN ROIs or externally generated ROIs for training
# Keep this True for most situations. Set to False if you want to train
Expand Down