相关地址:
https://openi.pcl.ac.cn/devilmaycry812839668/SyncVectorEnv
点击查看代码
from multiprocessing import Process
from multiprocessing import Queue
import numpy as np
from numpy import ndarray
from typing import Callable, Dict, List, Tuple
"""目前主要支持gymnasium环境,gym环境未成功支持"""
from gymnasium import Env
# try:
# import gymnasium as gym
# except ImportError:
# import gym
# from gym import Env
"""
多进程实现:
以子进程的方式运行游戏环境并与主进程交互数据
"""
class ChildEnv(Process):
def __init__(self, id:int, env_fns:List[Callable[[], Env]], shared_data:Dict[str, ndarray], queue:Queue, barrier:Queue):
super(ChildEnv, self).__init__()
self.id = id # 子进程的编号
self.env_fns = env_fns
self.actions = shared_data["actions"]
self.rewards = shared_data["rewards"]
self.next_obs = shared_data["next_obs"]
self.next_obs_truncated = shared_data["next_obs_truncated"]
self.terminated = shared_data["terminated"]
self.truncated = shared_data["truncated"]
self.queue = queue
self.barrier = barrier
def run(self):
super(ChildEnv, self).run()
envs = [env_fn() for env_fn in self.env_fns]
for i, env in enumerate(envs):
self.next_obs[i], _ = env.reset()
# print(self.id, i, env)
self.barrier.put(True)
while True:
instruction = self.queue.get()
if instruction is None:
break
for i, (env, action) in enumerate(zip(envs, self.actions)):
next_obs, reward, terminated, truncated, _ = env.step(action)
if terminated:
self.next_obs[i], _ = env.reset()
elif truncated:
self.next_obs_truncated[i] = next_obs
self.next_obs[i], _ = env.reset()
else:
self.next_obs[i] = next_obs
self.rewards[i] = reward
self.terminated[i] = terminated
self.truncated[i] = truncated
self.barrier.put(True)