Source code for stable_gym.common.max_episode_steps_injection_wrapper

"""This file contains a small gymnasium wrapper that injects the `max_episode_steps`
argument of a potentially nested `TimeLimit` wrapper into the base environment under the
`_time_limit_max_episode_steps` attribute.
"""

import gymnasium as gym


[docs]def get_time_limit_wrapper_max_episode_steps(env): """Returns the ``max_episode_steps`` attribute of a potentially nested ``TimeLimit`` wrapper. Args: env (gym.Env): The gymnasium environment. Returns: int: The value of the ``max_episode_steps`` attribute of a potentially nested ``TimeLimit`` wrapper. If the environment is not wrapped in a ``TimeLimit`` wrapper, then this function returns ``None``. """ if hasattr(env, "env"): if isinstance(env, gym.wrappers.TimeLimit): return env._max_episode_steps get_time_limit_wrapper_max_episode_steps(env.env) return None
[docs]def inject_attribute_into_base_env(env, attribute_name, attribute_value): """Injects the ``max_episode_steps`` argument into the base environment under the `_time_limit_max_episode_steps` attribute. Args: env (gym.Env): The gymnasium environment. attribute_name (str): The attribute's name to inject into the base environment. attribute_value (object): The attribute's value to inject into the base environment. """ if hasattr(env, "env"): return inject_attribute_into_base_env(env.env, attribute_name, attribute_value) setattr(env, attribute_name, attribute_value)
[docs]class MaxEpisodeStepsInjectionWrapper(gym.Wrapper): """A gymnasium wrapper that injects the ``max_episode_steps`` attribute of the ``TimeLimit`` wrapper into the base environment as the ``_time_limit_max_episode_steps`` attribute. If the environment is not wrapped in a ``TimeLimit`` wrapper, then the ``_time_limit_max_episode_steps`` attribute is set to ``None``. """ def __init__(self, env): """Wrap a gymnasium environment. Args: env (gym.Env): The gymnasium environment. """ super().__init__(env) # Retrieve max_episode_steps from potentially nested TimeLimit wrappers.
[docs] max_episode_steps = get_time_limit_wrapper_max_episode_steps(self.env)
# Inject the max_episode_steps attribute into the base environment. inject_attribute_into_base_env( self.env, "_time_limit_max_episode_steps", max_episode_steps )