fraug/augmenters/
base.rs

1use crate::Dataset;
2use rand::prelude::*;
3use rand::rng;
4use rayon::prelude::*;
5use std::ops::Add;
6use tracing::info_span;
7
8/// Trait for all augmenters, allows for augmentation of one time series or a batch
9pub trait Augmenter {
10    /// Augment a whole batch
11    ///
12    /// Parallelized using rayon when `parallell` is set
13    fn augment_batch(&self, input: &mut Dataset, parallel: bool, per_sample: bool)
14    where
15        Self: Sync,
16    {
17        let span = info_span!("", component = self.get_name());
18        let _enter = span.enter();
19        if parallel {
20            tracing::info!("Rust: parallel augment_batch called");
21            input.features.par_iter_mut().for_each(|x| {
22                if self.get_probability() > rng().random() {
23                    *x = self.augment_one(x)
24                }
25            });
26        } else {
27            input.features.iter_mut().for_each(|x| {
28                if self.get_probability() > rng().random() {
29                    *x = self.augment_one(x)
30                }
31            });
32        }
33    }
34
35    /// Augment one time series
36    ///
37    /// When called, the augmenter will always augment the series no matter what the probability for this augmenter is
38    fn augment_one(&self, x: &[f64]) -> Vec<f64>;
39
40    /// Get the probability that this augmenter will augment a series in a batch
41    fn get_probability(&self) -> f64;
42
43    /// By setting a probability with this function the augmenter will only augment a series in a
44    /// batch with the specified probability
45    fn set_probability(&mut self, probability: f64);
46
47    fn get_name(&self) -> String;
48
49    /// Indicate whether this augmenter supports per-sample chaining.
50    /// By default, return true. Augmenters that need a batch level view
51    /// should override this to return false.
52    fn supports_per_sample(&self) -> bool {
53        true
54    }
55}
56
57/// A pipeline of augmenters
58///
59/// Executes many augmenters at once
60///
61/// # Example
62///
63/// ```
64///  use fraug::Dataset;
65///  use fraug::augmenters::*;
66///
67///  let series = vec![1.0; 100];
68///  let mut set = Dataset {
69///     features: vec![series],
70///     labels: vec![String::from("1")],
71///  };
72///
73///  let pipeline = AugmentationPipeline::new()
74///                 + Repeat::new(5)
75///                 + Crop::new(20)
76///                 + Jittering::new(0.2);
77///
78///  pipeline.augment_batch(&mut set, true, false);
79///
80///  assert_eq!(set.features.len(), 5);
81///  assert_eq!(set.features[3].len(), 20);
82/// ```
83pub struct AugmentationPipeline {
84    pub name: String,
85    augmenters: Vec<Box<dyn Augmenter + Sync>>,
86    p: f64,
87}
88
89impl AugmentationPipeline {
90    /// Creates an empty pipeline
91    pub fn new() -> Self {
92        AugmentationPipeline {
93            name: "AugmentationPipeline".to_string(),
94            augmenters: Vec::new(),
95            p: 1.0,
96        }
97    }
98
99    /// Add an augmenter to the pipeline
100    ///
101    /// Has the same effect as using the `+` operator
102    pub fn add<T: Augmenter + 'static + Sync>(&mut self, augmenter: T) {
103        self.augmenters.push(Box::new(augmenter));
104    }
105}
106
107impl Augmenter for AugmentationPipeline {
108    fn augment_batch(&self, input: &mut Dataset, parallel: bool, per_sample: bool) {
109        if per_sample {
110            // Compatibility check : reject if any augmenter has per-sample chaining disabled in pipeline
111            for augmenter in &self.augmenters {
112                if !augmenter.supports_per_sample() {
113                    panic!(
114                        "Augmenter '{}' is not compatible with per-sample pipelining!",
115                        augmenter.get_name()
116                    );
117                }
118            }
119            tracing::info!("Rust: augment_batch called with per_sample = {}", per_sample);
120            if parallel {
121                input.features.par_iter_mut().for_each(|sample| {
122                    let mut chain = sample.to_vec();
123                    for augmenter in self.augmenters.iter() {
124                        if augmenter.get_probability() > rng().random() {
125                            chain = augmenter.augment_one(&chain);
126                        }
127                    }
128                    *sample = chain;
129                });
130            } else {
131                input.features.iter_mut().for_each(|sample| {
132                    let mut chain = sample.to_vec();
133                    for augmenter in self.augmenters.iter() {
134                        if augmenter.get_probability() > rng().random() {
135                            chain = augmenter.augment_one(&chain);
136                        }
137                    }
138                    *sample = chain;
139                });
140            }
141        } else {
142            // Existing batch approach: each augmenter processes the entire dataset in sequence
143            self.augmenters
144                .iter()
145                .for_each(|augmenter| augmenter.augment_batch(input, parallel, false));
146        }
147    }
148
149    fn augment_one(&self, x: &[f64]) -> Vec<f64> {
150        let mut res = x.to_vec();
151        for augmenter in self.augmenters.iter() {
152            res = augmenter.augment_one(&res);
153        }
154        res
155    }
156
157    fn get_probability(&self) -> f64 {
158        self.p
159    }
160
161    fn set_probability(&mut self, probability: f64) {
162        self.p = probability;
163    }
164
165    fn get_name(&self) -> String {
166        self.name.clone()
167    }
168}
169
170impl<T: Augmenter + 'static + Sync> Add<T> for AugmentationPipeline {
171    type Output = AugmentationPipeline;
172
173    fn add(self, rhs: T) -> Self::Output {
174        let mut augmenters = self.augmenters;
175        augmenters.push(Box::new(rhs));
176
177        AugmentationPipeline {
178            name: "AugmentationPipeline".to_string(),
179            augmenters,
180            p: self.p,
181        }
182    }
183}