Skip to content

Add LARS support #10374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 15, 2018
Merged

Add LARS support #10374

merged 3 commits into from
Jun 15, 2018

Conversation

typhoonzero
Copy link
Contributor

@typhoonzero typhoonzero commented May 3, 2018

Fix #6811

Related: #7788

To use, add LARS_weight_decay=[some value greater than 0] to enable LARS, LARS can also works along with current learning rate schedulers, like "polynomial_decay" e.g.

opt = fld.optimizer.Momentum(
  learning_rate=layers.polynomial_decay(learning_rate=1.0, decay_steps=100, power=2.0),
  momentum=0.8,
  LARS_weight_decay=0.3)

or

opt = fld.optimizer.Momentum(
  learning_rate=1.0,
  momentum=0.8,
  LARS_weight_decay=0.3)

@typhoonzero typhoonzero requested review from reyoung and Yancey0623 May 3, 2018 08:49
@typhoonzero typhoonzero changed the title add LARS support Add LARS support May 3, 2018
@typhoonzero typhoonzero requested a review from jacquesqiao May 3, 2018 12:56


def append_LARS(params_grads, learning_rate, weight_decay):
"""Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add the link to the paper here.

@jacquesqiao
Copy link
Member

Please add a unit test for this learning rate scheduler strategy.

@jacquesqiao jacquesqiao closed this May 4, 2018
@jacquesqiao jacquesqiao reopened this May 4, 2018
def __init__(self,
learning_rate,
regularization=None,
LARS_weight_decay=0.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the paper, I think the default value of LARS_weight_decay may be 1.0.

@typhoonzero
Copy link
Contributor Author

The program can not transpile to distributed version correctly when using LARS, still debugging.

@typhoonzero
Copy link
Contributor Author

@jacquesqiao Can we merge this for now, so I can test it using NCCL2 distributed training

Copy link
Member

@jacquesqiao jacquesqiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@typhoonzero typhoonzero merged commit 53d1d0f into PaddlePaddle:develop Jun 15, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants