-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsrtf_env.py
77 lines (66 loc) · 3.07 KB
/
srtf_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import queue as Queue
import time
import random
import numpy as np
import parameters as pm
from scheduler_base import Scheduler
class SRTF_Env(Scheduler):
def _schedule(self):
tic = time.time()
srtf_queue = Queue.PriorityQueue()
for job in self.uncompleted_jobs:
srtf_queue.put((1 - job.progress / job.num_epochs, random.random(), job)) # enqueue jobs into srtf queue
flag = False
while not srtf_queue.empty():
(_, job_arrival, job) = srtf_queue.get()
# allocate maximal number of workers
# bundle one ps and one worker together by default
for i in range(pm.MAX_NUM_WORKERS):
_, node = self.node_used_resr_queue.get()
if pm.PS_WORKER:
resr_reqs = job.resr_worker + job.resr_ps
else:
resr_reqs = job.resr_worker
succ, node_used_resrs = self.cluster.alloc(resr_reqs, node)
self.node_used_resr_queue.put((np.sum(node_used_resrs), node))
if succ:
if pm.PS_WORKER and pm.BUNDLE_ACTION and False:
self._state(job.id, "bundle")
job.num_workers += 1
job.curr_worker_placement.append(node)
job.num_ps += 1
job.curr_ps_placement.append(node)
job.dom_share = np.max(1.0 * (
job.num_workers * job.resr_worker + job.num_ps * job.resr_ps) / self.cluster.CLUSTER_RESR_CAPS)
else:
self._state(job.id, "worker")
job.num_workers += 1
job.curr_worker_placement.append(node)
job.dom_share = np.max(1.0 * (
job.num_workers * job.resr_worker + job.num_ps * job.resr_ps) / self.cluster.CLUSTER_RESR_CAPS)
if pm.PS_WORKER:
self._state(job.id, "ps")
job.num_ps += 1
job.curr_ps_placement.append(node)
job.dom_share = np.max(1.0 * (
job.num_workers * job.resr_worker + job.num_ps * job.resr_ps) / self.cluster.CLUSTER_RESR_CAPS)
self.running_jobs.add(job)
else: # fail to alloc resources
flag = True
break
if flag:
break
toc = time.time()
self.logger.debug(self.name + ":: " + "scheduling time: " + "%.3f" % (toc - tic) + " seconds.")
for job in self.uncompleted_jobs:
self.logger.debug(self.name + ":: scheduling results" + " num_worker: " + str(job.num_workers))
def test():
import log, trace
logger = log.getLogger(name="test.log", level="INFO")
job_trace = trace.Trace(logger).get_trace()
env = SRTF_Env("SRTF", job_trace, logger)
while not env.end:
env.step()
print(env.get_results())
if __name__ == '__main__':
test()