fraug/augmenters/
quantize.rs1use super::base::Augmenter;
2use tracing::{info_span};
3pub struct Quantize {
7 pub name: String,
8 levels: usize,
10 p: f64,
11}
12
13impl Quantize {
14 pub fn new(levels: usize) -> Self {
16 Quantize {
17 name: "Quantize".to_string(),
18 levels,
19 p: 1.0,
20 }
21 }
22}
23
24impl Augmenter for Quantize {
25 fn augment_one(&self, x: &[f64]) -> Vec<f64> {
26 let span = info_span!("", step = "augment_one");
27 let _enter = span.enter();
28 let max = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
29 let min = x.iter().fold(f64::INFINITY, |a, &b| a.min(b));
30 let range = max - min;
31 let step = range / self.levels as f64;
32 let level_set = (0..self.levels)
33 .map(|level| min + level as f64 * step)
34 .collect::<Vec<_>>();
35
36 x.iter()
38 .map(|v| {
39 let i = level_set
40 .iter()
41 .map(|&l| (l - *v).abs())
42 .enumerate()
43 .fold(
44 (0, f64::INFINITY),
45 |(i, a), (j, b)| if a > b { (j, b) } else { (i, a) },
46 )
47 .0;
48 level_set[i]
49 })
50 .collect::<Vec<_>>()
51 }
52
53 fn get_probability(&self) -> f64 {
54 self.p
55 }
56
57 fn set_probability(&mut self, probability: f64) {
58 self.p = probability;
59 }
60
61 fn get_name(&self) ->String {
62 self.name.clone()
63 }
64}