-
Notifications
You must be signed in to change notification settings - Fork 616
Implementation of Conditional Gradient Optimizer #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
30ca1ad
Add files via upload
pkan2 eb64aba
Add files via upload
pkan2 802c313
Add files via upload
pkan2 ae7dfcd
Add files via upload
pkan2 6768daf
Add files via upload
pkan2 aced155
Add files via upload
pkan2 b14081a
Add files via upload
pkan2 68dbf64
Add files via upload
pkan2 932bb7f
Add files via upload
pkan2 22d57d8
Add files via upload
pkan2 aeb991c
Add files via upload
pkan2 09ec631
Add files via upload
pkan2 d9ef23e
Add files via upload
pkan2 0c51eba
Add files via upload
pkan2 953fa39
add CG optimizer
pkan2 892d602
Revert "add CG optimizer"
pkan2 865ab62
Merge branch 'master' of https://github.com/tensorflow/addons
pkan2 c800f01
Add files via upload
pkan2 7d2b6c6
Add files via upload
pkan2 02d421a
Add files via upload
pkan2 c0dd737
Add files via upload
pkan2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Conditional Gradient method for TensorFlow.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow_addons.utils import keras_utils | ||
|
||
|
||
@keras_utils.register_keras_custom_object | ||
class ConditionalGradient(tf.keras.optimizers.Optimizer): | ||
"""Optimizer that implements the Conditional Gradient optimization. | ||
|
||
This optimizer helps handle constraints well. | ||
|
||
Currently only supports frobenius norm constraint. | ||
See https://arxiv.org/pdf/1803.06453.pdf | ||
|
||
``` | ||
variable -= (1-learning_rate) | ||
* (variable + lambda_ * gradient / frobenius_norm(gradient)) | ||
``` | ||
|
||
Note that we choose "lambda_" here to refer to the constraint "lambda" in the paper. | ||
""" | ||
|
||
def __init__(self, | ||
learning_rate, | ||
lambda_, | ||
use_locking=False, | ||
name='ConditionalGradient', | ||
**kwargs): | ||
"""Construct a conditional gradient optimizer. | ||
|
||
Args: | ||
learning_rate: A `Tensor` or a floating point value. | ||
pkan2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The learning rate. | ||
lambda_: A `Tensor` or a floating point value. The constraint. | ||
use_locking: If `True` use locks for update operations. | ||
name: Optional name prefix for the operations created when | ||
applying gradients. Defaults to 'ConditionalGradient' | ||
""" | ||
super(ConditionalGradient, self).__init__(name=name, **kwargs) | ||
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) | ||
self._set_hyper('lambda_', lambda_) | ||
self._set_hyper('use_locking', use_locking) | ||
|
||
def get_config(self): | ||
config = { | ||
'learning_rate': self._serialize_hyperparameter('learning_rate'), | ||
'lambda_': self._serialize_hyperparameter('lambda_'), | ||
'use_locking': self._serialize_hyperparameter('use_locking') | ||
} | ||
base_config = super(ConditionalGradient, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
def _create_slots(self, var_list): | ||
for v in var_list: | ||
self.add_slot(v, 'conditional_gradient') | ||
|
||
def _prepare_local(self, var_device, var_dtype, apply_state): | ||
super(ConditionalGradient, self)._prepare_local( | ||
var_device, var_dtype, apply_state) | ||
apply_state[(var_device, var_dtype)]['learning_rate'] = tf.identity( | ||
self._get_hyper('learning_rate', var_dtype)) | ||
apply_state[(var_device, var_dtype)]['lambda_'] = tf.identity( | ||
self._get_hyper('lambda_', var_dtype)) | ||
|
||
def _resource_apply_dense(self, grad, var, apply_state=None): | ||
def frobenius_norm(m): | ||
return tf.math.reduce_sum(m**2)**0.5 | ||
|
||
var_device, var_dtype = var.device, var.dtype.base_dtype | ||
coefficients = ((apply_state or {}).get((var_device, var_dtype)) | ||
or self._fallback_apply_state(var_device, var_dtype)) | ||
norm = tf.convert_to_tensor( | ||
frobenius_norm(grad), name='norm', dtype=var.dtype.base_dtype) | ||
lr = coefficients['learning_rate'] | ||
lambda_ = coefficients['lambda_'] | ||
var_update_tensor = ( | ||
tf.math.multiply(var, lr) - (1 - lr) * lambda_ * grad / norm) | ||
var_update_kwargs = { | ||
'resource': var.handle, | ||
'value': var_update_tensor, | ||
} | ||
var_update_op = tf.raw_ops.AssignVariableOp(**var_update_kwargs) | ||
return tf.group(var_update_op) | ||
|
||
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): | ||
def frobenius_norm(m): | ||
return tf.reduce_sum(m**2)**0.5 | ||
|
||
var_device, var_dtype = var.device, var.dtype.base_dtype | ||
coefficients = ((apply_state or {}).get((var_device, var_dtype)) | ||
or self._fallback_apply_state(var_device, var_dtype)) | ||
norm = tf.convert_to_tensor( | ||
frobenius_norm(grad), name='norm', dtype=var.dtype.base_dtype) | ||
lr = coefficients['learning_rate'] | ||
lambda_ = coefficients['lambda_'] | ||
var_slice = tf.gather(var, indices) | ||
var_update_value = ( | ||
tf.math.multiply(var_slice, lr) - (1 - lr) * lambda_ * grad / norm) | ||
var_update_kwargs = { | ||
'resource': var.handle, | ||
'indices': indices, | ||
'updates': var_update_value | ||
} | ||
var_update_op = tf.raw_ops.ResourceScatterUpdate(**var_update_kwargs) | ||
return tf.group(var_update_op) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.