-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathGridShield.py
371 lines (307 loc) · 16.2 KB
/
GridShield.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import json
import numpy as np
from copy import deepcopy
import random
class GridShield:
# Composed shield version
def __init__(self, env, nagents=2, start=np.array([[7, 0], [6, 3]]), file='example/example_2_agents_'):
self.nagents = nagents
base = 'shields/grid_shields/'
self.res = [3, 3]
self.max_per_shield = 2
self.smap = np.zeros([env.nrows, env.ncols])
# print('rows :', env.nrows, ' - cols: ', env.ncols)
for i in range(env.nrows):
for j in range(env.ncols):
# 1 based shield number
shieldnum = int((i / self.res[0])) * int(env.ncols / (self.res[1])) + int(j / self.res[0])
self.smap[i][j] = shieldnum
self.nshields = int(np.max(self.smap))+1
self.shield_json = []
self.current_state = np.zeros([self.nshields], dtype=int)
self.start_state = np.zeros([self.nshields], dtype=int)
self.agent_pos = np.ones([self.nagents, 2]) * -1 # [ index in shield, shield #]
self.states = np.reshape(np.arange(0, self.res[0] * self.res[1]), self.res)
# loading Shields.
for sh in range(self.nshields):
f = open(base + file + '_' + str(sh) + '.shield')
self.shield_json.append(json.load(f))
f.close()
# check in which shields agents are starting
for i in range(self.nagents):
sh_ind = self.smap[start[i][0]][start[i][1]]
self.agent_pos[i][1] = sh_ind
# print(self.smap)
# find the start state for each shield.
for sh in range(self.nshields):
ag = np.where(self.agent_pos[:, 1] == sh)[0]
if len(ag) > 0: # if there are agents starting in that shield
found = self._find_start_state(start, sh, ag)
if found:
self.current_state[sh] = self.start_state[sh]
else:
print('State not found for shield : ', str(sh))
# print(self.agent_pos)
self.start_pos = deepcopy(self.agent_pos)
self.prev_pos = deepcopy(self.agent_pos)
# search shield to find right start state based on agent start states
# i = index of shield, agents = indices of agents in shield i
def _find_start_state(self, start, i, agents):
found = False
pos = []
for a in agents:
pos.append(self._to_state(start[a]))
# print('shield:', i, 'pos:', pos, 'agents: ', agents)
for ind in range(len(self.shield_json[i])):
cur = self.shield_json[i][str(ind)]['State']
condition = np.zeros([self.max_per_shield])
# agents currently in shield
for a in range(len(agents)):
str_state = 'a' + str(a)
str_req = 'a_req' + str(a)
if pos[a] == cur[str_state] and cur[str_req] == pos[a]:
condition[a] = 1
# All other spaces
for a in range(len(agents), self.max_per_shield):
str_req = 'a_req' + str(a)
str_state = 'a' + str(a)
if cur[str_state] == 9 and cur[str_req] == 9:
condition[a] = 1
if np.count_nonzero(condition) == self.max_per_shield:
found = True
self.start_state[i] = ind # state number.
for a in agents:
self.agent_pos[a][0] = np.where(agents == a)[0][0]
self.agent_pos[a][1] = i
break
return found
# Transform 2d coordinate to inner shield state.
def _to_state(self, pos):
sh = self.smap[pos[0]][pos[1]]
res_width = self.smap.shape[1]/ self.res[0]
corner = np.array([(int(sh /res_width)) * self.res[0], (sh % res_width) * self.res[1]], dtype=int)
in_shield_pos = pos - corner
state = self.states[in_shield_pos[0]][in_shield_pos[1]]
return state
# use actions and current shield state to determine if action is dangerous.
# Step -> all shields, assumption : both agents cannot have already been in the shield and both have the same idx
# Assumption : when 2 in shield and 1 entering -> wait one turn until on agent leaves.s
def step(self, actions, pos, goal_flag, env):
act = np.ones([self.nagents], dtype=bool)
a_states = np.zeros([self.nagents], dtype=int)
a_req = np.ones([self.nagents, 2], dtype=int) * -1 # desired shield state
pos_shield = np.zeros([self.nagents], dtype=int)
desired_shield = np.zeros([self.nagents])
desired = np.zeros([self.nagents, 2], dtype=int) # desired coordinates
shield_idx = np.ones([self.nagents, 2]) * -1 # desired agent index in shield
oob = np.zeros([self.nagents])
obs = np.zeros([self.nagents])
# update which agents are in which shields
for i in range(self.nagents):
desired[i], oob[i], obs[i] = env.get_next_state(pos[i], actions[i], goal_flag[i])
pos_shield[i] = self.smap[pos[i][0]][pos[i][1]]
if not goal_flag[i]:
desired_shield[i] = self.smap[desired[i][0]][desired[i][1]]
else:
desired_shield[i] = pos_shield[i]
a_states[i] = self._to_state(pos[i])
a_req[i][0] = self._to_state(desired[i])
shield_idx[i] = [self.agent_pos[i][0], self.agent_pos[i][0]]
if pos_shield[i] != desired_shield[i]: # Exiting /entering
a_req[i][1] = 9
else:
a_req[i][1] = -1
# Loop over shields
for sh in range(self.nshields):
ag_sh = np.where(pos_shield == sh)[0] # which agents are in shield sh
des_sh = np.where(desired_shield == sh)[0]
des_sh = np.setdiff1d(des_sh, ag_sh) # only in desired (not already in these shields)
ex_sh = [] # agents exiting
a_des_states = [9] * len(des_sh) # they;re not in the shield so current state is 9.
# find exiting agents
for a in ag_sh:
if desired_shield[a] != sh and a_req[a][1] == 9:
ex_sh.append(a)
# for if an agent tried to go to an invalid location but didnt actually go there.
for a in ag_sh:
if sh != int(self.agent_pos[a][1]):
self.agent_pos[a] = deepcopy(self.prev_pos[a])
shield_idx[a] = [self.agent_pos[a][0], self.agent_pos[a][0]]
temp_req = deepcopy(a_req)
temp_req[ex_sh] = [9, 9] # so that exiting agents ask for 9 not sth else.
if len(ag_sh) > self.max_per_shield:
print('Error too many agents in shield : ', sh)
elif len(ag_sh) == 0 and len(des_sh) > 0: # there are no current agents but there are new ones
if len(des_sh) > 2:
for i in range(2, len(des_sh)):
act[des_sh[i]] = False
if len(des_sh) == 1:
temp = self.step_one(sh, [goal_flag[des_sh[0]]],
[a_req[des_sh[0]]],
[a_des_states[0]],
agent0=des_sh[0])
shield_idx[des_sh[0]][1] = 0
act[des_sh[0]] = act[des_sh[0]] and temp[0]
elif des_sh[0] > des_sh[1]:
temp = self.step_one(sh, [goal_flag[des_sh[1]], goal_flag[des_sh[0]]],
[a_req[des_sh[1]], a_req[des_sh[0]]],
[a_des_states[1], a_des_states[0]],
agent0=des_sh[1], agent1=des_sh[0])
shield_idx[des_sh[0]][1] = 1
shield_idx[des_sh[1]][1] = 0
act[des_sh[0]] = act[des_sh[0]] and temp[0]
act[des_sh[1]] = act[des_sh[1]] and temp[1]
else:
temp = self.step_one(sh, [goal_flag[des_sh[0]], goal_flag[des_sh[1]]],
[a_req[des_sh[0]], a_req[des_sh[1]]],
[a_des_states[0], a_des_states[1]],
agent0=des_sh[0], agent1=des_sh[1])
shield_idx[des_sh[0]][1] = 0
shield_idx[des_sh[1]][1] = 1
act[des_sh[0]] = act[des_sh[0]] and temp[0]
act[des_sh[1]] = act[des_sh[1]] and temp[1]
elif len(ag_sh) == 1: # there is one agent for this shield -> can accept 1 new
if len(des_sh) > 1:
for i in range(1, len(des_sh)): # only one is accepted so all others are denied movement.
act[des_sh[i]] = False
if self.agent_pos[ag_sh[0]][0] == 0:
if len(des_sh) > 0:
if des_sh[0] > ag_sh[0]: # make sure the order in a_req is consistent with agent order
temp = self.step_one(sh, [goal_flag[ag_sh[0]], goal_flag[des_sh[0]]],
[temp_req[ag_sh[0]], a_req[des_sh[0]]],
[a_states[ag_sh[0]], a_des_states[0]],
agent0=ag_sh[0], agent1=des_sh[0])
shield_idx[des_sh[0]][1] = 1
act[ag_sh[0]] = act[ag_sh[0]] and temp[0]
act[des_sh[0]] = act[des_sh[0]] and temp[1]
else:
temp = self.step_one(sh, [goal_flag[des_sh[0]], goal_flag[ag_sh[0]]],
[a_req[des_sh[0]], temp_req[ag_sh[0]]],
[a_des_states[0], a_states[ag_sh[0]]],
agent0=ag_sh[0], agent1=des_sh[0])
shield_idx[des_sh[0]][1] = 1
act[ag_sh[0]] = act[ag_sh[0]] and temp[1]
act[des_sh[0]] = act[des_sh[0]] and temp[0]
else:
temp = self.step_one(sh, goal_flag[ag_sh[0]],
temp_req[ag_sh[0]], a_states[ag_sh[0]], agent0=ag_sh[0])
act[ag_sh[0]] = act[ag_sh[0]] and temp
elif self.agent_pos[ag_sh[0]][0] == 1:
if len(des_sh) > 0:
if des_sh[0] > ag_sh[0]:
temp = self.step_one(sh, [goal_flag[ag_sh[0]], goal_flag[des_sh[0]]],
[temp_req[ag_sh[0]], a_req[des_sh[0]]],
[a_states[ag_sh[0]], a_des_states[0]],
agent1=ag_sh[0], agent0=des_sh[0])
shield_idx[des_sh[0]][1] = 0
act[ag_sh[0]] = act[ag_sh[0]] and temp[0]
act[des_sh[0]] = act[des_sh[0]] and temp[1]
else:
temp = self.step_one(sh, [goal_flag[des_sh[0]], goal_flag[ag_sh[0]]],
[a_req[des_sh[0]], temp_req[ag_sh[0]]],
[a_des_states[0], a_states[ag_sh[0]]],
agent1=ag_sh[0], agent0=des_sh[0])
shield_idx[des_sh[0]][1] = 0
act[ag_sh[0]] = act[ag_sh[0]] and temp[1]
act[des_sh[0]] = act[des_sh[0]] and temp[0]
else:
temp = self.step_one(sh, goal_flag[ag_sh[0]],
temp_req[ag_sh[0]], a_states[ag_sh[0]], agent1=ag_sh[0])
act[ag_sh[0]] = act[ag_sh[0]] and temp
elif len(ag_sh) == 2: # no more space
for a in des_sh:
act[a] = False
if self.agent_pos[ag_sh[0]][0] == 0:
a0 = ag_sh[0]
a1 = ag_sh[1]
else:
a0 = ag_sh[1]
a1 = ag_sh[0]
temp = self.step_one(sh, goal_flag[ag_sh], temp_req[ag_sh], a_states[ag_sh], agent0=a0, agent1=a1)
for i in range(len(ag_sh)):
act[ag_sh[i]] = act[ag_sh[i]] and temp[i]
else: # no agents and no desired
self.current_state[sh] = 0
self.prev_pos = deepcopy(self.agent_pos)
for i in range(self.nagents):
if act[i] and not oob[i] and not obs[i]:
self.agent_pos[i] = [shield_idx[i][1], desired_shield[i]]
# act says which ones need to be changed
actions[~ act] = False # blocked = False
return actions
def _get_arr_idx(self, agent0=None, agent1=None):
idx = []
if agent0 is None or agent1 is None:
idx = [0]
elif agent1 is not None and agent0 is not None:
if agent0 > agent1:
idx = [1, 0]
else:
idx = [0, 1] # agent0 < agent1
return idx
def _get_agent_idx(self, agent0=None, agent1=None):
idx = []
if agent0 is not None and agent1 is None:
idx = [0, 1]
elif agent1 is not None and agent0 is None:
idx=[1, 0]
elif agent1 is not None and agent0 is not None:
if agent0 > agent1:
idx = [1, 0]
else:
idx = [0, 1] # agent0 < agent1
return idx
# this function takes a shield state and checks if it matches requirements
def _compute_condition(self, cur, goal_flag, a_states, a_req, idx, agent0=None, agent1=None):
condition = np.zeros([self.max_per_shield])
agents = [agent0, agent1]
for i in range(self.max_per_shield):
a_str = 'a' + str(i)
a_req_str = 'a_req' + str(i)
if agents[i] is not None:
if goal_flag[idx[i]] == 1:
condition[idx[i]] = (cur[a_str] == a_states[idx[i]] and cur[a_req_str] == a_states[idx[i]])
else:
condition[idx[i]] = (cur[a_str] == a_states[idx[i]] and cur[a_req_str] == a_req[idx[i]][0])
else:
condition[idx[i]] = (cur[a_str] == 9 and cur[a_req_str] == 9)
return condition
# Processes one shield and relevant agents
def step_one(self, sh, goal_flag, a_req, a_states, agent0=None, agent1=None):
idx = self._get_arr_idx(agent0=agent0, agent1=agent1) # idx of present agent in arrays
a_idx = self._get_agent_idx(agent0=agent0, agent1=agent1) # idx of present agents
# print('idx test: ', [agent0, agent1], ' -- ', a_idx)
if type(goal_flag) is np.int64: # if there's only one agent
goal_flag = np.array([goal_flag])
a_states = np.array([a_states])
a_req = np.array([a_req])
act = np.zeros([len(goal_flag)], dtype=int)
successors = np.array(self.shield_json[sh][str(self.current_state[sh])]['Successors'])
for s in successors:
cur = self.shield_json[sh][str(s)]['State']
condition = self._compute_condition(cur, goal_flag, a_states, a_req, a_idx, agent0, agent1)
if np.all(condition): # found the correct successor
for i in range(len(goal_flag)):
s_str = 'a_shield' + str(i)
act[idx[i]] = cur[s_str]
if len(self.shield_json[sh][str(s)]['Successors']) > 0:
self.current_state[sh] = s
break
return act
def reset(self):
# reset to start state.
self.current_state = deepcopy(self.start_state)
self.agent_pos = deepcopy(self.start_pos)
self.prev_pos = deepcopy(self.agent_pos)
# for testing
if __name__ == "__main__":
shield = GridShield()
actions = np.array([4, 3])
sactions = shield.step(actions)
print('Start actions:', actions, 'Shield actions:', sactions)
actions = np.array([4, 2])
sactions = shield.step(actions)
print('Start actions:', actions, 'Shield actions:', sactions)
actions = np.array([4, 3])
sactions = shield.step(actions)
print('Start actions:', actions, 'Shield actions:', sactions)