-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graphing Model Created with Model Subclassing #3527
Comments
One way I know of is to use import numpy as np
import tensorflow as tf
from tensorflow.python.ops import summary_ops_v2
class SubclassedModel(tf.keras.Model):
def __init__(self):
super(SubclassedModel, self).__init__()
self.layer_1 = tf.keras.layers.Dense(
1, activation="sigmoid", input_shape=[3])
@tf.function
def call(self, inputs):
return self.layer_1(inputs)
model = SubclassedModel()
model.compile(loss="binary_crossentropy", optimizer="adam")
logdir = "/tmp/subclassed_model_logdir"
xs = np.ones([4, 3])
ys = np.zeros([4, 1])
model.fit(xs, ys, epochs=1,
callbacks=tf.keras.callbacks.TensorBoard(logdir))
writer = summary_ops_v2.create_file_writer_v2(logdir)
with writer.as_default():
summary_ops_v2.graph(model.call.get_concrete_function(xs).graph) # <--
writer.flush() Does this suite your need, @ryanmaxwell96 ? |
I'm getting this error now:
Written by Patrick Coady (pat-coady.github.io) import tensorflow as tf from fractalnet_regularNN import * from tensorflow.python.ops import summary_ops_v2 class Policy(object):
class PolicyNN(Layer):
class KLEntropy(Layer):
class LogProb(Layer):
class TRPO(Model):
` |
Sorry, that code is really hard to read. I have uploaded the applicable file where I have added your suggested code to Policy class line 48-55. https://github.com/ryanmaxwell96/trpo_fractal1NN_3/blob/master/policy.py |
My bad, I forgot to include the @tf.function part. But after I do this (which is now on my github), I get the same error. It is now on lines 136 to 150 in policy.py with a call to the appropriate function in train.py at line 311 |
The error you referred to above (related to |
I was running into the same problem when I tried to rearrange things as you suggested so I decided to try to run your code and I'm getting the same error as well: TypeError: 'TensorBoard' object is not iterable I have the following installed: I'm not sure if that helps at all or not. |
I'm trying to plot a model in TF2 that was made with the model subclassing method. (The code I've been modifying has built its models with the model subclassing method and I have not been able to find any way to plot the fractal model.) However, I have yet to find a functional way of plotting a model made this way.
Is there anyway to plot a model if it has been made with the subclassing method?
Thanks,
Ryan
The text was updated successfully, but these errors were encountered: