Skip to content

Commit 60080cc

Browse files
committed
Make moves generic and add alternative and mixed moves.
1 parent 171b68c commit 60080cc

File tree

4 files changed

+258
-39
lines changed

4 files changed

+258
-39
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ test = false
1515

1616
[dependencies]
1717
rand = { version = "0.9", default-features = false }
18+
rand_distr = "0.5"
1819
rayon = { version = "1.5", optional = true }
1920

2021
[dev-dependencies]
21-
rand_distr = "0.5"
2222
rand_pcg = "0.9"
2323

2424
[package.metadata.docs.rs]

src/lib.rs

+236-34
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
44
//!
55
//! ```
6-
//! use hammer_and_sample::{sample, MinChainLen, Model, Serial};
6+
//! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch};
77
//! use rand::{Rng, SeedableRng};
88
//! use rand_pcg::Pcg64;
99
//!
@@ -39,7 +39,7 @@
3939
//! ([p], rng)
4040
//! });
4141
//!
42-
//! let (chain, _accepted) = sample(&model, walkers, MinChainLen(10 * 1000), Serial);
42+
//! let (chain, _accepted) = sample(&model, &Stretch::default(), walkers, MinChainLen(10 * 1000), Serial);
4343
//!
4444
//! // 100 iterations of 10 walkers as burn-in
4545
//! let chain = &chain[10 * 100..];
@@ -48,11 +48,16 @@
4848
//! }
4949
//! ```
5050
use std::ops::ControlFlow;
51+
use std::ptr;
5152

5253
use rand::{
5354
distr::{Distribution, StandardUniform, Uniform},
5455
Rng,
5556
};
57+
use rand_distr::{
58+
weighted::{AliasableWeight, WeightedAliasIndex},
59+
Normal,
60+
};
5661
#[cfg(feature = "rayon")]
5762
use rayon::iter::{IntoParallelRefMutIterator, ParallelExtend, ParallelIterator};
5863

@@ -117,6 +122,209 @@ impl Params for Box<[f64]> {
117122
}
118123
}
119124

125+
/// TODO
126+
pub trait Move<M>
127+
where
128+
M: Model,
129+
{
130+
/// TODO
131+
fn propose<'a, O, R>(&self, self_: &'a M::Params, other: O, rng: &mut R) -> (M::Params, f64)
132+
where
133+
O: FnMut(&mut R) -> &'a M::Params,
134+
R: Rng;
135+
}
136+
137+
/// TODO
138+
pub struct Stretch {
139+
scale: f64,
140+
}
141+
142+
impl Stretch {
143+
/// TODO
144+
pub fn new(scale: f64) -> Self {
145+
Self { scale }
146+
}
147+
}
148+
149+
impl Default for Stretch {
150+
fn default() -> Self {
151+
Self::new(2.)
152+
}
153+
}
154+
155+
impl<M> Move<M> for Stretch
156+
where
157+
M: Model,
158+
{
159+
fn propose<'a, O, R>(&self, self_: &'a M::Params, mut other: O, rng: &mut R) -> (M::Params, f64)
160+
where
161+
O: FnMut(&mut R) -> &'a M::Params,
162+
R: Rng,
163+
{
164+
let other = other(rng);
165+
166+
let z = ((self.scale - 1.) * gen_unit(rng) + 1.).powi(2) / self.scale;
167+
168+
let new_state = M::Params::collect(
169+
self_
170+
.values()
171+
.zip(other.values())
172+
.map(|(self_, other)| other - z * (other - self_)),
173+
);
174+
175+
let factor = (new_state.dimension() - 1) as f64 * z.ln();
176+
177+
(new_state, factor)
178+
}
179+
}
180+
181+
/// TODO
182+
pub struct DifferentialEvolution {
183+
gamma: Normal<f64>,
184+
}
185+
186+
impl DifferentialEvolution {
187+
/// TODO
188+
pub fn new(gamma_mean: f64, gamma_std_dev: f64) -> Self {
189+
Self {
190+
gamma: Normal::new(gamma_mean, gamma_std_dev).unwrap(),
191+
}
192+
}
193+
}
194+
195+
impl<M> Move<M> for DifferentialEvolution
196+
where
197+
M: Model,
198+
{
199+
fn propose<'a, O, R>(&self, self_: &'a M::Params, mut other: O, rng: &mut R) -> (M::Params, f64)
200+
where
201+
O: FnMut(&mut R) -> &'a M::Params,
202+
R: Rng,
203+
{
204+
let first_other = other(rng);
205+
let mut second_other = other(rng);
206+
207+
while ptr::eq(first_other, second_other) {
208+
second_other = other(rng);
209+
}
210+
211+
let gamma = self.gamma.sample(rng);
212+
213+
let new_state = M::Params::collect(
214+
self_
215+
.values()
216+
.zip(first_other.values())
217+
.zip(second_other.values())
218+
.map(|((self_, first_other), second_other)| {
219+
self_ + gamma * (first_other - second_other)
220+
}),
221+
);
222+
223+
(new_state, 0.)
224+
}
225+
}
226+
227+
/// TODO
228+
pub struct RandomGaussian {
229+
dist: Normal<f64>,
230+
}
231+
232+
impl RandomGaussian {
233+
/// TODO
234+
pub fn new(scale: f64) -> Self {
235+
Self {
236+
dist: Normal::new(0., scale).unwrap(),
237+
}
238+
}
239+
}
240+
241+
impl<M> Move<M> for RandomGaussian
242+
where
243+
M: Model,
244+
{
245+
fn propose<'a, O, R>(&self, self_: &'a M::Params, _other: O, rng: &mut R) -> (M::Params, f64)
246+
where
247+
O: FnMut(&mut R) -> &'a M::Params,
248+
R: Rng,
249+
{
250+
let dir = rng.random_range(0..self_.dimension());
251+
252+
let new_state = M::Params::collect(self_.values().enumerate().map(|(idx, value)| {
253+
if idx == dir {
254+
value + self.dist.sample(rng)
255+
} else {
256+
*value
257+
}
258+
}));
259+
260+
(new_state, 0.)
261+
}
262+
}
263+
264+
/// TODO
265+
pub struct Mixture<W, M>(WeightedAliasIndex<W>, M)
266+
where
267+
W: AliasableWeight;
268+
269+
macro_rules! impl_mixture {
270+
( $( $types:ident @ $weights:ident),+ ) => {
271+
impl<W, $( $types ),+> From<( $( ( $types, W ) ),+ )> for Mixture<W, ( $( $types ),+ )>
272+
where
273+
W: AliasableWeight
274+
{
275+
#[allow(non_snake_case)]
276+
fn from(( $( ( $types, $weights ) ),+ ): ( $( ( $types, W ) ),+ )) -> Self {
277+
let index = WeightedAliasIndex::new(vec![$( $weights ),+]).unwrap();
278+
279+
Self(index, ( $( $types ),+ ))
280+
}
281+
}
282+
283+
impl<W, $( $types ),+, M> Move<M> for Mixture<W, ( $( $types ),+ )>
284+
where
285+
W: AliasableWeight,
286+
M: Model,
287+
$( $types: Move<M> ),+
288+
{
289+
#[allow(non_snake_case)]
290+
fn propose<'a, O, R>(&self, self_: &'a M::Params, other: O, rng: &mut R) -> (M::Params, f64)
291+
where
292+
O: FnMut(&mut R) -> &'a M::Params,
293+
R: Rng,
294+
{
295+
let Self(index, ( $( $types ),+ )) = self;
296+
297+
let chosen_index = index.sample(rng);
298+
299+
let mut index = 0;
300+
301+
$(
302+
303+
#[allow(unused_assignments)]
304+
if chosen_index == index {
305+
return $types.propose(self_, other, rng)
306+
} else {
307+
index += 1;
308+
}
309+
310+
)+
311+
312+
unreachable!()
313+
}
314+
}
315+
};
316+
}
317+
318+
impl_mixture!(A @ a, B @ b);
319+
impl_mixture!(A @ a, B @ b, C @ c);
320+
impl_mixture!(A @ a, B @ b, C @ c, D @ d);
321+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e);
322+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f);
323+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g);
324+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h);
325+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i);
326+
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j);
327+
120328
/// Models are defined by the type of their parameters and their probability functions
121329
pub trait Model: Send + Sync {
122330
/// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
@@ -126,9 +334,6 @@ pub trait Model: Send + Sync {
126334
///
127335
/// The sampler will only ever consider differences of these values, i.e. any addititive constant that does _not_ depend on `state` can be omitted when computing them.
128336
fn log_prob(&self, state: &Self::Params) -> f64;
129-
130-
/// Scale parameter for stretch moves
131-
const SCALE: f64 = 2.;
132337
}
133338

134339
/// Runs the sampler on the given [`model`][Model] using the chosen [`schedule`][Schedule] and [`execution`][Execution] strategy
@@ -138,17 +343,19 @@ pub trait Model: Send + Sync {
138343
/// The number of walkers must be non-zero, even and at least twice the number of parameters.
139344
///
140345
/// A vector of samples and the number of accepted moves are returned.
141-
pub fn sample<M, W, R, S, E>(
142-
model: &M,
346+
pub fn sample<MD, MV, W, R, S, E>(
347+
model: &MD,
348+
move_: &MV,
143349
walkers: W,
144350
mut schedule: S,
145351
execution: E,
146-
) -> (Vec<M::Params>, usize)
352+
) -> (Vec<MD::Params>, usize)
147353
where
148-
M: Model,
149-
W: Iterator<Item = (M::Params, R)>,
354+
MD: Model,
355+
MV: Move<MD> + Send + Sync,
356+
W: Iterator<Item = (MD::Params, R)>,
150357
R: Rng + Send + Sync,
151-
S: Schedule<M::Params>,
358+
S: Schedule<MD::Params>,
152359
E: Execution,
153360
{
154361
let mut walkers = walkers
@@ -166,10 +373,8 @@ where
166373

167374
let random_index = Uniform::new(0, half).unwrap();
168375

169-
let update_walker = move |walker: &mut Walker<M, R>, other_walkers: &[Walker<M, R>]| {
170-
let other = &other_walkers[random_index.sample(&mut walker.rng)];
171-
172-
walker.move_(model, other)
376+
let update_walker = move |walker: &mut Walker<MD, R>, other_walkers: &[Walker<MD, R>]| {
377+
walker.move_(model, move_, |rng| &other_walkers[random_index.sample(rng)])
173378
};
174379

175380
while schedule.next_step(&chain).is_continue() {
@@ -187,22 +392,22 @@ where
187392
(chain, accepted)
188393
}
189394

190-
struct Walker<M, R>
395+
struct Walker<MD, R>
191396
where
192-
M: Model,
397+
MD: Model,
193398
{
194-
state: M::Params,
399+
state: MD::Params,
195400
log_prob: f64,
196401
rng: R,
197402
accepted: usize,
198403
}
199404

200-
impl<M, R> Walker<M, R>
405+
impl<MD, R> Walker<MD, R>
201406
where
202-
M: Model,
407+
MD: Model,
203408
R: Rng,
204409
{
205-
fn new(model: &M, state: M::Params, rng: R) -> Self {
410+
fn new(model: &MD, state: MD::Params, rng: R) -> Self {
206411
let log_prob = model.log_prob(&state);
207412

208413
Self {
@@ -213,20 +418,17 @@ where
213418
}
214419
}
215420

216-
fn move_(&mut self, model: &M, other: &Self) -> M::Params {
217-
let z = ((M::SCALE - 1.) * gen_unit(&mut self.rng) + 1.).powi(2) / M::SCALE;
218-
219-
let mut new_state = M::Params::collect(
220-
self.state
221-
.values()
222-
.zip(other.state.values())
223-
.map(|(self_, other)| other - z * (other - self_)),
224-
);
421+
fn move_<'a, MV, O>(&'a mut self, model: &MD, move_: &MV, mut other: O) -> MD::Params
422+
where
423+
MV: Move<MD>,
424+
O: FnMut(&mut R) -> &'a Self,
425+
{
426+
let (mut new_state, factor) =
427+
move_.propose(&self.state, |rng| &other(rng).state, &mut self.rng);
225428

226429
let new_log_prob = model.log_prob(&new_state);
227430

228-
let log_prob_diff =
229-
(new_state.dimension() - 1) as f64 * z.ln() + new_log_prob - self.log_prob;
431+
let log_prob_diff = factor + new_log_prob - self.log_prob;
230432

231433
if log_prob_diff > gen_unit(&mut self.rng).ln() {
232434
self.state.clone_from(&new_state);
@@ -380,7 +582,7 @@ where
380582
/// Runs the inner `schedule` after calling the given `callback`
381583
///
382584
/// ```
383-
/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, WithProgress};
585+
/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
384586
/// # use rand::SeedableRng;
385587
/// # use rand_pcg::Pcg64Mcg;
386588
/// #
@@ -407,7 +609,7 @@ where
407609
/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
408610
/// };
409611
///
410-
/// let (chain, accepted) = sample(&model, walkers, schedule, Serial);
612+
/// let (chain, accepted) = sample(&model, &Stretch::default(), walkers, schedule, Serial);
411613
/// ```
412614
pub struct WithProgress<S, C> {
413615
/// The inner schedule which determines the number of iterations

0 commit comments

Comments
 (0)