fraug/augmenters/
frequency_mask.rs1use super::base::Augmenter;
2use crate::Dataset;
3use crate::transforms::fastfourier::{dataset_fft, dataset_ifft};
4use rand::{Rng, rng};
5use tracing::info_span;
6
7
8pub struct FrequencyMask {
12 pub name: String,
13 pub mask_width: usize,
14 pub is_time_domain: bool,
15 p: f64,
16}
17
18impl FrequencyMask {
19 pub fn new(mask_width: usize, is_time_domain: bool) -> Self {
20 FrequencyMask {
21 name: "FrequencyMask".to_string(),
22 mask_width,
23 is_time_domain,
24 p: 1.0,
25 }
26 }
27}
28
29impl Augmenter for FrequencyMask {
30 fn augment_batch(&self, data: &mut Dataset, _parallel: bool, per_sample: bool) {
31 let span = info_span!("", component = self.get_name());
32 let _enter = span.enter();
33 if self.is_time_domain {
34 let mut transformed_dataset = dataset_fft(data, true);
35
36 transformed_dataset.features.iter_mut().for_each(|sample| {
37 if self.get_probability() > rng().random() {
38 *sample = self.augment_one(sample)
39 }
40 });
41
42 let inverse_dataset = dataset_ifft(&transformed_dataset, true);
43 *data = inverse_dataset;
44 } else {
45 data.features.iter_mut().for_each(|sample| {
46 if self.get_probability() > rng().random() {
47 *sample = self.augment_one(sample)
48 }
49 });
50 }
51 }
52
53 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
54 let span = info_span!("", step = "augment_one");
55 let _enter = span.enter();
56 let mut res = x.to_vec();
57
58 let num_bins = x.len() / 2;
59 if num_bins < self.mask_width {
60 return res;
61 }
62
63 let mut rng = rand::rng();
64 let center = rng.random_range(self.mask_width / 2..(num_bins - self.mask_width / 2));
65 let start = center - self.mask_width / 2;
66 let end = start + self.mask_width;
67 for bin in start..end {
68 let re_idx = 2 * bin;
69 let im_idx = 2 * bin + 1;
70 res[re_idx] = 0.0;
71 res[im_idx] = 0.0;
72 }
73
74 res
75 }
76
77 fn get_probability(&self) -> f64 {
78 self.p
79 }
80
81 fn set_probability(&mut self, probability: f64) {
82 self.p = probability;
83 }
84
85 fn get_name(&self) -> String {
86 self.name.clone()
87 }
88 fn supports_per_sample(&self) -> bool {
89 !self.is_time_domain
91 }
92}