-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Fixed lbfgs for ray-cluster #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
def set_weights(new_weights): | ||
sess.run(assignment_nodes, feed_dict={p: w for p, w in zip(assignment_placeholders, new_weights)}) | ||
class LinearModel(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this class definition before the __main__
section of the file (after the imports)? Same for net_initialization and net_reinitialization, loss, grad, full_loss and full_grad!
sess.run(assignment_nodes, feed_dict={p: w for p, w in zip(assignment_placeholders, new_weights)}) | ||
class LinearModel(object): | ||
def __init__(self, shape): | ||
image_dimension = shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you get rid of image_dimension and label_dimension? In general LinearModel can be an arbitrary linear model (not related to images) and these could be arbitrary shapes, and just substituting shape[0] and shape[1] is good!
@@ -73,27 +70,29 @@ def set_weights(new_weights): | |||
def net_reinitialization(net_vars): | |||
return net_vars | |||
|
|||
# Create a reusable variable for the network. | |||
# Register the network with Ray and create a reusable variable for it. | |||
ray.register_class(LinearModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
register_class shouldn't be neccessary, we shouldn't have to serialize the network (and we shouldn't, if we have to there is a problem somewhere)
@@ -73,27 +70,29 @@ def set_weights(new_weights): | |||
def net_reinitialization(net_vars): | |||
return net_vars | |||
|
|||
# Create a reusable variable for the network. | |||
# Register the network with Ray and create a reusable variable for it. | |||
ray.register_class(LinearModel) | |||
ray.reusables.net_vars = ray.Reusable(net_initialization, net_reinitialization) | |||
|
|||
# Load the weights into the network. | |||
def load_weights(theta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you define generic functions set_flat and get_flat in TensorFlowVariables for this and use them below?
No description provided.