fraug/augmenters/
drift.rs1use super::base::Augmenter;
2use rand::Rng;
3use tracing::{info_span};
4pub struct Drift {
10 pub name: String,
11 pub max_drift: f64,
12 pub n_drift_points: usize,
13 p: f64,
14}
15
16impl Drift {
17 pub fn new(max_drift: f64, n_drift_points: usize) -> Self {
19 Drift {
20 name: "Drift".to_string(),
21 max_drift,
22 n_drift_points: n_drift_points.max(2), p: 1.0,
24 }
25 }
26
27 fn make_drift(&self, len: usize) -> Vec<f64> {
28 let mut rng = rand::rng();
29 let n = self.n_drift_points.min(len);
30 let mut drift_points = Vec::with_capacity(n);
31 for _ in 0..n {
32 drift_points.push(rng.random_range(-self.max_drift..=self.max_drift));
33 }
34 let mut drift = vec![0.0; len];
36 let seg_len = len as f64 / (n - 1) as f64;
37 for i in 0..len {
38 let pos = i as f64 / seg_len;
39 let left = pos.floor() as usize;
40 let right = pos.ceil() as usize;
41 let alpha = pos - left as f64;
42 let left_val = drift_points[left.min(n - 1)];
43 let right_val = drift_points[right.min(n - 1)];
44 drift[i] = (1.0 - alpha) * left_val + alpha * right_val;
45 }
46 drift
47 }
48}
49
50impl Augmenter for Drift {
51 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
52 let span = info_span!("", step = "augment_one");
53 let _enter = span.enter();
54 let drift = self.make_drift(x.len());
55 x.iter().zip(drift.iter()).map(|(xi, di)| xi + di).collect()
56 }
57
58 fn get_probability(&self) -> f64 {
59 self.p
60 }
61
62 fn set_probability(&mut self, probability: f64) {
63 self.p = probability;
64 }
65
66 fn get_name(&self) ->String {
67 self.name.clone()
68 }
69}