-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhierarchical.rs
149 lines (106 loc) · 4.2 KB
/
hierarchical.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use rand::{seq::SliceRandom, SeedableRng};
use rand_distr::{Bernoulli, Distribution, Normal};
use rand_pcg::Pcg64Mcg;
use hammer_and_sample::{auto_corr_time, sample, MinChainLen, Model, Parallel};
#[test]
fn hierarchical() {
const GROUPS: usize = 10;
const OBSERVATIONS: usize = 10000;
const WALKERS: usize = 100;
const ITERATIONS: usize = 5000;
const BURN_IN: usize = 1000;
struct Hierarchical {
data: [(usize, usize); GROUPS],
}
impl Model for Hierarchical {
type Params = [f64; 2 + GROUPS];
fn log_prob(&self, state: &Self::Params) -> f64 {
let logit_theta = state[0];
let sigma = state[1];
// prior on logit_theta
let mut log_prob = log_normal(logit_theta, -2., 2.);
// prior on sigma (half normal)
if sigma < 0. {
return f64::NEG_INFINITY;
}
log_prob += log_normal(sigma, 0., 2.);
for group in 0..GROUPS {
let group_alpha = state[2 + group];
let group_theta = expit(logit_theta + group_alpha);
// likelihood of group_alpha given sigma
log_prob += log_normal(group_alpha, 0., sigma);
// likelihood of data given group_theta
let outcomes = self.data[group];
log_prob += outcomes.0 as f64 * group_theta.ln();
log_prob += outcomes.1 as f64 * (1. - group_theta).ln();
}
log_prob
}
}
let mut data = Vec::new();
let true_logit_theta = -3.;
let true_sigma = 1.;
let mut true_theta = [0.; GROUPS];
let mut rng = Pcg64Mcg::seed_from_u64(0);
let true_alpha_dist = Normal::<f64>::new(0., true_sigma).unwrap();
#[allow(clippy::needless_range_loop)]
for group in 0..GROUPS {
let true_group_alpha = true_alpha_dist.sample(&mut rng);
let true_group_theta = expit(true_logit_theta + true_group_alpha);
true_theta[group] = true_group_theta;
let dist = Bernoulli::new(true_group_theta).unwrap();
for _ in 0..OBSERVATIONS {
data.push((dist.sample(&mut rng), group));
}
}
data.shuffle(&mut rng);
let model = Hierarchical {
data: data
.into_iter()
.fold([(0, 0); GROUPS], |mut outcomes, (data, group)| {
if data {
outcomes[group].0 += 1;
} else {
outcomes[group].1 += 1;
}
outcomes
}),
};
let prior_logit_theta = Normal::<f64>::new(-2., 2.).unwrap();
let prior_sigma = Normal::<f64>::new(0., 2.).unwrap();
let walkers = (0..WALKERS).map(|_| {
let mut rng = Pcg64Mcg::from_rng(&mut rng);
let guess_logit_theta = prior_logit_theta.sample(&mut rng);
let guess_sigma = prior_sigma.sample(&mut rng).abs();
let mut guess = [0.; 2 + GROUPS];
guess[0] = guess_logit_theta;
guess[1] = guess_sigma;
let prior_alpha = Normal::new(0., guess_sigma).unwrap();
for group in 0..GROUPS {
let guess_group_alpha = prior_alpha.sample(&mut rng);
guess[2 + group] = guess_group_alpha;
}
(guess, rng)
});
let (chain, accepted) = sample(&model, walkers, MinChainLen(WALKERS * ITERATIONS), Parallel);
let converged_chain = &chain[WALKERS * BURN_IN..];
for group in 0..GROUPS {
let estimated_group_theta = converged_chain
.iter()
.map(|params| expit(params[0] + params[2 + group]))
.sum::<f64>()
/ converged_chain.len() as f64;
assert!((true_theta[group] - estimated_group_theta).abs() < 0.005);
}
let acceptance_rate = accepted as f64 / chain.len() as f64;
assert!(acceptance_rate > 0.3 && acceptance_rate < 0.4);
let auto_corr_time_logit_theta =
auto_corr_time(converged_chain.iter().map(|params| params[0]), None, None).unwrap();
assert!(converged_chain.len() as f64 / auto_corr_time_logit_theta > 10_000.);
}
fn expit(x: f64) -> f64 {
1. / (1. + (-x).exp())
}
fn log_normal(x: f64, mu: f64, sigma: f64) -> f64 {
-0.5 * ((x - mu) / sigma).powi(2) - sigma.ln()
}