0
点赞
收藏
分享

微信扫一扫

gym库中类FilterObservation(ObservationWrapper)的理解


filter_observation.py模块中类

FilterObservation(ObservationWrapper) 的理解。



代码:


import copy

from gym import spaces
from gym import ObservationWrapper


class FilterObservation(ObservationWrapper):
"""Filter dictionary observations by their keys.

Args:
env: The environment to wrap.
filter_keys: List of keys to be included in the observations.

Raises:
ValueError: If observation keys in not instance of None or
iterable.
ValueError: If any of the `filter_keys` are not included in
the original `env`'s observation space

"""
def __init__(self, env, filter_keys=None):
super(FilterObservation, self).__init__(env)

wrapped_observation_space = env.observation_space
assert isinstance(wrapped_observation_space, spaces.Dict), (
"FilterObservationWrapper is only usable with dict observations.")

observation_keys = wrapped_observation_space.spaces.keys()

if filter_keys is None:
filter_keys = tuple(observation_keys)

missing_keys = set(
key for key in filter_keys if key not in observation_keys)

if missing_keys:
raise ValueError(
"All the filter_keys must be included in the "
"original obsrevation space.\n"
"Filter keys: {filter_keys}\n"
"Observation keys: {observation_keys}\n"
"Missing keys: {missing_keys}".format(
filter_keys=filter_keys,
observation_keys=observation_keys,
missing_keys=missing_keys,
))

self.observation_space = type(wrapped_observation_space)([
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
])

self._env = env
self._filter_keys = tuple(filter_keys)

def observation(self, observation):
filter_observation = self._filter_observation(observation)
return filter_observation

def _filter_observation(self, observation):
observation = type(observation)([
(name, value)
for name, value in observation.items()
if name in self._filter_keys
])
return observation





该类的一个前提要求是传入的内部env必须是状态空间属于spaces.Dict类的,如下:

assert isinstance(wrapped_observation_space, spaces.Dict), (
"FilterObservationWrapper is only usable with dict observations.")

该类的意思就是将传入的状态空间为spaces.Dict类型的env中的属于filter_keys的key保留下其他的不保留。

内部的包装类的所有key为env.observation_space.spaces.keys() 。



如果需要保留下来的key本身不存在与内部包装类中,则记录下来:

missing_keys = set(
key for key in filter_keys if key not in observation_keys)

并报错:

if missing_keys:
raise ValueError(
"All the filter_keys must be included in the "
"original obsrevation space.\n"
"Filter keys: {filter_keys}\n"
"Observation keys: {observation_keys}\n"
"Missing keys: {missing_keys}".format(
filter_keys=filter_keys,
observation_keys=observation_keys,
missing_keys=missing_keys,
))




每次获得内部类的observation后都按照过滤的key对其进行处理将过滤后的observation向上传递:

def observation(self, observation):
filter_observation = self._filter_observation(observation)
return filter_observation

def _filter_observation(self, observation):
observation = type(observation)([
(name, value)
for name, value in observation.items()
if name in self._filter_keys
])
return observation




=========================================


举报

相关推荐

0 条评论