1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#![allow(non_snake_case)]
use ggpf::deep::evaluator::PredictionEvaluatorChannel;
use ggpf::deep::tf;
use ggpf::game::breakthrough::{Breakthrough, BreakthroughBuilder};
use ggpf::game::meta::with_history::*;
use ggpf::game::*;
use ndarray::Dimension;
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::runtime;
use tokio::sync::mpsc;
const MODEL_PATH: &str = "data/alpha-breakthrough-5/model/";
type G = WithHistory<Breakthrough>;
fn main() {
let mut threaded_rt = runtime::Builder::new()
.threaded_scheduler()
.enable_all()
.core_threads(8)
.build()
.unwrap();
threaded_rt.block_on(run());
}
use indicatif::{ProgressBar, ProgressStyle};
const GPU_BATCH_SIZE: usize = 128;
const N_GENERATORS: usize = 256;
const N_EVALUATORS: usize = 4;
async fn run() {
flexi_logger::Logger::with_env().start().unwrap();
log::info!("AlphaZero generate: starting!");
if !Path::new(MODEL_PATH).exists() {
println!("Couldn't find model at {}", MODEL_PATH);
return;
};
let prediction_tensorflow = Arc::new((
AtomicBool::new(false),
RwLock::new(tf::load_model(&MODEL_PATH)),
));
let game_builder = WithHistoryGB::new(BreakthroughBuilder { size: 5 }, 2);
let breakthrough: G = game_builder.create(Breakthrough::players()[0]).await;
let ft = breakthrough.get_features();
let board_size = G::state_dimension(&ft).size();
let action_size = G::action_dimension(&ft).size();
let indicator_bar = ProgressBar::new_spinner();
indicator_bar.set_style(
ProgressStyle::default_spinner()
.template("[{spinner}] {wide_bar} {pos} steps generated ({elapsed_precise})"),
);
indicator_bar.enable_steady_tick(200);
let bar_box = Arc::new(Box::new(indicator_bar));
let mut jh = vec![];
for _ in 0..N_EVALUATORS {
let (pred_tx, pred_rx) = mpsc::channel::<PredictionEvaluatorChannel>(2 * GPU_BATCH_SIZE);
for _ in 0..N_GENERATORS {
let ptx = pred_tx.clone();
let bt = breakthrough.clone();
tokio::spawn(async move {
loop {
ggpf::deep::evaluator::prediction(
ptx.clone(),
Breakthrough::players()[0],
&bt,
1,
)
.await;
}
});
}
let prediction_tensorflow = prediction_tensorflow.clone();
let bb = bar_box.clone();
jh.push(tokio::spawn(ggpf::deep::evaluator::prediction_task(
GPU_BATCH_SIZE,
board_size,
action_size,
1,
prediction_tensorflow,
pred_rx,
Some(bb),
)));
}
for i in jh {
i.await.unwrap()
}
}