0
点赞
收藏
分享

微信扫一扫

python的多进程共享内存操作出现的同步不一致问题 ——

相关地址:

https://openi.pcl.ac.cn/devilmaycry812839668/SyncVectorEnv



点击查看代码

from multiprocessing import Process
from multiprocessing import Queue
import numpy as np
from numpy import ndarray
from typing import Callable, Dict, List, Tuple

"""目前主要支持gymnasium环境,gym环境未成功支持"""
from gymnasium import Env
# try:
#     import gymnasium as gym
# except ImportError:
#     import gym

# from gym import Env


"""
    多进程实现:
        以子进程的方式运行游戏环境并与主进程交互数据
"""
class ChildEnv(Process):

    def __init__(self, id:int, env_fns:List[Callable[[], Env]], shared_data:Dict[str, ndarray], queue:Queue, barrier:Queue):
        super(ChildEnv, self).__init__()
        self.id = id    # 子进程的编号
        self.env_fns = env_fns

        self.actions = shared_data["actions"]
        self.rewards = shared_data["rewards"]
        self.next_obs = shared_data["next_obs"]
        self.next_obs_truncated = shared_data["next_obs_truncated"]
        self.terminated = shared_data["terminated"]
        self.truncated = shared_data["truncated"]

        self.queue = queue
        self.barrier = barrier

    def run(self):
        super(ChildEnv, self).run()

        envs = [env_fn() for env_fn in self.env_fns]
        for i, env in enumerate(envs):
            self.next_obs[i], _ = env.reset()
            # print(self.id, i, env)
        self.barrier.put(True)

        while True:
            instruction = self.queue.get()
            if instruction is None:
                break
            for i, (env, action) in enumerate(zip(envs, self.actions)):
                next_obs, reward, terminated, truncated, _ = env.step(action)
                if terminated:
                    self.next_obs[i], _ = env.reset()
                elif truncated:
                    self.next_obs_truncated[i] = next_obs
                    self.next_obs[i], _ = env.reset()
                else:
                    self.next_obs[i] = next_obs

                self.rewards[i] = reward
                self.terminated[i] = terminated
                self.truncated[i] = truncated

            self.barrier.put(True)







举报

相关推荐

0 条评论