Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more quantization strategies to contrib.epi #2440

Merged
merged 7 commits into from
Apr 24, 2020
Merged

add more quantization strategies to contrib.epi #2440

merged 7 commits into from
Apr 24, 2020

Conversation

martinjankowiak
Copy link
Collaborator

@martinjankowiak martinjankowiak commented Apr 24, 2020

Addresses #2426

elif num_quant_bins == 16:
global w16
if w16.device != s.device:
w16 = w16.to(s.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a better way to deal with this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is really bad, you're updating the global variable. How about renaming the global variable to W16 and then using w16 = s.new_tensor(W16).

@@ -33,8 +33,9 @@
'contrib/autoname/mixture.py --num-epochs=1',
'contrib/autoname/tree_data.py --num-epochs=1',
'contrib/cevae/synthetic.py --num-epochs=1',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i changed some of the args to reduce test time to below a minute

@martinjankowiak
Copy link
Collaborator Author

i still intend to add a basic sampler test

"""

def __init__(self, compartments, duration, population):
def __init__(self, compartments, duration, population, num_quant_bins=4):
Copy link
Member

@fritzo fritzo Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a *, separator to force use as a kwarg rather than arg?

def __init__(self, compartments, duration, population, *,
             num_quant_bins=4):

This gives us flexibility to later (1) reorder all kwargs after the *,, or even (2) lump them into a **kwargs dict.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -91,6 +93,10 @@ def __init__(self, compartments, duration, population):
assert len(compartments) == len(set(compartments))
self.compartments = compartments

if num_quant_bins not in [4, 8, 12, 16]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please weaken this check and let the other file decide what values are allowed:

assert isininstance(num_quant_bins, int)
assert num_quant_bins > 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""

def __init__(self, population, recovery_time, data):
def __init__(self, population, recovery_time, data, num_quant_bins=4):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: Add *, separator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 151 to 152
arange_min = - (num_quant_bins // 2 - 1)
arange_max = num_quant_bins // 2 + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitnit: I guess you could write these as

    arange_min = 1 - num_quant_bins // 2
    arange_max = 1 + num_quant_bins // 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good after my minor comments.

Could you published the notebook deriving spline weights and to link to it in PR description? Feel free to push directly to my notebooks repo or whatever. We can move that to a new repo like pyro-ppl/derivations or something if you want.

@martinjankowiak
Copy link
Collaborator Author

Could you published the notebook deriving spline weights and to link to it in PR description? Feel free to push directly to my notebooks repo or whatever. We can move that to a new repo like pyro-ppl/derivations or something if you want.

my notebooks are a mess and i'd rather do this once i finalize which schemes i'd like to keep user facing

y = torch.min(y, 2 * max + 1 - y)
probs.scatter_add_(0, y, bin_probs[:, k] / num_samples)

max_deviation = (probs - 1.0 / (max + 1.0)).abs().max().item()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in total these tests take about a second

@fritzo fritzo merged commit b52a139 into dev Apr 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants