Skip to content

Commit 401871d

Browse files
committed
Add a schedule combinator that makes it easier to add progress reporting
1 parent 24748cc commit 401871d

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "hammer-and-sample"
33
description = "Simplistic MCMC ensemble sampler based on emcee, the MCMC hammer"
4-
version = "0.3.0"
4+
version = "0.3.1"
55
edition = "2021"
66
license = "MIT OR Apache-2.0"
77
readme = "README.md"

src/lib.rs

+56
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,62 @@ where
365365
}
366366
}
367367

368+
/// Runs the inner `schedule` after calling the given `callback`
369+
///
370+
/// ```
371+
/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, WithProgress};
372+
/// # use rand::SeedableRng;
373+
/// # use rand_pcg::Pcg64Mcg;
374+
/// #
375+
/// # struct CoinFlip;
376+
/// #
377+
/// # impl Model for CoinFlip {
378+
/// # type Params = [f64; 1];
379+
/// #
380+
/// # fn log_prob(&self, state: &Self::Params) -> f64 {
381+
/// # f64::NEG_INFINITY
382+
/// # }
383+
/// # }
384+
/// #
385+
/// # let model = CoinFlip;
386+
/// #
387+
/// # let walkers = (0..100).map(|idx| {
388+
/// # let mut rng = Pcg64Mcg::seed_from_u64(idx);
389+
/// #
390+
/// # ([0.5], rng)
391+
/// # });
392+
/// #
393+
/// let schedule = WithProgress {
394+
/// schedule: MinChainLen(100_000),
395+
/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
396+
/// };
397+
///
398+
/// let (chain, accepted) = sample(&model, walkers, schedule, Serial);
399+
/// ```
400+
pub struct WithProgress<S, C> {
401+
/// The inner schedule which determines the number of iterations
402+
pub schedule: S,
403+
/// The callback which is executed after each iteration
404+
pub callback: C,
405+
}
406+
407+
impl<P, S, C> Schedule<P> for WithProgress<S, C>
408+
where
409+
P: Params,
410+
S: Schedule<P>,
411+
C: FnMut(&[P]),
412+
{
413+
fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
414+
(self.callback)(chain);
415+
416+
self.schedule.next_step(chain)
417+
}
418+
419+
fn iterations(&self, walkers: usize) -> Option<usize> {
420+
self.schedule.iterations(walkers)
421+
}
422+
}
423+
368424
/// Execution strategy for `update`ing an ensemble of `walkers` to extend the given `chain`
369425
pub trait Execution {
370426
/// Must call `update` exactly once for all elements of `walkers` and store the results in `chain`

0 commit comments

Comments
 (0)