-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathenvironment.py
387 lines (319 loc) · 13.1 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
'''
The main city simulation implementation.
'''
from __future__ import annotations
from copy import deepcopy
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from utils import Direction, opposite
JUNCTION_COOLDOWN = 10
class GridWorldEnv:
def __init__(self, render_mode=None, map_size=42, num_of_vehicles=80):
# Generate grid map and its junctions
self.map, self.junctions = self.generate_map(map_size, map_size, 3, 3)
n, m = self.map.shape
self.day = 0
# Generate vehicles on random positions
self.vehicles = self.generate_vehicles(num_of_vehicles)
# Corner junctions
self.corner_roads = {
(1, 1): Direction.DOWN,
(2, 2): Direction.RIGHT,
(n - 2, 1): Direction.RIGHT,
(n - 3, 2): Direction.UP,
(n - 2, m - 2): Direction.UP,
(n - 3, m - 3): Direction.LEFT,
(1, m - 2): Direction.LEFT,
(2, m - 3): Direction.DOWN
}
def step(self) -> int:
"""
Makes simulation step and returns how many vehicles were moved
:return: reward
"""
self.day += 1
vehicles_moved = 0
# Move vehicles
for vehicle in self.vehicles:
if vehicle.move():
vehicles_moved += 1
# Unblock junctions after one tick
for junction in self.junctions:
junction.unblock()
# Remove left turn blocks
for vehicle in self.vehicles:
vehicle.unswap()
return vehicles_moved
def change_random_lights(self):
for junction in self.junctions:
junction.set_state(np.random.randint(0, 2))
@staticmethod
def generate_map(i: int, j: int, hor_jun: int, ver_jun: int) -> tuple[np.ndarray, np.ndarray]:
"""
Generate random map
:param i: height
:param j: width
:param hor_jun: horizontal junctions amount
:param ver_jun: vertical junctions amount
:return:
"""
new_map = np.zeros((i, j))
junctions = []
# Edge roads
new_map[1:3, 1:-1] = 1
new_map[-3:-1, 1:-1] = 1
new_map[1:-1, 1:3] = 1
new_map[1:-1, -3:-1] = 1
hor_jun_loc = list(np.linspace(1, i - 3, hor_jun + 2, dtype=int)[1:-1])
ver_jun_loc = list(np.linspace(1, j - 3, ver_jun + 2, dtype=int)[1:-1])
# 3 prio roads
new_map[hor_jun_loc[len(hor_jun_loc) // 2]:hor_jun_loc[len(hor_jun_loc) // 2] + 2, 1:-1] = 3
new_map[1:-1, ver_jun_loc[len(ver_jun_loc) // 2]:ver_jun_loc[len(ver_jun_loc) // 2] + 2] = 3
# 2-prio roads
new_map[hor_jun_loc[len(hor_jun_loc) // 4]:hor_jun_loc[len(hor_jun_loc) // 4] + 2, 1:-1] = 2
new_map[hor_jun_loc[3 * len(hor_jun_loc) // 4]:hor_jun_loc[3 * len(hor_jun_loc) // 4] + 2, 1:-1] = 2
new_map[1:-1, ver_jun_loc[len(ver_jun_loc) // 4]:ver_jun_loc[len(ver_jun_loc) // 4] + 2] = 2
new_map[1:-1, ver_jun_loc[3 * len(ver_jun_loc) // 4]:ver_jun_loc[3 * len(ver_jun_loc) // 4] + 2] = 2
# 1-prio roads
for hor_loc in hor_jun_loc:
if hor_loc not in [hor_jun_loc[len(hor_jun_loc) // 2], hor_jun_loc[len(hor_jun_loc) // 4],
hor_jun_loc[3 * len(hor_jun_loc) // 4]]:
new_map[hor_loc:hor_loc + 2, 1:-1] = 1
for ver_loc in ver_jun_loc:
if ver_loc not in [ver_jun_loc[len(ver_jun_loc) // 2], ver_jun_loc[len(ver_jun_loc) // 4],
ver_jun_loc[3 * len(ver_jun_loc) // 4]]:
new_map[1:-1, ver_loc:ver_loc + 2] = 1
hor_jun_loc = list(np.linspace(1, i - 3, hor_jun + 2, dtype=int))
ver_jun_loc = list(np.linspace(1, j - 3, ver_jun + 2, dtype=int))
for hor_loc in hor_jun_loc:
for ver_loc in ver_jun_loc:
up_prio = -1
down_prio = -1
left_prio = -1
right_prio = -1
if not hor_loc == 1:
up_prio = new_map[hor_loc - 1, ver_loc]
if not hor_loc == i - 3:
down_prio = new_map[hor_loc + 2, ver_loc]
if not ver_loc == 1:
left_prio = new_map[hor_loc, ver_loc - 1]
if not ver_loc == j - 3:
right_prio = new_map[hor_loc, ver_loc + 2]
negative = [up_prio, down_prio,
left_prio, right_prio].count(-1)
if negative == 2:
continue
prios = []
if up_prio != -1:
prios.append((up_prio, Direction.UP))
if down_prio != -1:
prios.append((down_prio, Direction.DOWN))
if left_prio != -1:
prios.append((left_prio, Direction.LEFT))
if right_prio != -1:
prios.append((right_prio, Direction.RIGHT))
junctions.append(Junction(hor_loc, ver_loc, prios))
return new_map, junctions
def generate_vehicles(self, n: int) -> List[Vehicle]:
veh_list = []
road_tiles = list(zip(*np.nonzero(self.map)))
indices = np.random.choice(len(road_tiles), n, replace=False)
road_tiles = [road_tiles[k] for k in indices]
for i in range(n):
direction = None
if self.map[road_tiles[i][0], road_tiles[i][1] - 1] == 0:
direction = Direction.DOWN
elif self.map[road_tiles[i][0], road_tiles[i][1] + 1] == 0:
direction = Direction.UP
elif self.map[road_tiles[i][0] - 1, road_tiles[i][1]] == 0:
direction = Direction.LEFT
elif self.map[road_tiles[i][0] + 1, road_tiles[i][1]] == 0:
direction = Direction.RIGHT
# if it's on junction
elif road_tiles[i][0] == 2:
direction = Direction.RIGHT
elif road_tiles[i][0] == self.map.shape[0] - 3:
direction = Direction.LEFT
elif self.map[
road_tiles[i][0] + 1, road_tiles[i][1] + 1] == 0 or self.map[
road_tiles[i][0] - 1, road_tiles[i][1] + 1] == 0:
direction = Direction.UP
elif self.map[
road_tiles[i][0] - 1, road_tiles[i][1] - 1] == 0 or self.map[
road_tiles[i][0] + 1, road_tiles[i][1] - 1] == 0:
direction = Direction.DOWN
assert direction is not None
veh_list.append(
Vehicle(self, road_tiles[i][0], road_tiles[i][1], direction))
return veh_list
def render(self):
"""
Renders current state of the map and shows it as matplotlib image
"""
temp = deepcopy(self.map)
self.render_junctions(temp)
for vehicle in self.vehicles:
temp[vehicle.i, vehicle.j] = 5
plt.matshow(temp)
plt.show()
def render_junctions(self, map_to_edit):
"""
Renders junctions on the map to display it on matplotlib image
:param map_to_edit: map to edit
:return: inplace
"""
for junction in self.junctions:
map_to_edit[junction.i:junction.i + 2,
junction.j:junction.j + 2] = -2 + junction.state
def animate(self, n=1000):
"""
Animates the map as matplotlib animation
:param n: steps
"""
fig, ax = plt.subplots()
ims = []
for i in range(n):
self.step()
temp = deepcopy(self.map)
self.render_junctions(temp)
for vehicle in self.vehicles:
temp[vehicle.i, vehicle.j] = 5
im = ax.imshow(temp, animated=True)
if i == 0:
ax.imshow(temp)
if self.day % 10 == 0:
self.change_random_lights()
ims.append([im])
ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()
class Junction:
"""
Class representing junctions on the map
"""
def __init__(self, top_i, top_j, valid_turns, state=0):
"""
:param top_i: top left corner of junction - row
:param top_j: top left corner of junction - column
:param valid_turns: list of tuples (priority, direction) of valid turns for each junction
:param state: current state of the junction (0 - left/right, 1 - up/down)
"""
self.i = top_i
self.j = top_j
self.valid_turns = valid_turns
self.state = state
# Is junction blocked after light change
self.is_blocked = False
# Cooldown between light changes
self.cooldown = 0
def is_in_junction(self, i, j) -> bool:
"""
Checks if given coordinates are in junction
:return: bool
"""
return self.i <= i <= self.i + 1 and self.j <= j <= self.j + 1
def get_random_turn(self, from_direction):
valid_turns = [turn for turn in self.valid_turns if turn[1]
!= opposite(from_direction)]
total = sum([turn[0] for turn in valid_turns])
return np.random.choice(list(map(lambda x: x[1], valid_turns)), p=[turn[0] / total for turn in valid_turns])
def set_state(self, new_state):
"""
Sets new state of the junction if it is possible
:param new_state:
:return:
"""
if self.cooldown > 0:
self.cooldown -= 1
return
if new_state ^ self.state == 0:
return
self.is_blocked = True
self.state = new_state
self.cooldown = JUNCTION_COOLDOWN
def unblock(self):
self.is_blocked = False
class Vehicle:
def __init__(self, world: GridWorldEnv, i: int, j: int, start_direction: Direction) -> None:
self.world = world
self.i = i
self.j = j
self.direction = start_direction
self.temp_direction = None
self.swapped = False
def move(self):
self.find_direction()
# Already turned left
if self.swapped:
return True
temp_direction = self.direction.value
# Is it left turn?
if self.temp_direction is not None:
temp_direction = self.direction.value[0] + self.temp_direction.value[0], \
self.direction.value[1] + self.temp_direction.value[1]
next_i, next_j = self.i + temp_direction[0], self.j + temp_direction[1]
# Other vehicles
for vehicle in self.world.vehicles:
if vehicle.i == next_i and vehicle.j == next_j:
if self.temp_direction is not None and vehicle.temp_direction is not None:
self.i, self.j, vehicle.i, vehicle.j = vehicle.i, vehicle.j, self.i, self.j
vehicle.swapped = True
self.swapped = True
self.direction = self.temp_direction
vehicle.direction = vehicle.temp_direction
self.temp_direction = None
vehicle.temp_direction = None
return True
else:
return False
# On the junction
is_in_junction = False
for junction in self.world.junctions:
if junction.is_in_junction(self.i, self.j):
is_in_junction = True
# Want to entry junctions
if not is_in_junction:
for junction in self.world.junctions:
if not junction.is_in_junction(next_i, next_j):
continue
if junction.is_blocked:
return False
if junction.state == 0:
if self.direction not in [Direction.LEFT, Direction.RIGHT]:
return False
else:
if self.direction not in [Direction.DOWN, Direction.UP]:
return False
if self.temp_direction is None:
temp_direction = junction.get_random_turn(self.direction)
if self.direction.is_left_turn(temp_direction):
self.temp_direction = temp_direction
else:
self.direction = temp_direction
else:
if self.temp_direction is not None:
self.direction = self.temp_direction
self.temp_direction = None
# Road
assert self.world.map[next_i, next_j] >= 1
self.i, self.j = next_i, next_j
return True
def find_direction(self):
for i1, j1 in self.world.corner_roads:
if self.i == i1 and self.j == j1:
self.direction = self.world.corner_roads[(i1, j1)]
def unswap(self):
self.swapped = False
def __repr__(self):
return f"Vehicle({self.i}, {self.j}, {self.direction})"
if __name__ == '__main__':
env = GridWorldEnv()
env.animate(500)
# ms = MainScreen(env.map)
# for i in range(3):
# env.step()
# env.render()
# if i % 10 == 9:
# env.change_random_lights()