0
点赞
收藏
分享

微信扫一扫

baselines算法库common/vec_env/util.py模块分析

TiaNa_na 2022-05-19 阅读 33

util.py模块代码:

"""
Helpers for dealing with vectorized environments.
"""

from collections import OrderedDict

import gym
import numpy as np


def copy_obs_dict(obs):
"""
Deep-copy an observation dict.
"""
return {k: np.copy(v) for k, v in obs.items()}


def dict_to_obs(obs_dict):
"""
Convert an observation dict into a raw array if the
original observation space was not a Dict space.
"""
if set(obs_dict.keys()) == {None}:
return obs_dict[None]
return obs_dict


def obs_space_info(obs_space):
"""
Get dict-structured information about a gym.Space.

Returns:
A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.
"""
if isinstance(obs_space, gym.spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict)
subspaces = obs_space.spaces
elif isinstance(obs_space, gym.spaces.Tuple):
assert isinstance(obs_space.spaces, tuple)
subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
else:
subspaces = {None: obs_space}
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes


def obs_to_dict(obs):
"""
Convert an observation into a dict.
"""
if isinstance(obs, dict):
return obs
return {None: obs}




函数:

def copy_obs_dict(obs):
def obs_to_dict(obs_dict):

假设传入的observation都是dict类型的。

在函数

obs_to_dict

中,如果传入的observation不是dict类型的则将其转为dict类型,此时的key值设置为None。




函数

def dict_to_obs(obs_dict)

假设输入的是key为None的dict类型的observation,将其dict类型转为np.array类型的observation。

如果输入的不是key为None的dict类型的observation则直接将其返回。





函数

def obs_space_info(obs_space):

输入参数为observation的spaces变量。

if isinstance(obs_space, gym.spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict)
subspaces = obs_space.spaces
elif isinstance(obs_space, gym.spaces.Tuple):
assert isinstance(obs_space.spaces, tuple)
subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
else:
subspaces = {None: obs_space}

首先将env.observation_sapce.spaces变量进行判断,将其转为dict类型。



对env.observation_space.spaces进行信息提取,得到:

Returns:
A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.

最后返回信息的形式为tuple类型。





====================================================


举报

相关推荐

0 条评论