Skip to content

Commit

Permalink
Re-structure the neural network package
Browse files Browse the repository at this point in the history
Re-structure the `nn` package to:

* Feature an external half precision module since it is used in
  multiple separate modules.
* Use half precision in the weight JSON format to decrease the disk size
  of the file (88Mb -> 44Mb)
* Move the ffi interface to NVIDIA CUDA into an internal `ffi` module
  that is private to the `nn` package.
* Improve error checking of CUDA.
  • Loading branch information
kblomdahl committed Nov 29, 2017
1 parent 42ec765 commit 6608c06
Show file tree
Hide file tree
Showing 13 changed files with 495 additions and 407 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ publish = false
default = []

# print (a lot) of debug info during neural network evaluation
debug_nn = []
trace-cuda = []

[profile.dev]
opt-level = 2
Expand Down
8 changes: 2 additions & 6 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::io::prelude::*;
use std::io::{self, BufReader, Cursor};
use std::thread::{self, JoinHandle};
use std::sync::mpsc::{sync_channel, SyncSender, Receiver};
use ::f16::*;

#[derive(Clone)]
pub struct Entry {
Expand Down Expand Up @@ -143,11 +144,6 @@ impl Entry {
}
}

extern "C" {
#[link_name = "llvm.convert.to.fp16.f32"]
fn convert_to_fp16_f32(f: f32) -> u16;
}

/// Write the given 32-bit floating point number to the given formatter
/// in the platform endianess.
///
Expand All @@ -159,7 +155,7 @@ extern "C" {
fn write_f32<T>(f: &mut T, value: f32) -> io::Result<()>
where T: io::Write
{
let value_b16: u16 = unsafe { convert_to_fp16_f32(value) };
let value_b16: u16 = f16::from(value).to_bits();

unsafe {
let bytes = transmute::<_, [u8; 2]>(value_b16);
Expand Down
80 changes: 80 additions & 0 deletions src/f16/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2017 Karl Sundequist Blomdahl <karl.sundequist.blomdahl@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/// 16-bit floating point numbers as defined in IEEE 754-2008.
#[allow(non_camel_case_types)]
pub struct f16(u16);

impl f16 {
/// Wrap the given bits as an half precision floating point number.
///
/// # Arguments
///
/// * `bits` - the bits to wrap
///
pub fn from_bits(bits: u16) -> f16 {
f16(bits)
}

/// Returns the wrapped bits.
pub fn to_bits(&self) -> u16 {
let f16(bits) = *self;

bits
}
}

extern "C" {
#[link_name = "llvm.convert.to.fp16.f32"]
fn convert_to_fp16_f32(f: f32) -> u16;

#[link_name = "llvm.convert.from.fp16.f32"]
fn convert_from_fp16_f32(f: u16) -> f32;
}

impl From<f16> for f32 {
fn from(value: f16) -> f32 {
let f16(bits) = value;

unsafe { convert_from_fp16_f32(bits) }
}
}

impl From<f32> for f16 {
fn from(value: f32) -> f16 {
f16(unsafe { convert_to_fp16_f32(value) })
}
}

#[cfg(test)]
mod tests {
use ::f16::*;

#[test]
fn from_f16_to_f32() {
assert_eq!(f32::from(f16::from_bits(0x4170)), 2.71875); // e
assert_eq!(f32::from(f16::from_bits(0x4248)), 3.140625); // pi
assert_eq!(f32::from(f16::from_bits(0x3518)), 0.31835938); // 1/pi
assert_eq!(f32::from(f16::from_bits(0x398c)), 0.6933594); // ln 2
assert_eq!(f32::from(f16::from_bits(0x36f3)), 0.43432617); // log10 e
assert_eq!(f32::from(f16::from_bits(0x3dc5)), 1.4423828); // log2 e
assert_eq!(f32::from(f16::from_bits(0x3da8)), 1.4140625); // sqrt 2
}

#[test]
fn from_f32_to_f16() {
assert_eq!(f16::from(::std::f32::consts::PI).to_bits(), 0x4248); // pi
assert_eq!(f16::from(::std::f32::consts::E).to_bits(), 0x4170); // pi
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extern crate regex;
extern crate ordered_float;

pub mod dataset;
mod f16;
pub mod go;
pub mod mcts;
pub mod nn;
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use dream_go::{dataset, nn, mcts};
use std::env;
use std::path::Path;

///
/// Main function.
fn main() {
// keep everything that is before the first "--" indicator as potential
// program arguments
Expand Down
18 changes: 18 additions & 0 deletions src/mcts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ fn choose(board: &Board, color: Color, policy: &mut [f32], greedy: bool) -> Opti
}
}

/// Performs a forward pass through the neural network for the given board
/// position using a random symmetry to increase entropy.
///
/// # Arguments
///
/// * `workspace` - the workspace to use during the forward pass
/// * `board` - the board position
/// * `color` - the current player
///
fn forward(workspace: &mut Workspace, board: &Board, color: Color) -> (f32, Box<[f32]>) {
lazy_static! {
static ref SYMM: Vec<symmetry::Transform> = vec! [
Expand Down Expand Up @@ -125,6 +134,9 @@ pub fn self_play(workspace: &mut Workspace) -> GameResult {
let mut pass_count = 0;
let mut count = 0;

// limit the maximum number of moves to `2 * 19 * 19` to avoid the
// engine playing pointless capture sequences at the end of the game
// that does not change the final result.
while count < 722 {
let (value, mut policy) = forward(workspace, &board, current);

Expand All @@ -133,8 +145,14 @@ pub fn self_play(workspace: &mut Workspace) -> GameResult {

return GameResult::Resign(sgf, board, current.opposite(), -value);
} else {
// add some random noise to the policy to increase the entropy of
// the self play dataset and avoid just overfitting to the current
// policy during training.
dirichlet::add(&mut policy, 0.03);

// choose a random move from the policy for the first 10 turns, after
// that play deterministically (discounting the dirichlet noise) to
// avoid making large blunders during life or death situations.
let policy_m = choose(&board, current, &mut policy, count >= 10);

if policy_m.is_none() || policy_m == Some(361) { // passing move
Expand Down
9 changes: 8 additions & 1 deletion src/nn/cublas.rs → src/nn/ffi/cublas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use libc::{c_float, c_int, c_void};
use nn::cuda::Stream;
use nn::ffi::cuda::Stream;

#[repr(i32)]
#[allow(dead_code)]
Expand All @@ -31,6 +31,13 @@ pub enum Status {
LicenseError = 16
}

impl Status {
/// Returns whether this status indicates a successful call.
pub fn is_ok(&self) -> bool {
*self == Status::Success
}
}

#[repr(i32)]
#[allow(dead_code)]
pub enum Operation {
Expand Down
7 changes: 7 additions & 0 deletions src/nn/cuda.rs → src/nn/ffi/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ pub enum Error {
StartupFailure = 0x7f
}

impl Error {
/// Returns whether this _error_ indicates a successful call.
pub fn is_ok(&self) -> bool {
*self == Error::Success
}
}

#[repr(i32)]
#[allow(dead_code)]
pub enum MemcpyKind {
Expand Down
9 changes: 8 additions & 1 deletion src/nn/cudnn.rs → src/nn/ffi/cudnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use libc::{c_double, c_int, c_void, size_t};
use nn::cuda::Stream;
use nn::ffi::cuda::Stream;

#[repr(i32)]
#[allow(dead_code)]
Expand Down Expand Up @@ -67,6 +67,13 @@ pub enum Status {
RuntimePrerequisiteMissing = 11
}

impl Status {
/// Returns whether this status indicates a successful call.
pub fn is_ok(&self) -> bool {
*self == Status::Success
}
}

#[repr(i32)]
#[allow(dead_code)]
pub enum NanPropagation {
Expand Down
98 changes: 98 additions & 0 deletions src/nn/ffi/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2017 Karl Sundequist Blomdahl <karl.sundequist.blomdahl@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub mod cublas;
pub mod cuda;
pub mod cudnn;

#[cfg(test)]
mod tests {
use libc::{c_void};
use std::ptr;

use ::nn::ffi::*;

#[test]
fn sgemm() {
let mut handle: cublas::Handle = ptr::null_mut();
let c_0 = 0.0f32;
let c_1 = 1.0f32;

unsafe {
let a = [ // 3x2
1.0f32, 2.0f32,
3.0f32, 4.0f32,
5.0f32, 6.0f32
];
let b = [ // 2x3
1.0f32, 2.0f32, 3.0f32,
4.0f32, 5.0f32, 6.0f32
];
let c = [ // 3x3
0.0f32, 0.0f32, 0.0f32,
0.0f32, 0.0f32, 0.0f32,
0.0f32, 0.0f32, 0.0f32
];

// C = A * B
let mut a_ = ptr::null_mut();
let mut b_ = ptr::null_mut();
let mut c_ = ptr::null_mut();

assert_eq!(cuda::cudaMalloc(&mut a_, 24), cuda::Error::Success);
assert_eq!(cuda::cudaMalloc(&mut b_, 24), cuda::Error::Success);
assert_eq!(cuda::cudaMalloc(&mut c_, 36), cuda::Error::Success);
assert_eq!(cuda::cudaMemcpy(
a_,
a.as_ptr() as *const c_void,
24,
cuda::MemcpyKind::HostToDevice
), cuda::Error::Success);
assert_eq!(cuda::cudaMemcpy(
b_,
b.as_ptr() as *const c_void,
24,
cuda::MemcpyKind::HostToDevice
), cuda::Error::Success);

assert_eq!(cublas::cublasCreate_v2(&mut handle), cublas::Status::Success);
assert_eq!(cublas::cublasSgemm_v2(
handle,
cublas::Operation::N,
cublas::Operation::N,
3, 3, 2,
&c_1,
b_, 3,
a_, 2,
&c_0,
c_, 3
), cublas::Status::Success);
assert_eq!(cublas::cublasDestroy_v2(handle), cublas::Status::Success);

// check the results
assert_eq!(cuda::cudaMemcpy(
c.as_ptr() as *mut c_void,
c_,
36,
cuda::MemcpyKind::DeviceToHost
), cuda::Error::Success);

assert_eq!(c, [
9.0f32, 12.0f32, 15.0f32,
19.0f32, 26.0f32, 33.0f32,
29.0f32, 40.0f32, 51.0f32
])
}
}
}
Loading

0 comments on commit 6608c06

Please sign in to comment.