Skip to content

Commit 944da6d

Browse files
committed
Make the Params trait more generic to isolate it from the move semantics
1 parent 401871d commit 944da6d

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "hammer-and-sample"
33
description = "Simplistic MCMC ensemble sampler based on emcee, the MCMC hammer"
4-
version = "0.3.1"
4+
version = "0.4.0"
55
edition = "2021"
66
license = "MIT OR Apache-2.0"
77
readme = "README.md"

src/lib.rs

+27-13
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ pub trait Params: Send + Sync + Clone {
6363
/// This can depend on `self` in situations where the number of parameters depends on the data itself, e.g. the number of groups in a hierarchical model.
6464
fn dimension(&self) -> usize;
6565

66-
/// Propose new parameters doing a stretch move based on the parameters `other` and the scale `z`
66+
/// Compute new parameters by mapping the given closure `f` over all coordinate pairs
6767
#[must_use]
68-
fn propose(&self, other: &Self, z: f64) -> Self;
68+
fn map<F>(&self, other: &Self, f: F) -> Self
69+
where
70+
F: Fn(f64, f64) -> f64;
6971
}
7072

7173
/// Model parameters stored as an array of length `N` considered as an element of the vector space `R^N`
@@ -74,10 +76,13 @@ impl<const N: usize> Params for [f64; N] {
7476
N
7577
}
7678

77-
fn propose(&self, other: &Self, z: f64) -> Self {
79+
fn map<F>(&self, other: &Self, f: F) -> Self
80+
where
81+
F: Fn(f64, f64) -> f64,
82+
{
7883
let mut new = [0.; N];
7984
for i in 0..N {
80-
new[i] = other[i] - z * (other[i] - self[i]);
85+
new[i] = f(self[i], other[i]);
8186
}
8287
new
8388
}
@@ -89,10 +94,13 @@ impl Params for Vec<f64> {
8994
self.len()
9095
}
9196

92-
fn propose(&self, other: &Self, z: f64) -> Self {
97+
fn map<F>(&self, other: &Self, f: F) -> Self
98+
where
99+
F: Fn(f64, f64) -> f64,
100+
{
93101
self.iter()
94102
.zip(other)
95-
.map(|(self_, other)| other - z * (other - self_))
103+
.map(|(self_, other)| f(*self_, *other))
96104
.collect()
97105
}
98106
}
@@ -103,10 +111,13 @@ impl Params for Box<[f64]> {
103111
self.len()
104112
}
105113

106-
fn propose(&self, other: &Self, z: f64) -> Self {
114+
fn map<F>(&self, other: &Self, f: F) -> Self
115+
where
116+
F: Fn(f64, f64) -> f64,
117+
{
107118
self.iter()
108119
.zip(other.iter())
109-
.map(|(self_, other)| other - z * (other - self_))
120+
.map(|(self_, other)| f(*self_, *other))
110121
.collect()
111122
}
112123
}
@@ -210,7 +221,10 @@ where
210221
fn move_(&mut self, model: &M, other: &Self) -> M::Params {
211222
let z = ((M::SCALE - 1.) * gen_unit(&mut self.rng) + 1.).powi(2) / M::SCALE;
212223

213-
let mut new_state = self.state.propose(&other.state, z);
224+
let mut new_state = self
225+
.state
226+
.map(&other.state, |self_, other| other - z * (other - self_));
227+
214228
let new_log_prob = model.log_prob(&new_state);
215229

216230
let log_prob_diff =
@@ -372,22 +386,22 @@ where
372386
/// # use rand::SeedableRng;
373387
/// # use rand_pcg::Pcg64Mcg;
374388
/// #
375-
/// # struct CoinFlip;
389+
/// # struct Dummy;
376390
/// #
377-
/// # impl Model for CoinFlip {
391+
/// # impl Model for Dummy {
378392
/// # type Params = [f64; 1];
379393
/// #
380394
/// # fn log_prob(&self, state: &Self::Params) -> f64 {
381395
/// # f64::NEG_INFINITY
382396
/// # }
383397
/// # }
384398
/// #
385-
/// # let model = CoinFlip;
399+
/// # let model = Dummy;
386400
/// #
387401
/// # let walkers = (0..100).map(|idx| {
388402
/// # let mut rng = Pcg64Mcg::seed_from_u64(idx);
389403
/// #
390-
/// # ([0.5], rng)
404+
/// # ([0.], rng)
391405
/// # });
392406
/// #
393407
/// let schedule = WithProgress {

0 commit comments

Comments
 (0)