Skip to content

Commit 569a065

Browse files
authored
Merge pull request #77 from Hoff97/feature/mean-variance-numerically-stable
Make sure mean and variance computations are numerically stable
2 parents 4eff518 + 708d7f6 commit 569a065

File tree

2 files changed

+138
-11
lines changed

2 files changed

+138
-11
lines changed

src/statistics/stat.rs

+108-11
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ impl Statistics for Vec<f64> {
174174

175175
/// Mean
176176
///
177+
/// Uses welfords online algorithm for numerically stable computation.
178+
///
177179
/// # Examples
178180
/// ```
179181
/// #[macro_use]
@@ -186,11 +188,20 @@ impl Statistics for Vec<f64> {
186188
/// }
187189
/// ```
188190
fn mean(&self) -> f64 {
189-
self.reduce(0f64, |x, y| x + y) / (self.len() as f64)
191+
let mut xn = 0f64;
192+
let mut n = 0f64;
193+
194+
for x in self.iter() {
195+
n += 1f64;
196+
xn += (x - xn) / n;
197+
}
198+
xn
190199
}
191200

192201
/// Variance
193202
///
203+
/// Uses welfords online algorithm for numerically stable computation.
204+
///
194205
/// # Examples
195206
/// ```
196207
/// #[macro_use]
@@ -203,17 +214,18 @@ impl Statistics for Vec<f64> {
203214
/// }
204215
/// ```
205216
fn var(&self) -> f64 {
206-
let mut ss = 0f64;
207-
let mut s = 0f64;
208-
let mut l = 0f64;
209-
210-
for x in self.into_iter() {
211-
ss += x.powf(2f64);
212-
s += *x;
213-
l += 1f64;
217+
let mut xn = 0f64;
218+
let mut n = 0f64;
219+
let mut m2n: f64 = 0f64;
220+
221+
for x in self.iter() {
222+
n += 1f64;
223+
let diff_1 = x - xn;
224+
xn += diff_1 / n;
225+
m2n += diff_1 * (x - xn);
214226
}
215-
assert_ne!(l, 1f64);
216-
(ss / l - (s / l).powf(2f64)) * l / (l - 1f64)
227+
assert_ne!(n, 1f64);
228+
m2n / (n - 1f64)
217229
}
218230

219231
/// Standard Deviation
@@ -241,6 +253,91 @@ impl Statistics for Vec<f64> {
241253
}
242254
}
243255

256+
impl Statistics for Vec<f32> {
257+
type Array = Vec<f32>;
258+
type Value = f32;
259+
260+
/// Mean
261+
///
262+
/// Uses welfords online algorithm for numerically stable computation.
263+
///
264+
/// # Examples
265+
/// ```
266+
/// #[macro_use]
267+
/// extern crate peroxide;
268+
/// use peroxide::fuga::*;
269+
///
270+
/// fn main() {
271+
/// let a = c!(1,2,3,4,5);
272+
/// assert_eq!(a.mean(), 3.0);
273+
/// }
274+
/// ```
275+
fn mean(&self) -> f32 {
276+
let mut xn = 0f32;
277+
let mut n = 0f32;
278+
279+
for x in self.iter() {
280+
n += 1f32;
281+
xn += (x - xn) / n;
282+
}
283+
xn
284+
}
285+
286+
/// Variance
287+
///
288+
/// Uses welfords online algorithm for numerically stable computation.
289+
///
290+
/// # Examples
291+
/// ```
292+
/// #[macro_use]
293+
/// extern crate peroxide;
294+
/// use peroxide::fuga::*;
295+
///
296+
/// fn main() {
297+
/// let a = c!(1,2,3,4,5);
298+
/// assert_eq!(a.var(), 2.5);
299+
/// }
300+
/// ```
301+
fn var(&self) -> f32 {
302+
let mut xn = 0f32;
303+
let mut n = 0f32;
304+
let mut m2n: f32 = 0f32;
305+
306+
for x in self.iter() {
307+
n += 1f32;
308+
let diff_1 = x - xn;
309+
xn += diff_1 / n;
310+
m2n += diff_1 * (x - xn);
311+
}
312+
assert_ne!(n, 1f32);
313+
m2n / (n - 1f32)
314+
}
315+
316+
/// Standard Deviation
317+
///
318+
/// # Examples
319+
/// ```
320+
/// #[macro_use]
321+
/// extern crate peroxide;
322+
/// use peroxide::fuga::*;
323+
///
324+
/// fn main() {
325+
/// let a = c!(1,2,3);
326+
/// assert!(nearly_eq(a.sd(), 1f64)); // Floating Number Error
327+
/// }
328+
/// ```
329+
fn sd(&self) -> f32 {
330+
self.var().sqrt()
331+
}
332+
333+
fn cov(&self) -> Vec<f32> {
334+
unimplemented!()
335+
}
336+
fn cor(&self) -> Vec<f32> {
337+
unimplemented!()
338+
}
339+
}
340+
244341
impl Statistics for Matrix {
245342
type Array = Matrix;
246343
type Value = Vec<f64>;

tests/stat.rs

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
extern crate peroxide;
2+
use peroxide::fuga::*;
3+
4+
#[test]
5+
fn test_mean() {
6+
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
7+
assert_eq!(a.mean() ,3.0);
8+
}
9+
10+
#[test]
11+
fn test_mean_stable() {
12+
let a: Vec<f32> = vec![1.0; 10000000];
13+
let diff = 10000.0;
14+
let b = a.iter().map(|x| x+diff).collect::<Vec<f32>>();
15+
assert_eq!(a.mean(), b.mean()-diff);
16+
}
17+
18+
#[test]
19+
fn test_variance() {
20+
let a = vec![1.0,2.0,3.0,4.0,5.0];
21+
assert_eq!(a.var(), 2.5);
22+
}
23+
24+
#[test]
25+
fn test_variance_stable() {
26+
let a = vec![1.0,2.0,3.0,4.0,5.0];
27+
let diff = 1000000000.0;
28+
let b = a.iter().map(|x| x+diff).collect::<Vec<f64>>();
29+
assert_eq!(a.var(), b.var());
30+
}

0 commit comments

Comments
 (0)