fraug/augmenters/
amplitude_phase_perturbation.rs1use super::base::Augmenter;
2use crate::Dataset;
3use crate::transforms::fastfourier::{dataset_fft, dataset_ifft};
4use rand::{Rng, rng};
5use rand_distr::{Distribution, Normal};
6use tracing::{info_span};
7
8pub struct AmplitudePhasePerturbation {
15 pub name: String,
16 pub magnitude_std: f64,
17 pub phase_std: f64,
18 pub is_time_domain: bool,
19 p: f64,
20}
21
22impl AmplitudePhasePerturbation {
23 pub fn new(magnitude_std: f64, phase_std: f64, is_time_domain: bool) -> Self {
24 Self {
25 name: "AmplitudePhasePerturbation".to_string(),
26 magnitude_std,
27 phase_std,
28 is_time_domain,
29 p: 1.0,
30 }
31 }
32}
33
34impl Augmenter for AmplitudePhasePerturbation {
35 fn augment_batch(&self, data: &mut Dataset, _parallel: bool, per_sample: bool) {
36 let span = info_span!("", component = self.get_name());
38 let _enter = span.enter();
39 if self.is_time_domain {
40 let mut transformed_dataset = dataset_fft(data, true);
41
42 transformed_dataset.features.iter_mut().for_each(|sample| {
43 if self.get_probability() > rng().random() {
44 *sample = self.augment_one(sample)
45 }
46 });
47
48 let inverse_dataset = dataset_ifft(&transformed_dataset, true);
49 *data = inverse_dataset;
50 } else {
51 data.features.iter_mut().for_each(|sample| {
52 if self.get_probability() > rng().random() {
53 *sample = self.augment_one(sample)
54 }
55 });
56 }
57 }
58
59 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
60 let span = info_span!("", step = "augment_one");
61 let _enter = span.enter();
62 let num_bins = x.len() / 2;
63 let mut rng = rng();
64 let mag_noise = Normal::new(0.0, self.magnitude_std).unwrap();
65 let phase_noise = Normal::new(0.0, self.phase_std).unwrap();
66
67 let mut x = x.to_vec();
68
69 for bin in 0..num_bins {
70 let re_idx = 2 * bin;
71 let im_idx = 2 * bin + 1;
72 let re = x[re_idx];
73 let im = x[im_idx];
74
75 let mag = (re * re + im * im).sqrt();
77 let phase = im.atan2(re);
78
79 let mag_perturbed = (mag + mag_noise.sample(&mut rng)).max(0.0);
81 let phase_perturbed = phase + phase_noise.sample(&mut rng);
82
83 x[re_idx] = mag_perturbed * phase_perturbed.cos();
85 x[im_idx] = mag_perturbed * phase_perturbed.sin();
86 }
87
88 x
89 }
90
91 fn get_probability(&self) -> f64 {
92 self.p
93 }
94
95 fn set_probability(&mut self, probability: f64) {
96 self.p = probability;
97 }
98
99 fn get_name(&self) ->String {
100 self.name.clone()
101 }
102
103 fn supports_per_sample(&self) -> bool {
104 !self.is_time_domain
106 }
107
108}