fraug/augmenters/
jittering.rs

1use super::base::Augmenter;
2use rand::prelude::*;
3use rand_distr::Normal;
4use tracing::{info_span};
5/// Augmenter that adds white gaussian noise of the specified standard deviation and a mean of 0
6/// 
7/// A special case of the `AddNoise` augmenter
8pub struct Jittering {
9    pub name: String,
10    pub deviation: f64,
11    p: f64,
12}
13
14impl Jittering {
15    pub fn new(standard_deviation: f64) -> Self {
16        Jittering {
17            name: "Jittering".to_string(),
18            deviation: standard_deviation,
19            p: 1.0,
20        }
21    }
22}
23
24impl Augmenter for Jittering {
25    fn augment_one(&self, x: &[f64]) -> Vec<f64> {
26        let span = info_span!("", step = "augment_one");
27        let _enter = span.enter();
28        let _enter = span.enter();
29        let mut rng = rand::rng();
30        let dist = Normal::new(0.0, self.deviation)
31            .expect("Couldn't create normal distribution from specified standard deviation");
32        x.iter().map(|val| *val + dist.sample(&mut rng)).collect()
33    }
34
35    fn get_probability(&self) -> f64 {
36        self.p
37    }
38
39    fn set_probability(&mut self, probability: f64) {
40        self.p = probability;
41    }
42
43    fn get_name(&self) ->String {
44        self.name.clone()
45    }
46}