fraug/augmenters/
pool.rs

1use super::base::Augmenter;
2use tracing::{info_span};
3/// Reduces the temporal resolution without changing the length by pooling multiple samples together
4pub struct Pool {
5    pub name: String,
6    /// Pooling function to be used
7    pub kind: PoolingMethod,
8    /// Size of one pool
9    pub size: usize,
10    p: f64,
11}
12
13/// Enum to specify the pooling function for the `Pool` augmenter
14pub enum PoolingMethod {
15    Max,
16    Min,
17    Average,
18}
19
20impl Pool {
21    /// Creates new pool augmenter
22    pub fn new(kind: PoolingMethod, size: usize) -> Self {
23        Pool {
24            name: "Pool".to_string(),
25            kind,
26            size,
27            p: 1.0,
28        }
29    }
30}
31
32impl Augmenter for Pool {
33    fn augment_one(&self, x: &[f64]) -> Vec<f64> {
34        let span = info_span!("", step = "augment_one");
35        let _enter = span.enter();
36        let mut res = Vec::with_capacity(x.len());
37
38        let mut i = 0;
39        while i < x.len() {
40            let cur_size = if i + self.size < x.len() {
41                self.size
42            } else {
43                x.len() - i
44            };
45
46            let new_val = {
47                match self.kind {
48                    PoolingMethod::Max => *x[i..i + cur_size]
49                        .iter()
50                        .reduce(|a, b| if a < b { b } else { a })
51                        .unwrap(),
52                    PoolingMethod::Min => *x[i..i + cur_size]
53                        .iter()
54                        .reduce(|a, b| if a < b { a } else { b })
55                        .unwrap(),
56                    PoolingMethod::Average => {
57                        x[i..i + cur_size].iter().sum::<f64>() / cur_size as f64
58                    }
59                }
60            };
61
62            for _ in i..i + cur_size {
63                res.push(new_val);
64            }
65
66            i += self.size;
67        }
68
69        res
70    }
71
72    fn get_probability(&self) -> f64 {
73        self.p
74    }
75
76    fn set_probability(&mut self, probability: f64) {
77        self.p = probability;
78    }
79
80    fn get_name(&self) ->String {
81        self.name.clone()
82    }
83}