Skip to content

Commit 137811c

Browse files
committed
Make moves generic and add alternative move based on differential evolution.
1 parent 171b68c commit 137811c

File tree

4 files changed

+200
-37
lines changed

4 files changed

+200
-37
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ test = false
1515

1616
[dependencies]
1717
rand = { version = "0.9", default-features = false }
18+
rand_distr = "0.5"
1819
rayon = { version = "1.5", optional = true }
1920

2021
[dev-dependencies]
21-
rand_distr = "0.5"
2222
rand_pcg = "0.9"
2323

2424
[package.metadata.docs.rs]

src/lib.rs

+184-31
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
44
//!
55
//! ```
6-
//! use hammer_and_sample::{sample, MinChainLen, Model, Serial};
6+
//! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch};
77
//! use rand::{Rng, SeedableRng};
88
//! use rand_pcg::Pcg64;
99
//!
@@ -39,7 +39,7 @@
3939
//! ([p], rng)
4040
//! });
4141
//!
42-
//! let (chain, _accepted) = sample(&model, walkers, MinChainLen(10 * 1000), Serial);
42+
//! let (chain, _accepted) = sample(&model, &Stretch, walkers, MinChainLen(10 * 1000), Serial);
4343
//!
4444
//! // 100 iterations of 10 walkers as burn-in
4545
//! let chain = &chain[10 * 100..];
@@ -48,11 +48,16 @@
4848
//! }
4949
//! ```
5050
use std::ops::ControlFlow;
51+
use std::ptr;
5152

5253
use rand::{
5354
distr::{Distribution, StandardUniform, Uniform},
5455
Rng,
5556
};
57+
use rand_distr::{
58+
weighted::{AliasableWeight, WeightedAliasIndex},
59+
Normal,
60+
};
5661
#[cfg(feature = "rayon")]
5762
use rayon::iter::{IntoParallelRefMutIterator, ParallelExtend, ParallelIterator};
5863

@@ -117,6 +122,157 @@ impl Params for Box<[f64]> {
117122
}
118123
}
119124

125+
/// TODO
126+
pub trait Move<M>
127+
where
128+
M: Model,
129+
{
130+
/// TODO
131+
fn propose<'a, O, R>(&self, self_: &'a M::Params, other: O, rng: &mut R) -> (M::Params, f64)
132+
where
133+
O: FnMut(&mut R) -> &'a M::Params,
134+
R: Rng;
135+
}
136+
137+
/// TODO
138+
pub struct Stretch;
139+
140+
impl<M> Move<M> for Stretch
141+
where
142+
M: Model,
143+
{
144+
fn propose<'a, O, R>(&self, self_: &'a M::Params, mut other: O, rng: &mut R) -> (M::Params, f64)
145+
where
146+
O: FnMut(&mut R) -> &'a M::Params,
147+
R: Rng,
148+
{
149+
let other = other(rng);
150+
151+
let z = ((M::SCALE - 1.) * gen_unit(rng) + 1.).powi(2) / M::SCALE;
152+
153+
let new_state = M::Params::collect(
154+
self_
155+
.values()
156+
.zip(other.values())
157+
.map(|(self_, other)| other - z * (other - self_)),
158+
);
159+
160+
let factor = (new_state.dimension() - 1) as f64 * z.ln();
161+
162+
(new_state, factor)
163+
}
164+
}
165+
166+
/// TODO
167+
pub struct DifferentialEvolution {
168+
gamma: Normal<f64>,
169+
}
170+
171+
impl DifferentialEvolution {
172+
/// TODO
173+
pub fn new(sigma: Option<f64>, gamma0: Option<f64>, dimension: usize) -> Self {
174+
let sigma = sigma.unwrap_or(1e-5);
175+
let gamma0 = gamma0.unwrap_or_else(|| 2.38 / (2. * dimension as f64).sqrt());
176+
177+
Self {
178+
gamma: Normal::new(gamma0, sigma).unwrap(),
179+
}
180+
}
181+
}
182+
183+
impl<M> Move<M> for DifferentialEvolution
184+
where
185+
M: Model,
186+
{
187+
fn propose<'a, O, R>(&self, self_: &'a M::Params, mut other: O, rng: &mut R) -> (M::Params, f64)
188+
where
189+
O: FnMut(&mut R) -> &'a M::Params,
190+
R: Rng,
191+
{
192+
let first_other = other(rng);
193+
let mut second_other = other(rng);
194+
195+
while ptr::eq(first_other, second_other) {
196+
second_other = other(rng);
197+
}
198+
199+
let gamma = self.gamma.sample(rng);
200+
201+
let new_state = M::Params::collect(
202+
self_
203+
.values()
204+
.zip(first_other.values())
205+
.zip(second_other.values())
206+
.map(|((self_, first_other), second_other)| {
207+
self_ + gamma * (first_other - second_other)
208+
}),
209+
);
210+
211+
(new_state, 0.)
212+
}
213+
}
214+
215+
/// TODO
216+
pub struct Mixture<W, M>(WeightedAliasIndex<W>, M)
217+
where
218+
W: AliasableWeight;
219+
220+
macro_rules! impl_mixture {
221+
( $( $types:ident @ $weights:ident),+ ) => {
222+
impl<W, $( $types ),+> From<( $( ( $types, W ) ),+ )> for Mixture<W, ( $( $types ),+ )> where W: AliasableWeight {
223+
#[allow(non_snake_case)]
224+
fn from(( $( ( $types, $weights ) ),+ ): ( $( ( $types, W ) ),+ )) -> Self {
225+
let index = WeightedAliasIndex::new(vec![$( $weights ),+]).unwrap();
226+
227+
Self(index, ( $( $types ),+ ))
228+
}
229+
}
230+
231+
impl<W, $( $types ),+, M> Move<M> for Mixture<W, ( $( $types ),+ )>
232+
where
233+
W: AliasableWeight,
234+
M: Model,
235+
$( $types: Move<M> ),+
236+
{
237+
#[allow(non_snake_case)]
238+
fn propose<'a, O, R>(&self, self_: &'a M::Params, other: O, rng: &mut R) -> (M::Params, f64)
239+
where
240+
O: FnMut(&mut R) -> &'a M::Params,
241+
R: Rng,
242+
{
243+
let Self(index, ( $( $types ),+ )) = self;
244+
245+
let chosen_index = index.sample(rng);
246+
247+
let mut index = 0;
248+
249+
$(
250+
251+
#[allow(unused_assignments)]
252+
if index == chosen_index {
253+
return $types.propose(self_, other, rng)
254+
} else {
255+
index += 1;
256+
}
257+
258+
)+
259+
260+
unreachable!()
261+
}
262+
}
263+
};
264+
}
265+
266+
impl_mixture!(A @ a, B @ b);
267+
impl_mixture!(A @ a, B @ b, C @ c);
268+
impl_mixture!(A @ a, B @ b, C @ c, D @ d);
269+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e);
270+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f);
271+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g);
272+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h);
273+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i);
274+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j);
275+
120276
/// Models are defined by the type of their parameters and their probability functions
121277
pub trait Model: Send + Sync {
122278
/// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
@@ -138,17 +294,19 @@ pub trait Model: Send + Sync {
138294
/// The number of walkers must be non-zero, even and at least twice the number of parameters.
139295
///
140296
/// A vector of samples and the number of accepted moves are returned.
141-
pub fn sample<M, W, R, S, E>(
142-
model: &M,
297+
pub fn sample<MD, MV, W, R, S, E>(
298+
model: &MD,
299+
move_: &MV,
143300
walkers: W,
144301
mut schedule: S,
145302
execution: E,
146-
) -> (Vec<M::Params>, usize)
303+
) -> (Vec<MD::Params>, usize)
147304
where
148-
M: Model,
149-
W: Iterator<Item = (M::Params, R)>,
305+
MD: Model,
306+
MV: Move<MD> + Send + Sync,
307+
W: Iterator<Item = (MD::Params, R)>,
150308
R: Rng + Send + Sync,
151-
S: Schedule<M::Params>,
309+
S: Schedule<MD::Params>,
152310
E: Execution,
153311
{
154312
let mut walkers = walkers
@@ -166,10 +324,8 @@ where
166324

167325
let random_index = Uniform::new(0, half).unwrap();
168326

169-
let update_walker = move |walker: &mut Walker<M, R>, other_walkers: &[Walker<M, R>]| {
170-
let other = &other_walkers[random_index.sample(&mut walker.rng)];
171-
172-
walker.move_(model, other)
327+
let update_walker = move |walker: &mut Walker<MD, R>, other_walkers: &[Walker<MD, R>]| {
328+
walker.move_(model, move_, |rng| &other_walkers[random_index.sample(rng)])
173329
};
174330

175331
while schedule.next_step(&chain).is_continue() {
@@ -187,22 +343,22 @@ where
187343
(chain, accepted)
188344
}
189345

190-
struct Walker<M, R>
346+
struct Walker<MD, R>
191347
where
192-
M: Model,
348+
MD: Model,
193349
{
194-
state: M::Params,
350+
state: MD::Params,
195351
log_prob: f64,
196352
rng: R,
197353
accepted: usize,
198354
}
199355

200-
impl<M, R> Walker<M, R>
356+
impl<MD, R> Walker<MD, R>
201357
where
202-
M: Model,
358+
MD: Model,
203359
R: Rng,
204360
{
205-
fn new(model: &M, state: M::Params, rng: R) -> Self {
361+
fn new(model: &MD, state: MD::Params, rng: R) -> Self {
206362
let log_prob = model.log_prob(&state);
207363

208364
Self {
@@ -213,20 +369,17 @@ where
213369
}
214370
}
215371

216-
fn move_(&mut self, model: &M, other: &Self) -> M::Params {
217-
let z = ((M::SCALE - 1.) * gen_unit(&mut self.rng) + 1.).powi(2) / M::SCALE;
218-
219-
let mut new_state = M::Params::collect(
220-
self.state
221-
.values()
222-
.zip(other.state.values())
223-
.map(|(self_, other)| other - z * (other - self_)),
224-
);
372+
fn move_<'a, MV, O>(&'a mut self, model: &MD, move_: &MV, mut other: O) -> MD::Params
373+
where
374+
MV: Move<MD>,
375+
O: FnMut(&mut R) -> &'a Self,
376+
{
377+
let (mut new_state, factor) =
378+
move_.propose(&self.state, |rng| &other(rng).state, &mut self.rng);
225379

226380
let new_log_prob = model.log_prob(&new_state);
227381

228-
let log_prob_diff =
229-
(new_state.dimension() - 1) as f64 * z.ln() + new_log_prob - self.log_prob;
382+
let log_prob_diff = factor + new_log_prob - self.log_prob;
230383

231384
if log_prob_diff > gen_unit(&mut self.rng).ln() {
232385
self.state.clone_from(&new_state);
@@ -380,7 +533,7 @@ where
380533
/// Runs the inner `schedule` after calling the given `callback`
381534
///
382535
/// ```
383-
/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, WithProgress};
536+
/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
384537
/// # use rand::SeedableRng;
385538
/// # use rand_pcg::Pcg64Mcg;
386539
/// #
@@ -407,7 +560,7 @@ where
407560
/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
408561
/// };
409562
///
410-
/// let (chain, accepted) = sample(&model, walkers, schedule, Serial);
563+
/// let (chain, accepted) = sample(&model, &Stretch, walkers, schedule, Serial);
411564
/// ```
412565
pub struct WithProgress<S, C> {
413566
/// The inner schedule which determines the number of iterations

tests/coin_flip.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use rand::{
44
};
55
use rand_pcg::Pcg64Mcg;
66

7-
use hammer_and_sample::{auto_corr_time, sample, MinChainLen, Model, Serial};
7+
use hammer_and_sample::{
8+
auto_corr_time, sample, DifferentialEvolution, MinChainLen, Mixture, Model, Serial, Stretch,
9+
};
810

911
#[test]
1012
fn coin_flip() {
@@ -48,7 +50,9 @@ fn coin_flip() {
4850
([guess_p], rng)
4951
});
5052

51-
let (chain, accepted) = sample(&model, walkers, MinChainLen(100_000), Serial);
53+
let move_ = Mixture::from(((Stretch, 2), (DifferentialEvolution::new(None, None, 1), 1)));
54+
55+
let (chain, accepted) = sample(&model, &move_, walkers, MinChainLen(100_000), Serial);
5256

5357
let converged_chain = &chain[10_000..];
5458

@@ -59,7 +63,7 @@ fn coin_flip() {
5963

6064
let acceptance_rate = accepted as f64 / chain.len() as f64;
6165

62-
assert!(acceptance_rate > 0.7 && acceptance_rate < 0.8);
66+
assert!(acceptance_rate > 0.6 && acceptance_rate < 0.7);
6367

6468
let auto_corr_time_p =
6569
auto_corr_time(converged_chain.iter().map(|params| params[0]), None, None).unwrap();

tests/hierarchical.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use rand::{seq::SliceRandom, SeedableRng};
22
use rand_distr::{Bernoulli, Distribution, Normal};
33
use rand_pcg::Pcg64Mcg;
44

5-
use hammer_and_sample::{auto_corr_time, sample, MinChainLen, Model, Parallel};
5+
use hammer_and_sample::{auto_corr_time, sample, MinChainLen, Model, Parallel, Stretch};
66

77
#[test]
88
fn hierarchical() {
@@ -116,7 +116,13 @@ fn hierarchical() {
116116
(guess, rng)
117117
});
118118

119-
let (chain, accepted) = sample(&model, walkers, MinChainLen(WALKERS * ITERATIONS), Parallel);
119+
let (chain, accepted) = sample(
120+
&model,
121+
&Stretch,
122+
walkers,
123+
MinChainLen(WALKERS * ITERATIONS),
124+
Parallel,
125+
);
120126

121127
let converged_chain = &chain[WALKERS * BURN_IN..];
122128

0 commit comments

Comments
 (0)