fraug/augmenters/
convolve.rs1use super::base::Augmenter;
2use rand::{Rng, rng};
3use rayon::prelude::*;
4use tracing::{info_span};
5
6pub enum ConvolveWindow {
8 Flat,
9 Gaussian,
10}
11
12pub struct Convolve {
19 pub name: String,
20 window: ConvolveWindow,
21 size: usize,
22 p: f64,
23}
24
25impl Convolve {
26 pub fn new(window: ConvolveWindow, size: usize) -> Self {
27 assert!(size > 0, "Kernel size must be greater than 0");
28 Convolve {
29 name: "Convolve".to_string(),
30 window,
31 size,
32 p: 1.0,
33 }
34 }
35
36 fn make_kernel(&self) -> Vec<f64> {
37 let n = self.size;
38 match self.window {
39 ConvolveWindow::Flat => vec![1.0 / n as f64; n],
40 ConvolveWindow::Gaussian => {
41 let sigma = 0.3 * ((n - 1) as f64) * 0.5 + 0.8;
42 let mid = (n as f64 - 1.0) / 2.0;
43 let mut kernel = Vec::with_capacity(n);
44 for i in 0..n {
45 let x = i as f64 - mid;
46 kernel.push((-0.5 * (x / sigma).powi(2)).exp());
47 }
48 let sum: f64 = kernel.iter().sum();
49 kernel.iter_mut().for_each(|v| *v /= sum);
50 kernel
51 }
52 }
53 }
54
55 fn convolve(&self, x: &[f64], kernel: &[f64]) -> Vec<f64> {
56 let n = kernel.len();
57 let len = x.len();
58 if len < n {
59 return x.to_vec();
60 }
61 let mut out = vec![0.0; len];
62 let half = n / 2;
63 for i in 0..len {
64 let mut acc = 0.0;
65 for k in 0..n {
66 let idx = if i + k >= half && i + k < len + half {
67 i + k - half
68 } else {
69 continue;
70 };
71 if idx < len {
72 acc += x[idx] * kernel[k];
73 }
74 }
75 out[i] = acc;
76 }
77 out
78 }
79}
80
81impl Augmenter for Convolve {
82 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
83 let span = info_span!("", step = "augment_one");
84 let _enter = span.enter();
85 let kernel = self.make_kernel();
86 self.convolve(x, &kernel)
87 }
88
89 fn augment_batch(&self, input: &mut crate::Dataset, parallel: bool, per_sample: bool)
91 where
92 Self: Sync,
93 {
94 let kernel = self.make_kernel();
95 if parallel {
96 input.features.par_iter_mut().for_each(|x| {
97 if self.get_probability() > rng().random() {
98 *x = self.convolve(x, &kernel)
99 }
100 });
101 } else {
102 input.features.iter_mut().for_each(|x| {
103 if self.get_probability() > rng().random() {
104 *x = self.convolve(x, &kernel)
105 }
106 });
107 }
108 }
109
110 fn get_probability(&self) -> f64 {
111 self.p
112 }
113
114 fn set_probability(&mut self, probability: f64) {
115 self.p = probability;
116 }
117
118 fn get_name(&self) ->String {
119 self.name.clone()
120 }
121}