Skip to content
This repository has been archived by the owner on Aug 15, 2024. It is now read-only.

Commit

Permalink
action maskをagile_renderに追加
Browse files Browse the repository at this point in the history
  • Loading branch information
yasuohayashibara committed Nov 22, 2023
1 parent 5cf87e1 commit b248a9c
Show file tree
Hide file tree
Showing 3 changed files with 1,152 additions and 11 deletions.
10 changes: 6 additions & 4 deletions controllers/agilerl_render/agilerl_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import imageio
import numpy as np
import torch
from agilerl.algorithms.matd3 import MATD3
#from agilerl.algorithms.matd3 import MATD3
from soccer.matd3 import MATD3
from PIL import Image, ImageDraw

import soccer_v0
Expand Down Expand Up @@ -72,7 +73,7 @@ def _label_with_episode_number(frame, episode_num):
)

# Load the saved algorithm into the MATD3 object
path = "../agilerl_train/models/MATD3/231112/MATD3_trained_agent.pt"
path = "../agilerl_train/models/MATD3/231113/MATD3_trained_agent.pt"
matd3.loadCheckpoint(path)

rewards = [] # List to collect total episodic reward
Expand All @@ -86,16 +87,17 @@ def _label_with_episode_number(frame, episode_num):
state, _ = env.reset()
agent_reward = {agent_id: 0 for agent_id in agent_ids}
score = 0
info = None
for _ in range(max_steps):
# Get action
action = matd3.getAction(state, epsilon=0)
action = matd3.getAction(state, epsilon=0, action_mask=info)

# Save the frame for this step and append to frames list
#frame = env.render()
#frames.append(_label_with_episode_number(frame, episode_num=ep))

# Take action in environment
print(_, action)
#print(_, action)
state, reward, termination, truncation, info = env.step(action)

# Stop episode if any agents have terminated
Expand Down
Loading

0 comments on commit b248a9c

Please sign in to comment.