forked from b06b01073/go_thesis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGameUnroller.py
49 lines (32 loc) · 1.41 KB
/
GameUnroller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from GoEnv import Go
import goutils
import govars
from seed import set_seed
set_seed()
class GameUnroller:
'''
A class that plays the Go using `unroll(game)` given the `game`.
'''
def __init__(self):
self.go_env = Go()
def unroll(self, game):
'''
Args:
game(list of string): a list of string where each entry represents a move
Return:
return a list that stores the features and a list that stores the label
'''
self.go_env.reset() # reset the game (clear all previous states)
# this part will be the inputs and labels for the model, where game_moves[i] is the label for game_states[i]
game_states = []
game_moves = []
# in SGF the coordinate of a move is represented as [column row] and the origin is at the top-left corner
for move in game:
action1d = goutils.move_encode(move) # move is the COLOR[column row], where COLOR is 'B' or 'W'
game_features = self.go_env.game_features() # get the features for the current board, this will be the input
# keep only the non-pass transitioins
if action1d != govars.PASS:
game_states.append(game_features)
game_moves.append(action1d)
self.go_env.make_move(action1d)
return game_states, game_moves