fraug/augmenters/
jittering.rs1use super::base::Augmenter;
2use rand::prelude::*;
3use rand_distr::Normal;
4use tracing::{info_span};
5pub struct Jittering {
9 pub name: String,
10 pub deviation: f64,
11 p: f64,
12}
13
14impl Jittering {
15 pub fn new(standard_deviation: f64) -> Self {
16 Jittering {
17 name: "Jittering".to_string(),
18 deviation: standard_deviation,
19 p: 1.0,
20 }
21 }
22}
23
24impl Augmenter for Jittering {
25 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
26 let span = info_span!("", step = "augment_one");
27 let _enter = span.enter();
28 let _enter = span.enter();
29 let mut rng = rand::rng();
30 let dist = Normal::new(0.0, self.deviation)
31 .expect("Couldn't create normal distribution from specified standard deviation");
32 x.iter().map(|val| *val + dist.sample(&mut rng)).collect()
33 }
34
35 fn get_probability(&self) -> f64 {
36 self.p
37 }
38
39 fn set_probability(&mut self, probability: f64) {
40 self.p = probability;
41 }
42
43 fn get_name(&self) ->String {
44 self.name.clone()
45 }
46}