3
3
//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
4
4
//!
5
5
//! ```
6
- //! use hammer_and_sample::{sample, MinChainLen, Model, Serial};
6
+ //! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch };
7
7
//! use rand::{Rng, SeedableRng};
8
8
//! use rand_pcg::Pcg64;
9
9
//!
39
39
//! ([p], rng)
40
40
//! });
41
41
//!
42
- //! let (chain, _accepted) = sample(&model, walkers, MinChainLen(10 * 1000), Serial);
42
+ //! let (chain, _accepted) = sample(&model, &Stretch::default(), walkers, MinChainLen(10 * 1000), Serial);
43
43
//!
44
44
//! // 100 iterations of 10 walkers as burn-in
45
45
//! let chain = &chain[10 * 100..];
48
48
//! }
49
49
//! ```
50
50
use std:: ops:: ControlFlow ;
51
+ use std:: ptr;
51
52
52
53
use rand:: {
53
54
distr:: { Distribution , StandardUniform , Uniform } ,
54
55
Rng ,
55
56
} ;
57
+ use rand_distr:: {
58
+ weighted:: { AliasableWeight , WeightedAliasIndex } ,
59
+ Normal ,
60
+ } ;
56
61
#[ cfg( feature = "rayon" ) ]
57
62
use rayon:: iter:: { IntoParallelRefMutIterator , ParallelExtend , ParallelIterator } ;
58
63
@@ -117,6 +122,209 @@ impl Params for Box<[f64]> {
117
122
}
118
123
}
119
124
125
+ /// TODO
126
+ pub trait Move < M >
127
+ where
128
+ M : Model ,
129
+ {
130
+ /// TODO
131
+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , other : O , rng : & mut R ) -> ( M :: Params , f64 )
132
+ where
133
+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
134
+ R : Rng ;
135
+ }
136
+
137
+ /// TODO
138
+ pub struct Stretch {
139
+ scale : f64 ,
140
+ }
141
+
142
+ impl Stretch {
143
+ /// TODO
144
+ pub fn new ( scale : f64 ) -> Self {
145
+ Self { scale }
146
+ }
147
+ }
148
+
149
+ impl Default for Stretch {
150
+ fn default ( ) -> Self {
151
+ Self :: new ( 2. )
152
+ }
153
+ }
154
+
155
+ impl < M > Move < M > for Stretch
156
+ where
157
+ M : Model ,
158
+ {
159
+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , mut other : O , rng : & mut R ) -> ( M :: Params , f64 )
160
+ where
161
+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
162
+ R : Rng ,
163
+ {
164
+ let other = other ( rng) ;
165
+
166
+ let z = ( ( self . scale - 1. ) * gen_unit ( rng) + 1. ) . powi ( 2 ) / self . scale ;
167
+
168
+ let new_state = M :: Params :: collect (
169
+ self_
170
+ . values ( )
171
+ . zip ( other. values ( ) )
172
+ . map ( |( self_, other) | other - z * ( other - self_) ) ,
173
+ ) ;
174
+
175
+ let factor = ( new_state. dimension ( ) - 1 ) as f64 * z. ln ( ) ;
176
+
177
+ ( new_state, factor)
178
+ }
179
+ }
180
+
181
+ /// TODO
182
+ pub struct DifferentialEvolution {
183
+ gamma : Normal < f64 > ,
184
+ }
185
+
186
+ impl DifferentialEvolution {
187
+ /// TODO
188
+ pub fn new ( gamma_mean : f64 , gamma_std_dev : f64 ) -> Self {
189
+ Self {
190
+ gamma : Normal :: new ( gamma_mean, gamma_std_dev) . unwrap ( ) ,
191
+ }
192
+ }
193
+ }
194
+
195
+ impl < M > Move < M > for DifferentialEvolution
196
+ where
197
+ M : Model ,
198
+ {
199
+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , mut other : O , rng : & mut R ) -> ( M :: Params , f64 )
200
+ where
201
+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
202
+ R : Rng ,
203
+ {
204
+ let first_other = other ( rng) ;
205
+ let mut second_other = other ( rng) ;
206
+
207
+ while ptr:: eq ( first_other, second_other) {
208
+ second_other = other ( rng) ;
209
+ }
210
+
211
+ let gamma = self . gamma . sample ( rng) ;
212
+
213
+ let new_state = M :: Params :: collect (
214
+ self_
215
+ . values ( )
216
+ . zip ( first_other. values ( ) )
217
+ . zip ( second_other. values ( ) )
218
+ . map ( |( ( self_, first_other) , second_other) | {
219
+ self_ + gamma * ( first_other - second_other)
220
+ } ) ,
221
+ ) ;
222
+
223
+ ( new_state, 0. )
224
+ }
225
+ }
226
+
227
+ /// TODO
228
+ pub struct RandomGaussian {
229
+ dist : Normal < f64 > ,
230
+ }
231
+
232
+ impl RandomGaussian {
233
+ /// TODO
234
+ pub fn new ( scale : f64 ) -> Self {
235
+ Self {
236
+ dist : Normal :: new ( 0. , scale) . unwrap ( ) ,
237
+ }
238
+ }
239
+ }
240
+
241
+ impl < M > Move < M > for RandomGaussian
242
+ where
243
+ M : Model ,
244
+ {
245
+ fn propose < ' a , O , R > ( & self , self_ : & ' a M :: Params , _other : O , rng : & mut R ) -> ( M :: Params , f64 )
246
+ where
247
+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
248
+ R : Rng ,
249
+ {
250
+ let dir = rng. random_range ( 0 ..self_. dimension ( ) ) ;
251
+
252
+ let new_state = M :: Params :: collect ( self_. values ( ) . enumerate ( ) . map ( |( idx, value) | {
253
+ if idx == dir {
254
+ value + self . dist . sample ( rng)
255
+ } else {
256
+ * value
257
+ }
258
+ } ) ) ;
259
+
260
+ ( new_state, 0. )
261
+ }
262
+ }
263
+
264
+ /// TODO
265
+ pub struct Mixture < W , M > ( WeightedAliasIndex < W > , M )
266
+ where
267
+ W : AliasableWeight ;
268
+
269
+ macro_rules! impl_mixture {
270
+ ( $( $types: ident @ $weights: ident) ,+ ) => {
271
+ impl <W , $( $types ) ,+> From <( $( ( $types, W ) ) ,+ ) > for Mixture <W , ( $( $types ) ,+ ) >
272
+ where
273
+ W : AliasableWeight
274
+ {
275
+ #[ allow( non_snake_case) ]
276
+ fn from( ( $( ( $types, $weights ) ) ,+ ) : ( $( ( $types, W ) ) ,+ ) ) -> Self {
277
+ let index = WeightedAliasIndex :: new( vec![ $( $weights ) ,+] ) . unwrap( ) ;
278
+
279
+ Self ( index, ( $( $types ) ,+ ) )
280
+ }
281
+ }
282
+
283
+ impl <W , $( $types ) ,+, M > Move <M > for Mixture <W , ( $( $types ) ,+ ) >
284
+ where
285
+ W : AliasableWeight ,
286
+ M : Model ,
287
+ $( $types: Move <M > ) ,+
288
+ {
289
+ #[ allow( non_snake_case) ]
290
+ fn propose<' a, O , R >( & self , self_: & ' a M :: Params , other: O , rng: & mut R ) -> ( M :: Params , f64 )
291
+ where
292
+ O : FnMut ( & mut R ) -> & ' a M :: Params ,
293
+ R : Rng ,
294
+ {
295
+ let Self ( index, ( $( $types ) ,+ ) ) = self ;
296
+
297
+ let chosen_index = index. sample( rng) ;
298
+
299
+ let mut index = 0 ;
300
+
301
+ $(
302
+
303
+ #[ allow( unused_assignments) ]
304
+ if chosen_index == index {
305
+ return $types. propose( self_, other, rng)
306
+ } else {
307
+ index += 1 ;
308
+ }
309
+
310
+ ) +
311
+
312
+ unreachable!( )
313
+ }
314
+ }
315
+ } ;
316
+ }
317
+
318
+ impl_mixture ! ( A @ a, B @ b) ;
319
+ impl_mixture ! ( A @ a, B @ b, C @ c) ;
320
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d) ;
321
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e) ;
322
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f) ;
323
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g) ;
324
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h) ;
325
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i) ;
326
+ impl_mixture ! ( A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j) ;
327
+
120
328
/// Models are defined by the type of their parameters and their probability functions
121
329
pub trait Model : Send + Sync {
122
330
/// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
@@ -126,9 +334,6 @@ pub trait Model: Send + Sync {
126
334
///
127
335
/// The sampler will only ever consider differences of these values, i.e. any addititive constant that does _not_ depend on `state` can be omitted when computing them.
128
336
fn log_prob ( & self , state : & Self :: Params ) -> f64 ;
129
-
130
- /// Scale parameter for stretch moves
131
- const SCALE : f64 = 2. ;
132
337
}
133
338
134
339
/// Runs the sampler on the given [`model`][Model] using the chosen [`schedule`][Schedule] and [`execution`][Execution] strategy
@@ -138,17 +343,19 @@ pub trait Model: Send + Sync {
138
343
/// The number of walkers must be non-zero, even and at least twice the number of parameters.
139
344
///
140
345
/// A vector of samples and the number of accepted moves are returned.
141
- pub fn sample < M , W , R , S , E > (
142
- model : & M ,
346
+ pub fn sample < MD , MV , W , R , S , E > (
347
+ model : & MD ,
348
+ move_ : & MV ,
143
349
walkers : W ,
144
350
mut schedule : S ,
145
351
execution : E ,
146
- ) -> ( Vec < M :: Params > , usize )
352
+ ) -> ( Vec < MD :: Params > , usize )
147
353
where
148
- M : Model ,
149
- W : Iterator < Item = ( M :: Params , R ) > ,
354
+ MD : Model ,
355
+ MV : Move < MD > + Send + Sync ,
356
+ W : Iterator < Item = ( MD :: Params , R ) > ,
150
357
R : Rng + Send + Sync ,
151
- S : Schedule < M :: Params > ,
358
+ S : Schedule < MD :: Params > ,
152
359
E : Execution ,
153
360
{
154
361
let mut walkers = walkers
@@ -166,10 +373,8 @@ where
166
373
167
374
let random_index = Uniform :: new ( 0 , half) . unwrap ( ) ;
168
375
169
- let update_walker = move |walker : & mut Walker < M , R > , other_walkers : & [ Walker < M , R > ] | {
170
- let other = & other_walkers[ random_index. sample ( & mut walker. rng ) ] ;
171
-
172
- walker. move_ ( model, other)
376
+ let update_walker = move |walker : & mut Walker < MD , R > , other_walkers : & [ Walker < MD , R > ] | {
377
+ walker. move_ ( model, move_, |rng| & other_walkers[ random_index. sample ( rng) ] )
173
378
} ;
174
379
175
380
while schedule. next_step ( & chain) . is_continue ( ) {
@@ -187,22 +392,22 @@ where
187
392
( chain, accepted)
188
393
}
189
394
190
- struct Walker < M , R >
395
+ struct Walker < MD , R >
191
396
where
192
- M : Model ,
397
+ MD : Model ,
193
398
{
194
- state : M :: Params ,
399
+ state : MD :: Params ,
195
400
log_prob : f64 ,
196
401
rng : R ,
197
402
accepted : usize ,
198
403
}
199
404
200
- impl < M , R > Walker < M , R >
405
+ impl < MD , R > Walker < MD , R >
201
406
where
202
- M : Model ,
407
+ MD : Model ,
203
408
R : Rng ,
204
409
{
205
- fn new ( model : & M , state : M :: Params , rng : R ) -> Self {
410
+ fn new ( model : & MD , state : MD :: Params , rng : R ) -> Self {
206
411
let log_prob = model. log_prob ( & state) ;
207
412
208
413
Self {
@@ -213,20 +418,17 @@ where
213
418
}
214
419
}
215
420
216
- fn move_ ( & mut self , model : & M , other : & Self ) -> M :: Params {
217
- let z = ( ( M :: SCALE - 1. ) * gen_unit ( & mut self . rng ) + 1. ) . powi ( 2 ) / M :: SCALE ;
218
-
219
- let mut new_state = M :: Params :: collect (
220
- self . state
221
- . values ( )
222
- . zip ( other. state . values ( ) )
223
- . map ( |( self_, other) | other - z * ( other - self_) ) ,
224
- ) ;
421
+ fn move_ < ' a , MV , O > ( & ' a mut self , model : & MD , move_ : & MV , mut other : O ) -> MD :: Params
422
+ where
423
+ MV : Move < MD > ,
424
+ O : FnMut ( & mut R ) -> & ' a Self ,
425
+ {
426
+ let ( mut new_state, factor) =
427
+ move_. propose ( & self . state , |rng| & other ( rng) . state , & mut self . rng ) ;
225
428
226
429
let new_log_prob = model. log_prob ( & new_state) ;
227
430
228
- let log_prob_diff =
229
- ( new_state. dimension ( ) - 1 ) as f64 * z. ln ( ) + new_log_prob - self . log_prob ;
431
+ let log_prob_diff = factor + new_log_prob - self . log_prob ;
230
432
231
433
if log_prob_diff > gen_unit ( & mut self . rng ) . ln ( ) {
232
434
self . state . clone_from ( & new_state) ;
@@ -380,7 +582,7 @@ where
380
582
/// Runs the inner `schedule` after calling the given `callback`
381
583
///
382
584
/// ```
383
- /// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, WithProgress};
585
+ /// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
384
586
/// # use rand::SeedableRng;
385
587
/// # use rand_pcg::Pcg64Mcg;
386
588
/// #
@@ -407,7 +609,7 @@ where
407
609
/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
408
610
/// };
409
611
///
410
- /// let (chain, accepted) = sample(&model, walkers, schedule, Serial);
612
+ /// let (chain, accepted) = sample(&model, &Stretch::default(), walkers, schedule, Serial);
411
613
/// ```
412
614
pub struct WithProgress < S , C > {
413
615
/// The inner schedule which determines the number of iterations
0 commit comments