Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Distribution a templated class by sample type. #4

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

template <typename SampleType = double> class Distribution {
// Abstract base class for probability distributions in HIRM.
public:
// N is the number of incorporated observations.
int N = 0;

virtual void incorporate(SampleType x) = 0;
virtual void unincorporate(SampleType x) = 0;

// The log probability of x according to the distribution we have
// accumulated so far.
virtual double logp(SampleType x) const = 0;

virtual double logp_score() const = 0;

// A sample from the distribution we have accumulated so far.
virtual SampleType sample() = 0;

virtual ~Distribution() = default;
};

51 changes: 51 additions & 0 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include "base.hh"

class BetaBernoulli : public Distribution<double> {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
int s = 0; // sum of observed values
PRNG *prng;

BetaBernoulli(PRNG *prng) {
this->prng = prng;
}
void incorporate(double x){
assert(x == 0 || x == 1);
N += 1;
s += x;
}
void unincorporate(double x) {
assert(x == 0 || x ==1);
N -= 1;
s -= x;
assert(0 <= s);
assert(0 <= N);
}
double logp(double x) const {
double log_denom = log(N + alpha + beta);
if (x == 1) { return log(s + alpha) - log_denom; }
if (x == 0) { return log(N - s + beta) - log_denom; }
assert(false);
}
double logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}
double sample() {
double p = exp(logp(1));
vector<int> items {0, 1};
vector<double> weights {1-p, p};
auto idx = choice(weights, prng);
return items[idx];
}

// Disable copying.
BetaBernoulli & operator=(const BetaBernoulli&) = delete;
BetaBernoulli(const BetaBernoulli&) = delete;
};
65 changes: 65 additions & 0 deletions cxx/distributions/dirichlet_categorical.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include "base.hh"

class DirichletCategorical : public Distribution<double> {
public:
double alpha = 1; // hyperparameter (applies to all categories)
std::vector<int> counts; // counts of observed categories
int n; // Total number of observations.
PRNG *prng;

DirichletCategorical(PRNG *prng, int k) { // k is number of categories
this->prng = prng;
counts = std::vector<int>(k, 0);
n = 0;
}
void incorporate(double x) {
assert(x >= 0 && x < counts.size());
counts[size_t(x)] += 1;
++n;
}
void unincorporate(double x) {
const size_t y = x;
assert(y < counts.size());
counts[y] -= 1;
--n;
assert(0 <= counts[y]);
assert(0 <= n);
}
double logp(double x) const {
assert(x >= 0 && x < counts.size());
const double numer = log(alpha + counts[size_t(x)]);
const double denom = log(n + alpha * counts.size());
return numer - denom;
}
double logp_score() const {
const size_t k = counts.size();
const double a = alpha * k;
const double lg = std::transform_reduce(
counts.cbegin(),
counts.cend(),
0,
std::plus{},
[&](size_t y) -> double {return lgamma(counts[y] + alpha); }
);
return lgamma(a) - lgamma(a + n) + lg - k * lgamma(alpha);
}
double sample() {
vector<double> weights(counts.size());
std::transform(
counts.begin(),
counts.end(),
weights.begin(),
[&](size_t y) -> double { return y + alpha; }
);
int idx = choice(weights, prng);
return double(idx);
}

// Disable copying.
DirichletCategorical & operator=(const DirichletCategorical&) = delete;
DirichletCategorical(const DirichletCategorical&) = delete;
};
59 changes: 59 additions & 0 deletions cxx/distributions/normal.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include "base.hh"

#ifndef M_2PI
#define M_2PI 6.28318530717958647692528676655
#endif

class Normal : public Distribution<double> {
public:
// We use Welford's algorithm for computing the mean and variance
// of streaming data in a numerically stable way. See Knuth's
// Art of Computer Programming vol. 2, 3rd edition, page 232.
int mean = 0; // Mean of observed values
int var = 0; // Variance of observed values

PRNG *prng;

Normal(PRNG *prng) {
this->prng = prng;
}

void incorporate(double x){
N += 1;
double old_mean = mean;
mean += (x - mean) / N;
var += (x - mean) * (x - old_mean);
}

void unincorporate(double x) {
int old_N = N;
N -= 1;
double old_mean = mean;
mean = (mean * old_N - x) / N;
var -= (x - mean) * (x - old_mean);
}

double logp(double x) const {
double y = (x - mean);
return -0.5 * (y * y / var + log(var) + log(M_2PI));
}

double logp_score() const {
// TODO(thomaswc): This.
return 0.0;
}

double sample() {
std::normal_distribution<double> d(mean, var);
return d(*prng);
}

// Disable copying.
Normal & operator=(const Normal&) = delete;
Normal(const Normal&) = delete;
};

72 changes: 9 additions & 63 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,68 +5,14 @@
#include "globals.hh"
#include "util_hash.hh"
#include "util_math.hh"
#include "distributions/base.hh"
#include "distributions/beta_bernoulli.hh"
#include "distributions/dirichlet_categorical.hh"

typedef int T_item;
typedef vector<T_item> T_items;
typedef VectorIntHash H_items;

class Distribution {
public:
int N = 0;
virtual void incorporate(double x) = 0;
virtual void unincorporate(double x) = 0;
virtual double logp(double x) const = 0;
virtual double logp_score() const = 0;
virtual double sample() = 0;
~Distribution(){};
};

class BetaBernoulli : public Distribution {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
int s = 0; // sum of observed values
PRNG *prng;

BetaBernoulli(PRNG *prng) {
this->prng = prng;
}
void incorporate(double x){
assert(x == 0 || x == 1);
N += 1;
s += x;
}
void unincorporate(double x) {
assert(x == 0 || x ==1);
N -= 1;
s -= x;
assert(0 <= s);
assert(0 <= N);
}
double logp(double x) const {
double log_denom = log(N + alpha + beta);
if (x == 1) { return log(s + alpha) - log_denom; }
if (x == 0) { return log(N - s + beta) - log_denom; }
assert(false);
}
double logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}
double sample() {
double p = exp(logp(1));
vector<int> items {0, 1};
vector<double> weights {1-p, p};
auto idx = choice(weights, prng);
return items[idx];
}

// Disable copying.
BetaBernoulli & operator=(const BetaBernoulli&) = delete;
BetaBernoulli(const BetaBernoulli&) = delete;
};

class CRP {
public:
double alpha = 1; // concentration parameter
Expand Down Expand Up @@ -243,7 +189,7 @@ public:
// list of domain pointers
const vector<Domain*> domains;
// map from cluster multi-index to Distribution pointer
umap<const vector<int>, Distribution*, VectorIntHash> clusters;
umap<const vector<int>, Distribution<double>*, VectorIntHash> clusters;
// map from item to observed data
umap<const T_items, double, H_items> data;
// map from domain name to reverse map from item to
Expand Down Expand Up @@ -421,7 +367,7 @@ public:
auto z = get_cluster_assignment_gibbs(items_list[0], domain, item, table);

BetaBernoulli aux (prng);
Distribution * cluster = clusters.count(z) > 0 ? clusters.at(z) : &aux;
Distribution<double> * cluster = clusters.count(z) > 0 ? clusters.at(z) : &aux;
// auto cluster = self.clusters.get(z, self.aux())
auto logp0 = cluster->logp_score();
for (const auto &items : items_list) {
Expand Down Expand Up @@ -480,7 +426,7 @@ public:
i_list = {0};
} else {
auto tables_weights = domain->tables_weights();
auto Z = log(1 + domain->crp.N);
auto Z = log(domain->crp.alpha + domain->crp.N);
int idx = 0;
for (const auto &[t, w] : tables_weights) {
t_list.push_back(t);
Expand All @@ -505,7 +451,7 @@ public:
logp_w += wi;
}
BetaBernoulli aux (prng);
Distribution * cluster = clusters.count(z) > 0 ? clusters.at(z) : &aux;
Distribution<double> * cluster = clusters.count(z) > 0 ? clusters.at(z) : &aux;
auto logp_z = cluster->logp(value);
auto logp_zw = logp_z + logp_w;
logps.push_back(logp_zw);
Expand Down Expand Up @@ -676,7 +622,7 @@ public:
i_list = {0};
} else {
auto tables_weights = domain->tables_weights();
auto Z = log(1 + domain->crp.N);
auto Z = log(domain->crp.alpha + domain->crp.N);
int idx = 0;
for (const auto &[t, w] : tables_weights) {
t_list.push_back(t);
Expand Down Expand Up @@ -717,7 +663,7 @@ public:
z.push_back(t);
}
BetaBernoulli aux (prng);
Distribution * cluster = relation->clusters.count(z) > 0 \
Distribution<double> * cluster = relation->clusters.count(z) > 0 \
? relation->clusters.at(z)
: &aux;
logp_indexes += cluster->logp(value);
Expand Down
15 changes: 11 additions & 4 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,20 @@ int main(int argc, char **argv) {
std::cout << "writing clusters to " << path_clusters << std::endl;
to_txt(path_clusters, irm, encoding);

auto item_to_code = std::get<0>(encoding);
auto code_item_0_D1 = item_to_code.at("D1").at("0");
auto code_item_10_D1 = item_to_code.at("D1").at("10");
auto code_item_0_D2 = item_to_code.at("D2").at("0");
auto code_item_10_D2 = item_to_code.at("D2").at("10");
auto code_item_novel = 100;

map<int, map<int, double>> expected_p0 {
{0, { {0, 1}, {10, 1}, {100, .5} } },
{10, { {0, 0}, {10, 0}, {100, .5} } },
{100, { {0, .66}, {10, .66}, {100, .5} } },
{code_item_0_D1, { {code_item_0_D2, 1}, {code_item_10_D2, 1}, {code_item_novel, .5} } },
{code_item_10_D1, { {code_item_0_D2, 0}, {code_item_10_D2, 0}, {code_item_novel, .5} } },
{code_item_novel, { {code_item_0_D2, .66}, {code_item_10_D2, .66}, {code_item_novel, .5} } },
};

vector<vector<int>> indexes {{0, 10, 100}, {0, 10, 100}};
vector<vector<int>> indexes {{code_item_0_D1, code_item_10_D1, code_item_novel}, {code_item_0_D1, code_item_10_D2, code_item_novel}};
for (const auto &l : product(indexes)) {
assert(l.size() == 2);
auto x1 = l.at(0);
Expand Down
12 changes: 12 additions & 0 deletions cxx/tests/test_misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ int main(int argc, char **argv) {
}
printf("\n");

DirichletCategorical dc (&prng, 8);
dc.incorporate(1);
dc.incorporate(1);
dc.incorporate(3);
dc.unincorporate(1);
printf("%f\n", exp(dc.logp(5)));
printf("%f\n", exp(dc.logp_score()));
for (int i = 0; i < 100; i++) {
printf("%1.f ", dc.sample());
}
printf("\n");

CRP crp (&prng);
crp.alpha = 1.5;
printf("starting crp\n");
Expand Down
4 changes: 2 additions & 2 deletions src/hirm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def logp(self, items, value):
i_list = [0]
else:
tables_weights = domain.tables_weights()
Z = math.log(1 + domain.crp.N)
Z = math.log(domain.crp.alpha + domain.crp.N)
t_list = tuple(tables_weights.keys())
w_list = tuple(math.log(x) - Z for x in tables_weights.values())
i_list = tuple(range(len(tables_weights)))
Expand Down Expand Up @@ -593,7 +593,7 @@ def logp_observations(observations):
else:
tables_weights = domain.tables_weights()
t_list = tuple(tables_weights.keys())
Z = math.log(1 + domain.crp.N)
Z = math.log(domain.crp.alpha + domain.crp.N)
w_list = tuple(math.log(x) - Z for x in tables_weights.values())
i_list = tuple(range(len(tables_weights)))
item_universe.add((domain.name, item))
Expand Down