"""This file contains a small gymnasium wrapper that injects the `max_episode_steps`
argument of a potentially nested `TimeLimit` wrapper into the base environment under the
`_time_limit_max_episode_steps` attribute.
"""
import gymnasium as gym
[docs]def get_time_limit_wrapper_max_episode_steps(env):
    """Returns the ``max_episode_steps`` attribute of a potentially nested
    ``TimeLimit`` wrapper.
    Args:
        env (gym.Env): The gymnasium environment.
    Returns:
        int: The value of the ``max_episode_steps`` attribute of a potentially nested
            ``TimeLimit`` wrapper. If the environment is not wrapped in a ``TimeLimit``
            wrapper, then this function returns ``None``.
    """
    if hasattr(env, "env"):
        if isinstance(env, gym.wrappers.TimeLimit):
            return env._max_episode_steps
        get_time_limit_wrapper_max_episode_steps(env.env)
    return None 
[docs]def inject_attribute_into_base_env(env, attribute_name, attribute_value):
    """Injects the ``max_episode_steps`` argument into the base environment under the
    `_time_limit_max_episode_steps` attribute.
    Args:
        env (gym.Env): The gymnasium environment.
        attribute_name (str): The attribute's name to inject into the base
            environment.
        attribute_value (object): The attribute's value to inject into the base
            environment.
    """
    if hasattr(env, "env"):
        return inject_attribute_into_base_env(env.env, attribute_name, attribute_value)
    setattr(env, attribute_name, attribute_value) 
[docs]class MaxEpisodeStepsInjectionWrapper(gym.Wrapper):
    """A gymnasium wrapper that injects the ``max_episode_steps`` attribute of the
    ``TimeLimit`` wrapper into the base environment as the
    ``_time_limit_max_episode_steps`` attribute. If the environment is not wrapped in
    a ``TimeLimit`` wrapper, then the ``_time_limit_max_episode_steps`` attribute is
    set to ``None``.
    """
    def __init__(self, env):
        """Wrap a gymnasium environment.
        Args:
            env (gym.Env): The gymnasium environment.
        """
        super().__init__(env)
        # Retrieve max_episode_steps from potentially nested TimeLimit wrappers.
[docs]        max_episode_steps = get_time_limit_wrapper_max_episode_steps(self.env) 
        # Inject the max_episode_steps attribute into the base environment.
        inject_attribute_into_base_env(
            self.env, "_time_limit_max_episode_steps", max_episode_steps
        )