use serde_derive::Deserialize;
#[derive(Deserialize, Clone, Debug)]
#[serde(tag = "kind")]
pub enum Game {
Breakthrough {
history: Option<usize>,
size: usize,
},
Gym {
history: Option<usize>,
name: String,
#[serde(default = "default_remote")]
remote: String,
},
}
fn default_remote() -> String {
"localhost:1337".into()
}
impl Game {
pub fn name(&self) -> String {
match self {
Game::Breakthrough { size, .. } => format!("breakthrough-{}", size),
Game::Gym { name, .. } => format!("gym-{}", name),
}
}
pub fn history(&self) -> Option<usize> {
match self {
Game::Breakthrough { history, .. } => *history,
Game::Gym { history, .. } => *history,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct SelfPlay {
pub batch_size: usize,
pub evaluators: usize,
pub generators: usize,
}
const DEFAULT_PLAYOUTS: usize = 200;
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct MCTS {
pub playouts: usize,
}
impl Default for MCTS {
fn default() -> Self {
Self {
playouts: DEFAULT_PLAYOUTS,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct RAVE {
#[serde(default = "default_uct")]
pub uct_weight: f32,
pub playouts: usize,
}
impl Default for RAVE {
fn default() -> Self {
Self {
uct_weight: default_uct(),
playouts: DEFAULT_PLAYOUTS,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct UCT {
#[serde(default = "default_uct")]
pub uct_weight: f32,
pub playouts: usize,
}
impl Default for UCT {
fn default() -> Self {
Self {
uct_weight: default_uct(),
playouts: DEFAULT_PLAYOUTS,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct FlatUCBMonteCarlo {
pub playouts: usize,
#[serde(default = "default_uct")]
pub ucb_weight: f32,
}
impl Default for FlatUCBMonteCarlo {
fn default() -> Self {
Self {
ucb_weight: default_uct(),
playouts: DEFAULT_PLAYOUTS,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct FlatMonteCarlo {
pub playouts: usize,
}
impl Default for FlatMonteCarlo {
fn default() -> Self {
Self {
playouts: DEFAULT_PLAYOUTS,
}
}
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct PPA {
#[serde(default = "default_uct")]
pub uct_weight: f32,
pub playouts: usize,
pub alpha: f32,
}
impl Default for PPA {
fn default() -> Self {
Self {
uct_weight: default_uct(),
playouts: DEFAULT_PLAYOUTS,
alpha: 0.1,
}
}
}
fn default_uct() -> f32 {
0.4
}
#[derive(Deserialize, Copy, Clone, Debug, Default)]
pub struct Policies {
#[serde(default)]
pub rave: RAVE,
#[serde(default)]
pub ppa: PPA,
#[serde(default)]
pub flat: FlatMonteCarlo,
#[serde(default)]
pub flat_ucb: FlatUCBMonteCarlo,
#[serde(default)]
pub uct: UCT,
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct PUCT {
pub discount: f32,
pub c_base: f32,
pub c_init: f32,
pub root_dirichlet_alpha: f32,
pub root_exploration_fraction: f32,
pub value_support: Option<usize>,
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct AlphaZero {
pub puct: PUCT,
}
#[derive(Deserialize, Copy, Clone, Debug)]
pub struct MuZero {
pub puct: PUCT,
pub reward_support: Option<usize>,
pub repr_shape: ndarray::Ix3,
pub unroll_steps: usize,
pub td_steps: usize,
}
#[derive(Deserialize, Clone, Debug)]
pub struct Config {
pub game: Game,
pub self_play: SelfPlay,
pub mcts: MCTS,
pub alpha: Option<AlphaZero>,
pub mu: Option<MuZero>,
#[serde(default)]
pub policies: Policies,
}
use crate::policies::mcts::{muz::MuZeroConfig, puct::AlphaZeroConfig};
impl Config {
pub fn get_alphazero<A, B>(
&self,
action_shape: A,
board_shape: B,
) -> Option<AlphaZeroConfig<A, B>> {
if let Some(alpha_config) = self.alpha {
let model_path = format!("data/alpha-{}/model/", self.game.name());
let alpha_config = AlphaZeroConfig {
action_shape,
board_shape,
puct: alpha_config.puct,
network_path: model_path,
watch_models: true,
batch_size: self.self_play.batch_size,
n_playouts: self.mcts.playouts,
};
Some(alpha_config)
} else {
None
}
}
pub fn get_muzero<A, B>(&self, action_shape: A, board_shape: B) -> Option<MuZeroConfig<B, A>> {
if let Some(mu_config) = self.mu {
let models_path = format!("data/mu-{}/models/", self.game.name());
let mu_config = MuZeroConfig {
action_shape,
board_shape,
muz: mu_config,
networks_path: models_path,
watch_models: true,
batch_size: self.self_play.batch_size,
n_playouts: self.mcts.playouts,
};
Some(mu_config)
} else {
None
}
}
}
pub enum Method {
MuZero,
AlphaZero,
}
impl Method {
pub fn name(&self) -> &str {
match self {
Method::MuZero => "mu",
Method::AlphaZero => "alpha",
}
}
}
#[derive(Debug, Clone)]
pub struct StrError(pub String);
use std::{error, fmt};
impl fmt::Display for StrError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl error::Error for StrError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None
}
}