filter_observation.py模块中类
FilterObservation(ObservationWrapper) 的理解。
代码:
import copy
from gym import spaces
from gym import ObservationWrapper
class FilterObservation(ObservationWrapper):
"""Filter dictionary observations by their keys.
Args:
env: The environment to wrap.
filter_keys: List of keys to be included in the observations.
Raises:
ValueError: If observation keys in not instance of None or
iterable.
ValueError: If any of the `filter_keys` are not included in
the original `env`'s observation space
"""
def __init__(self, env, filter_keys=None):
super(FilterObservation, self).__init__(env)
wrapped_observation_space = env.observation_space
assert isinstance(wrapped_observation_space, spaces.Dict), (
"FilterObservationWrapper is only usable with dict observations.")
observation_keys = wrapped_observation_space.spaces.keys()
if filter_keys is None:
filter_keys = tuple(observation_keys)
missing_keys = set(
key for key in filter_keys if key not in observation_keys)
if missing_keys:
raise ValueError(
"All the filter_keys must be included in the "
"original obsrevation space.\n"
"Filter keys: {filter_keys}\n"
"Observation keys: {observation_keys}\n"
"Missing keys: {missing_keys}".format(
filter_keys=filter_keys,
observation_keys=observation_keys,
missing_keys=missing_keys,
))
self.observation_space = type(wrapped_observation_space)([
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
])
self._env = env
self._filter_keys = tuple(filter_keys)
def observation(self, observation):
filter_observation = self._filter_observation(observation)
return filter_observation
def _filter_observation(self, observation):
observation = type(observation)([
(name, value)
for name, value in observation.items()
if name in self._filter_keys
])
return observation
该类的一个前提要求是传入的内部env必须是状态空间属于spaces.Dict类的,如下:
assert isinstance(wrapped_observation_space, spaces.Dict), (
"FilterObservationWrapper is only usable with dict observations.")
该类的意思就是将传入的状态空间为spaces.Dict类型的env中的属于filter_keys的key保留下其他的不保留。
内部的包装类的所有key为env.observation_space.spaces.keys() 。
如果需要保留下来的key本身不存在与内部包装类中,则记录下来:
missing_keys = set(
key for key in filter_keys if key not in observation_keys)
并报错:
if missing_keys:
raise ValueError(
"All the filter_keys must be included in the "
"original obsrevation space.\n"
"Filter keys: {filter_keys}\n"
"Observation keys: {observation_keys}\n"
"Missing keys: {missing_keys}".format(
filter_keys=filter_keys,
observation_keys=observation_keys,
missing_keys=missing_keys,
))
每次获得内部类的observation后都按照过滤的key对其进行处理将过滤后的observation向上传递:
def observation(self, observation):
filter_observation = self._filter_observation(observation)
return filter_observation
def _filter_observation(self, observation):
observation = type(observation)([
(name, value)
for name, value in observation.items()
if name in self._filter_keys
])
return observation
=========================================