fraug/augmenters/
amplitude_phase_perturbation.rs

1use super::base::Augmenter;
2use crate::Dataset;
3use crate::transforms::fastfourier::{dataset_fft, dataset_ifft};
4use rand::{Rng, rng};
5use rand_distr::{Distribution, Normal};
6use tracing::{info_span};
7
8/// This augmenter perturbs the frequency representation of each time series by adding Gaussian noise
9/// to the magnitude and phase of each frequency bin. If `is_time_domain` is true, the input is first
10/// transformed to the frequency domain using FFT, the perturbation is applied, and then the result is
11/// transformed back to the time domain using IFFT.
12/// The standard deviations of the noise for magnitude
13/// and phase are controlled by `magnitude_std` and `phase_std`, respectively.
14pub struct AmplitudePhasePerturbation {
15    pub name: String,
16    pub magnitude_std: f64,
17    pub phase_std: f64,
18    pub is_time_domain: bool,
19    p: f64,
20}
21
22impl AmplitudePhasePerturbation {
23    pub fn new(magnitude_std: f64, phase_std: f64, is_time_domain: bool) -> Self {
24        Self {
25            name: "AmplitudePhasePerturbation".to_string(),
26            magnitude_std,
27            phase_std,
28            is_time_domain,
29            p: 1.0,
30        }
31    }
32}
33
34impl Augmenter for AmplitudePhasePerturbation {
35    fn augment_batch(&self, data: &mut Dataset, _parallel: bool, per_sample: bool) {
36        // tracing::info!("Rust: augment_batch called with per_sample = {}", per_sample);
37        let span = info_span!("", component = self.get_name());
38        let _enter = span.enter();
39        if self.is_time_domain {
40            let mut transformed_dataset = dataset_fft(data, true);
41
42            transformed_dataset.features.iter_mut().for_each(|sample| {
43                if self.get_probability() > rng().random() {
44                    *sample = self.augment_one(sample)
45                }
46            });
47
48            let inverse_dataset = dataset_ifft(&transformed_dataset, true);
49            *data = inverse_dataset;
50        } else {
51            data.features.iter_mut().for_each(|sample| {
52                if self.get_probability() > rng().random() {
53                    *sample = self.augment_one(sample)
54                }
55            });
56        }
57    }
58
59    fn augment_one(&self, x: &[f64]) -> Vec<f64> {
60        let span = info_span!("", step = "augment_one");
61        let _enter = span.enter();
62        let num_bins = x.len() / 2;
63        let mut rng = rng();
64        let mag_noise = Normal::new(0.0, self.magnitude_std).unwrap();
65        let phase_noise = Normal::new(0.0, self.phase_std).unwrap();
66
67        let mut x = x.to_vec();
68
69        for bin in 0..num_bins {
70            let re_idx = 2 * bin;
71            let im_idx = 2 * bin + 1;
72            let re = x[re_idx];
73            let im = x[im_idx];
74
75            // Convert to polar
76            let mag = (re * re + im * im).sqrt();
77            let phase = im.atan2(re);
78
79            // Add noise
80            let mag_perturbed = (mag + mag_noise.sample(&mut rng)).max(0.0);
81            let phase_perturbed = phase + phase_noise.sample(&mut rng);
82
83            // Convert back to cartesian
84            x[re_idx] = mag_perturbed * phase_perturbed.cos();
85            x[im_idx] = mag_perturbed * phase_perturbed.sin();
86        }
87
88        x
89    }
90
91    fn get_probability(&self) -> f64 {
92        self.p
93    }
94
95    fn set_probability(&mut self, probability: f64) {
96        self.p = probability;
97    }
98
99    fn get_name(&self) ->String {
100        self.name.clone()
101    }
102
103    fn supports_per_sample(&self) -> bool {
104        // if in time-domain mode, disable per-sample chaining because of the FFT/IFFT used in the batch
105        !self.is_time_domain
106    }
107    
108}