forked from rggibson/open-pure-cfr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
entries.hpp
300 lines (255 loc) · 8.09 KB
/
entries.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
#ifndef __PURE_CFR_ENTRIES_HPP__
#define __PURE_CFR_ENTRIES_HPP__
/* entries.hpp
* Richard Gibson, Jul 1, 2013
* Email: richard.g.gibson@gmail.com
*
* A class for storing regret and avg strategies of variable type.
*
* Copyright (C) 2013 by Richard Gibson
*/
/* C / C++ / STL includes */
#include <assert.h>
#include <typeinfo>
/* C project-acpc-poker includes */
extern "C" {
}
/* Pure CFR includes */
#include "constants.hpp"
class Entries {
public:
Entries( size_t new_num_entries_per_bucket, size_t total_num_entries );
virtual ~Entries( );
/* Returns the sum of all pos_values in the returned pos_values array */
virtual uint64_t get_pos_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
uint64_t *pos_values ) const = 0;
virtual void update_regret( const int bucket,
const int64_t soln_idx,
const int num_choices,
const int *values,
const int retval ) = 0;
/* Return 0 on success, 1 on overflow */
virtual int increment_entry( const int bucket, const int64_t soln_idx, const int choice ) = 0;
/* Return 0 on success, 1 on failure */
virtual int write( FILE *file ) const = 0;
virtual int load( FILE *file ) = 0;
virtual pure_cfr_entry_type_t get_entry_type( ) const = 0;
protected:
size_t get_entry_index( const int bucket, const int64_t soln_idx ) const;
const size_t num_entries_per_bucket;
const size_t total_num_entries;
};
template <typename T>
class Entries_der : public Entries {
public:
Entries_der( size_t new_num_entries_per_bucket,
size_t new_total_num_entries,
T *loaded_data = NULL );
virtual ~Entries_der( );
virtual uint64_t get_pos_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
uint64_t *pos_values ) const;
virtual void update_regret( const int bucket,
const int64_t soln_idx,
const int num_choices,
const int *values,
const int retval );
virtual int increment_entry( const int bucket,
const int64_t soln_idx,
const int choice );
virtual int write( FILE *file ) const;
virtual int load( FILE *file );
virtual pure_cfr_entry_type_t get_entry_type( ) const;
virtual void get_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
T *values ) const;
virtual void set_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
const T *values );
protected:
T *entries;
const int data_was_loaded;
};
Entries *new_loaded_entries( size_t num_entries_per_bucket,
size_t total_num_entries,
void **data );
/* Unfortunately, templates require definitions in the same file
* as their declarations
*/
template <typename T>
Entries_der<T>::Entries_der( size_t new_num_entries_per_bucket,
size_t new_total_num_entries,
T *loaded_data )
: Entries( new_num_entries_per_bucket, new_total_num_entries ),
data_was_loaded( loaded_data != NULL ? 1 : 0 )
{
if( loaded_data != NULL ) {
entries = loaded_data;
} else {
entries = ( T * ) calloc( total_num_entries, sizeof( T ) );
/* If you hit this assert, you have run out of RAM!
* Use a smaller game or coarser abstractions.
*/
assert( entries != NULL );
}
}
template <typename T>
Entries_der<T>::~Entries_der( )
{
if( !data_was_loaded ) {
free( entries );
}
entries = NULL;
}
template <typename T>
uint64_t Entries_der<T>::get_pos_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
uint64_t *values ) const
{
/* Get the local entries at this index */
size_t base_index = get_entry_index( bucket, soln_idx );
T local_entries[ num_choices ];
memcpy( local_entries, &entries[ base_index ], num_choices * sizeof( T ) );
/* Zero out negative values and store in the returned array */
uint64_t sum_values = 0;
for( int c = 0; c < num_choices; ++c ) {
local_entries[ c ] *= ( local_entries[ c ] > 0 );
values[ c ] = local_entries[ c ];
sum_values += local_entries[ c ];
}
return sum_values;
}
template <typename T>
void Entries_der<T>::update_regret( const int bucket,
const int64_t soln_idx,
const int num_choices,
const int *values,
const int retval )
{
/* Get a pointer to the local entries at this index */
size_t base_index = get_entry_index( bucket, soln_idx );
T *local_entries = &entries[ base_index ];
for( int c = 0; c < num_choices; ++c ) {
int diff = values[ c ] - retval;
T new_regret = local_entries[ c ] + diff;
/* Only update regret if no overflow occurs */
if( ( ( diff < 0 ) && ( new_regret < local_entries[ c ] ) )
|| ( ( diff > 0 ) && ( new_regret > local_entries[ c ] ) ) ) {
local_entries[ c ] = new_regret;
}
}
}
template <typename T>
int Entries_der<T>::increment_entry( const int bucket, const int64_t soln_idx, const int choice )
{
/* Get a pointer to the local entries at this index */
size_t base_index = get_entry_index( bucket, soln_idx );
T *local_entries = &entries[ base_index ];
local_entries[ choice ] += 1;
if( local_entries[ choice ] <= 0 ) {
/* Overflow! */
return 1;
}
return 0;
}
template <typename T>
int Entries_der<T>::write( FILE *file ) const
{
if( data_was_loaded ) {
fprintf( stderr, "tried to write data that was loaded at instantiation, "
"which is not allowed\n" );
return 1;
}
/* First, write the type to file */
pure_cfr_entry_type_t type = get_entry_type( );
size_t num_written = fwrite( &type, sizeof( pure_cfr_entry_type_t ), 1, file );
if( num_written != 1 ) {
fprintf( stderr, "error while writing dump type [%d]\n", type );
return 1;
}
/* Dump entries */
num_written = fwrite( entries, sizeof( T ), total_num_entries, file );
if( num_written != total_num_entries ) {
fprintf( stderr, "error while writing; only wrote %jd of %jd entries\n",
( intmax_t ) num_written, ( intmax_t ) total_num_entries );
return 1;
}
return 0;
}
template <typename T>
int Entries_der<T>::load( FILE *file )
{
if( data_was_loaded ) {
fprintf( stderr, "tried to load from file on top of loaded data at "
"instantiation, which is not allowed\n" );
return 1;
}
/* First, load the type and double-check that it matches */
pure_cfr_entry_type_t type;
size_t num_read = fread( &type,
sizeof( pure_cfr_entry_type_t ),
1,
file );
if( num_read != 1 ) {
fprintf( stderr, "failed to read entry type\n" );
return 1;
}
pure_cfr_entry_type_t this_type = get_entry_type( );
if( type != this_type ) {
fprintf( stderr, "type [%d] found, but expected type [%d]\n",
type, this_type );
return 1;
}
/* Now load the entries */
num_read = fread( entries, sizeof( T ), total_num_entries, file );
if( num_read != total_num_entries ) {
fprintf( stderr, "error while loading; only read %jd of %jd entries\n",
( intmax_t ) num_read, ( intmax_t ) total_num_entries );
return 1;
}
return 0;
}
template <typename T>
pure_cfr_entry_type_t Entries_der<T>::get_entry_type( ) const
{
if( typeid( T ) == typeid( uint8_t ) ) {
return TYPE_UINT8_T;
} else if( typeid( T ) == typeid( int ) ) {
return TYPE_INT;
} else if( typeid( T ) == typeid( uint32_t ) ) {
return TYPE_UINT32_T;
} else if( typeid( T ) == typeid( uint64_t ) ) {
return TYPE_UINT64_T;
} else {
fprintf( stderr, "called get_entry_type for unrecognized template type!\n" );
assert( 0 );
return TYPE_NUM_TYPES;
}
}
template <typename T>
void Entries_der<T>::get_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
T *values ) const
{
size_t base_index = get_entry_index( bucket, soln_idx );
/* Copy the values over */
memcpy( values, &entries[ base_index ], num_choices * sizeof( T ) );
}
template <typename T>
void Entries_der<T>::set_values( const int bucket,
const int64_t soln_idx,
const int num_choices,
const T *values )
{
size_t base_index = get_entry_index( bucket, soln_idx );
/* Copy the values over */
memcpy( &entries[ base_index ], values, num_choices * sizeof( T ) );
}
#endif