-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorboard logging fixes #544
Conversation
trieste/models/gpflow/interface.py
Outdated
for i, lengthscale in enumerate(lengthscales): | ||
logging.scalar(f"kernel.lengthscale[{i}]", lengthscale) | ||
kernel = self.get_kernel() | ||
if isinstance(kernel, gpflow.kernels.Stationary): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hard to say if this would be a good enough check, I don't know GPflow well enough - perhaps add a test for this?
it might be good enough to support only the basic case, but I think we probably can do something more generic, like fetching all trainable variables? I think they also come with names so we can still make it descriptive enough, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check should be sufficient: gpflow.kernels.Stationary is where the lengthscales and variance parameters are set.
Makes sense to add a test. I'll see how much we can extract generically, but it might be tricky inferring types, shapes, etc. If it's too much work this may have to do for now (though I can add it to a WIBNI issue).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be a bug fix for the kernel parameters, but i suspect we can come up with generic readout of kernel parameters and log all of them - that might work for complex kernels as well
other changes look good
for i, lengthscale in enumerate(lengthscales): | ||
logging.scalar(f"kernel.lengthscale[{i}]", lengthscale) | ||
kernel = self.get_kernel() | ||
components = _merge_leaf_components(leaf_components(kernel)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasnt aware of this, great that you have worked it out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much better now, great work :)
No description provided.