1use crate::Dataset;
2use rand::prelude::*;
3use rand::rng;
4use rayon::prelude::*;
5use std::ops::Add;
6use tracing::info_span;
7
8pub trait Augmenter {
10 fn augment_batch(&self, input: &mut Dataset, parallel: bool, per_sample: bool)
14 where
15 Self: Sync,
16 {
17 let span = info_span!("", component = self.get_name());
18 let _enter = span.enter();
19 if parallel {
20 tracing::info!("Rust: parallel augment_batch called");
21 input.features.par_iter_mut().for_each(|x| {
22 if self.get_probability() > rng().random() {
23 *x = self.augment_one(x)
24 }
25 });
26 } else {
27 input.features.iter_mut().for_each(|x| {
28 if self.get_probability() > rng().random() {
29 *x = self.augment_one(x)
30 }
31 });
32 }
33 }
34
35 fn augment_one(&self, x: &[f64]) -> Vec<f64>;
39
40 fn get_probability(&self) -> f64;
42
43 fn set_probability(&mut self, probability: f64);
46
47 fn get_name(&self) -> String;
48
49 fn supports_per_sample(&self) -> bool {
53 true
54 }
55}
56
57pub struct AugmentationPipeline {
84 pub name: String,
85 augmenters: Vec<Box<dyn Augmenter + Sync>>,
86 p: f64,
87}
88
89impl AugmentationPipeline {
90 pub fn new() -> Self {
92 AugmentationPipeline {
93 name: "AugmentationPipeline".to_string(),
94 augmenters: Vec::new(),
95 p: 1.0,
96 }
97 }
98
99 pub fn add<T: Augmenter + 'static + Sync>(&mut self, augmenter: T) {
103 self.augmenters.push(Box::new(augmenter));
104 }
105}
106
107impl Augmenter for AugmentationPipeline {
108 fn augment_batch(&self, input: &mut Dataset, parallel: bool, per_sample: bool) {
109 if per_sample {
110 for augmenter in &self.augmenters {
112 if !augmenter.supports_per_sample() {
113 panic!(
114 "Augmenter '{}' is not compatible with per-sample pipelining!",
115 augmenter.get_name()
116 );
117 }
118 }
119 tracing::info!("Rust: augment_batch called with per_sample = {}", per_sample);
120 if parallel {
121 input.features.par_iter_mut().for_each(|sample| {
122 let mut chain = sample.to_vec();
123 for augmenter in self.augmenters.iter() {
124 if augmenter.get_probability() > rng().random() {
125 chain = augmenter.augment_one(&chain);
126 }
127 }
128 *sample = chain;
129 });
130 } else {
131 input.features.iter_mut().for_each(|sample| {
132 let mut chain = sample.to_vec();
133 for augmenter in self.augmenters.iter() {
134 if augmenter.get_probability() > rng().random() {
135 chain = augmenter.augment_one(&chain);
136 }
137 }
138 *sample = chain;
139 });
140 }
141 } else {
142 self.augmenters
144 .iter()
145 .for_each(|augmenter| augmenter.augment_batch(input, parallel, false));
146 }
147 }
148
149 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
150 let mut res = x.to_vec();
151 for augmenter in self.augmenters.iter() {
152 res = augmenter.augment_one(&res);
153 }
154 res
155 }
156
157 fn get_probability(&self) -> f64 {
158 self.p
159 }
160
161 fn set_probability(&mut self, probability: f64) {
162 self.p = probability;
163 }
164
165 fn get_name(&self) -> String {
166 self.name.clone()
167 }
168}
169
170impl<T: Augmenter + 'static + Sync> Add<T> for AugmentationPipeline {
171 type Output = AugmentationPipeline;
172
173 fn add(self, rhs: T) -> Self::Output {
174 let mut augmenters = self.augmenters;
175 augmenters.push(Box::new(rhs));
176
177 AugmentationPipeline {
178 name: "AugmentationPipeline".to_string(),
179 augmenters,
180 p: self.p,
181 }
182 }
183}