fraug/augmenters/
frequency_mask.rs

1use super::base::Augmenter;
2use crate::Dataset;
3use crate::transforms::fastfourier::{dataset_fft, dataset_ifft};
4use rand::{Rng, rng};
5use tracing::info_span;
6
7
8/// This augmenter applies a frequency-domain mask to each time series, zeroing out a contiguous block of frequency bins.
9/// - If `is_time_domain` is true, the input is first transformed to the frequency domain using FFT, the mask is applied, and then the result is transformed back to the time domain using IFFT.
10/// The width of the mask is controlled by `mask_width`, and the masked region is chosen randomly for each sample.
11pub struct FrequencyMask {
12    pub name: String,
13    pub mask_width: usize,
14    pub is_time_domain: bool,
15    p: f64,
16}
17
18impl FrequencyMask {
19    pub fn new(mask_width: usize, is_time_domain: bool) -> Self {
20        FrequencyMask {
21            name: "FrequencyMask".to_string(),
22            mask_width,
23            is_time_domain,
24            p: 1.0,
25        }
26    }
27}
28
29impl Augmenter for FrequencyMask {
30    fn augment_batch(&self, data: &mut Dataset, _parallel: bool, per_sample: bool) {
31        let span = info_span!("", component = self.get_name());
32        let _enter = span.enter();
33        if self.is_time_domain {
34            let mut transformed_dataset = dataset_fft(data, true);
35
36            transformed_dataset.features.iter_mut().for_each(|sample| {
37                if self.get_probability() > rng().random() {
38                    *sample = self.augment_one(sample)
39                }
40            });
41
42            let inverse_dataset = dataset_ifft(&transformed_dataset, true);
43            *data = inverse_dataset;
44        } else {
45            data.features.iter_mut().for_each(|sample| {
46                if self.get_probability() > rng().random() {
47                    *sample = self.augment_one(sample)
48                }
49            });
50        }
51    }
52
53    fn augment_one(&self, x: &[f64]) -> Vec<f64> {
54        let span = info_span!("", step = "augment_one");
55        let _enter = span.enter();
56        let mut res = x.to_vec();
57
58        let num_bins = x.len() / 2;
59        if num_bins < self.mask_width {
60            return res;
61        }
62
63        let mut rng = rand::rng();
64        let center = rng.random_range(self.mask_width / 2..(num_bins - self.mask_width / 2));
65        let start = center - self.mask_width / 2;
66        let end = start + self.mask_width;
67        for bin in start..end {
68            let re_idx = 2 * bin;
69            let im_idx = 2 * bin + 1;
70            res[re_idx] = 0.0;
71            res[im_idx] = 0.0;
72        }
73
74        res
75    }
76
77    fn get_probability(&self) -> f64 {
78        self.p
79    }
80
81    fn set_probability(&mut self, probability: f64) {
82        self.p = probability;
83    }
84
85    fn get_name(&self) -> String {
86        self.name.clone()
87    }
88    fn supports_per_sample(&self) -> bool {
89        // if in time-domain mode, disable per-sample chaining because of the FFT/IFFT used in the batch
90        !self.is_time_domain
91    }
92}