Source code for stable_learning_control.utils.gym_utils

"""Contains utilities that can be used with the
`gymnasium package <https://gymnasium.farama.org/>`_.
"""

import importlib

import gymnasium as gym
from gymnasium import spaces

from stable_learning_control.common.helpers import friendly_err

# from textwrap import dedent


[docs]DISCRETE_SPACES = ( spaces.Discrete, spaces.MultiBinary, spaces.MultiDiscrete, )
[docs]CONTINUOUS_SPACES = (spaces.Box,)
[docs]def is_gym_env(env): """Checks whether object is a gymnasium environment. Args: env (object): A python object. Returns: bool: Boolean specifying whether object is gymnasium environment. """ return isinstance(env, gym.Env)
[docs]def is_continuous_space(space): """Checks whether a given space is continuous. Args: space (:obj:`gym.spaces`): The gymnasium space object. Returns: bool: Boolean specifying whether the space is discrete. """ return isinstance(space, CONTINUOUS_SPACES)
[docs]def is_discrete_space(space): """Checks whether a given space is discrete. Args: space (:obj:`gym.spaces`): The gymnasium space object. Returns: bool: Boolean specifying whether the space is discrete. """ return isinstance(space, DISCRETE_SPACES)
[docs]def validate_gym_env(arg_dict): """Make sure that env_name is a real, registered gymnasium environment. Args: cmd (dict): The cmd dictionary. Raises: AssertError: Raised when a environment is supplied that is not a valid gymnasium environment. """ # Special handling for environment: make sure that env_name is a real, # registered gymnasium environment. assert "env_name" in arg_dict, friendly_err( "You did not give a valid value for --env_name! Please try again." ) # Check if the environment is a valid gymnasium environment. for env_name in arg_dict["env_name"]: if ":" in env_name: # Try to import the custom gymnasium environment package. try: importlib.import_module(env_name.split(":")[0]) env_name = env_name.split(":")[1] except ImportError as e: raise ImportError( friendly_err( "Could not import custom gymnasium environment package: " + str(e) ) ) # Check if the environment is a valid gymnasium environment. if env_name not in gym.envs.registry: err_msg = friendly_err( """ %s is not registered with gymnasium. Recommendations: * Check for a typo (did you include the version tag?) * Gymnasium environments: View the complete list of valid gymnasium environments at https://gymnasium.farama.org/api/env/ * Custom environments: Ensure the custom environment is installed and you specify the module prefix (e.g. `custom_module:env_name`). """ % (env_name) ) assert False, err_msg