1use super::base::Augmenter;
2use tracing::{info_span};
3pub struct Pool {
5 pub name: String,
6 pub kind: PoolingMethod,
8 pub size: usize,
10 p: f64,
11}
12
13pub enum PoolingMethod {
15 Max,
16 Min,
17 Average,
18}
19
20impl Pool {
21 pub fn new(kind: PoolingMethod, size: usize) -> Self {
23 Pool {
24 name: "Pool".to_string(),
25 kind,
26 size,
27 p: 1.0,
28 }
29 }
30}
31
32impl Augmenter for Pool {
33 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
34 let span = info_span!("", step = "augment_one");
35 let _enter = span.enter();
36 let mut res = Vec::with_capacity(x.len());
37
38 let mut i = 0;
39 while i < x.len() {
40 let cur_size = if i + self.size < x.len() {
41 self.size
42 } else {
43 x.len() - i
44 };
45
46 let new_val = {
47 match self.kind {
48 PoolingMethod::Max => *x[i..i + cur_size]
49 .iter()
50 .reduce(|a, b| if a < b { b } else { a })
51 .unwrap(),
52 PoolingMethod::Min => *x[i..i + cur_size]
53 .iter()
54 .reduce(|a, b| if a < b { a } else { b })
55 .unwrap(),
56 PoolingMethod::Average => {
57 x[i..i + cur_size].iter().sum::<f64>() / cur_size as f64
58 }
59 }
60 };
61
62 for _ in i..i + cur_size {
63 res.push(new_val);
64 }
65
66 i += self.size;
67 }
68
69 res
70 }
71
72 fn get_probability(&self) -> f64 {
73 self.p
74 }
75
76 fn set_probability(&mut self, probability: f64) {
77 self.p = probability;
78 }
79
80 fn get_name(&self) ->String {
81 self.name.clone()
82 }
83}