-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid unclear TypeError when using theano.shared variables as input to distribution parameters #4445
Conversation
We could instead also do a |
Interesting. Do we want to nudge people towards pm.Data (which can be safely used for distribution parameters since #3925). If we want to nudge I would mention it in the AttributeError (splitting the infinite and the TypeError messages). If we want to support theano.shared I wouldn't close the linked issue. Edit: Just for completeness, since #3925 shared tensors will work properly ( |
But I do agree with @kc611 that this may all be made redundant with the |
I think we can avoid the TypeError altogether if we change: https://github.com/pymc-devs/pymc3/blob/03d7af5b6dd5ad99ab2f3bd8ca7987a744dbef46/pymc3/distributions/distribution.py#L170-L171 To: if isinstance(val, (
tt.sharedvar.TensorSharedVariable, # pm.Data or theano.shared tensor
theano.tensor.sharedvar.ScalarSharedVariable, # theano.shared scalar
theano.compile.sharedvalue.SharedVariable, # theano.shared tensor from non-numpy array such as list
):
return val.get_value() I don't see any obvious drawbacks and it makes the API more forgiving. However, we should then add some unittest along these lines, to make sure it is working as intended: https://github.com/pymc-devs/pymc3/blob/03d7af5b6dd5ad99ab2f3bd8ca7987a744dbef46/pymc3/tests/test_data_container.py#L139-L157 @AlexAndorra since you worked on enabling the pm.Data as input to other rvs, do you think there is any reason to not accomodate these other shared types? |
Thanks for the deep dive @ricardoV94 !
No, adding these looks really fine to me 👌 |
@kc611 is this something that you want to do in this PR? (it's totally fine if you are not interested) |
No issues on my side. I'll make the changes shortly. |
I just realized that the changes checking for with pm.Model() as m:
shared_var = theano.shared([5.0, 5.0]) # Fails: cannot be safely coerced into theano.config.floatX
v = pm.Normal("v", mu=shared_var, shape=2) Raises So I think we can just drop that condition / check. The best would be to have a way to nudge users into using np.array, but I don't see an easy way to achieve that. Note that the original goal of this PR would not address this issue either. The current PR still solves the case where theano.shared is a scalar (which was the failing example that motivated this PR), which is already more forgiving compared to what we had before: with pm.Model() as m:
shared_var = theano.shared(5.0) # Failed before
v = pm.Normal("v", mu=shared_var, shape=2) And just for completeness, it is still fine to use a theano.shared with a numpy array (or equivalent, pm.Data): with pm.Model() as m:
shared_var = theano.shared(np.array([5.0, 5.0])) # Still fine
v = pm.Normal("v", mu=shared_var, shape=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last suggestion I promise!
In the meanwhile, you can go ahead and add a maintenance Release Note mentioning that theano ScalarSharedVariable
can now also be used as input to other RVs.
Here I'm feeling stupid for not noticing such obvious stuff. :-P . Anyway thanks for your huge support in this PR. (I wasn't exactly familiar with |
I also didn't notice the "obvious" stuff before! I am sorry if I came across a bit heavy-handed, I was unsure of how willing / comfortable you were with the suggested changes, since the original PR / issue had a very different angle to it. I really appreciate your patience and effort! |
Leaving it open just in case anyone finds issues with it. Will merge in a day or so otherwise. |
Fixes #3139
As suggested by @rpgoldman over here #3139 (comment). This PR just makes the error message a bit clearer (I just re-arranged it a bit, the error message suggesting to add a
test_val
argument was already present).