-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/340 gaussian nb #474
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i made a few comments here. nothing crazy though. other than that it looks good to go
heat/core/statistics.py
Outdated
raise NotImplementedError( | ||
"weights.split does not match data.split: not implemented yet." | ||
) | ||
# fix after Issue #425 is solved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too many tabs and 1 line below this it looks like there is dead code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too many tabs (black's doing) or too many nested conditional statements?
if former:
🤷♀️
else:
I'll think about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was more confused about the structure here. but after sleeping on it i see what you meant. i read it now with the meaning of after issue #425 is solved then the if statement above these lines and the raise within will both be removed. i misunderstood it before i think
@@ -1227,7 +1230,7 @@ def reduce_vars_elementwise(output_shape_i): | |||
|
|||
var_shape = list(var.shape) if list(var.shape) else [1] | |||
|
|||
var_tot = factories.zeros(([x.comm.size, 2] + var_shape), device=x.device) | |||
var_tot = factories.zeros(([x.comm.size, 2] + var_shape), dtype=x.dtype, device=x.device) | |||
n_tot = factories.zeros(x.comm.size, device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possible dtype problem here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is var() supposed to return a float32? Even if x is float64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should not. i guess i missed a couple dtype calls here. can you add them for me? it should be just in this spot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought, so is this the possible dtype problem you were talking about? I guess I just misunderstood your first comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i just also realized that the dtype of n_tot doesnt matter because it is only used internally
…z-analytics/heat into features/340-GaussianNB
good work! |
Description
Initiating submodule naive_bayes and class GaussianNB().
Implementation of Gaussian Naive Bayes classification along the scikit-learn lines. Known issues listed below.
Issue/s resolved: #340
Changes proposed:
Known issues:
Adapt documentation, conventions from sklearn to HeAT (I know about those missing dunderscores!)DONEType of change
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
no