Source code for stable_learning_control.utils.test_policy

"""A set of functions that can be used to see a algorithm perform in the environment
it was trained on.
"""

import glob
import os
import os.path as osp
import re
from pathlib import Path

import joblib
import torch

from stable_learning_control.common.exceptions import EnvLoadError, PolicyLoadError
from stable_learning_control.common.helpers import friendly_err, get_env_id
from stable_learning_control.utils.import_utils import import_tf
from stable_learning_control.utils.log_utils.helpers import log_to_std_out
from stable_learning_control.utils.log_utils.logx import EpochLogger
from stable_learning_control.utils.serialization_utils import load_from_json


[docs]def _retrieve_iter_folder(fpath, itr): """Retrieves the path of the requested model iteration. Args: fpath (str): The path where the model is found. itr (int): The current policy iteration (checkpoint). Raises: IOError: Raised if the model is corrupt. FileNotFoundError: Raised if the model path did not exist. Returns: str: The model iteration path. """ cpath = Path(fpath).joinpath("checkpoints", f"iter{itr}") if osp.isdir(cpath): log_to_std_out(f"Using model iteration {itr}.", type="info") return cpath else: log_to_std_out( f"Iteration {itr} not found inside the supplied model path '{fpath}'. " "Last iteration used instead.", type="warning", ) return fpath
[docs]def _retrieve_model_folder(fpath): """Tries to retrieve the model folder and backend from the given path. Args: fpath (str): The path where the model is found. Raises: IOError: Raised if the model is corrupt. FileNotFoundError: Raised if the model path did not exist. Returns: (tuple): tuple containing: - model_folder (:obj:`func`): The model folder. - backend (:obj:`str`): The inferred backend. Options are ``tf2`` and ``torch``. """ data_folders = ( glob.glob(fpath + r"/*_save") if not bool(re.search(r"/*_save", fpath)) else glob.glob(fpath) ) if any(["tf2_save" in item for item in data_folders]) and any( ["torch_save" in item for item in data_folders] ): raise IOError( friendly_err( f"Policy could not be loaded since the specified model folder " f"'{fpath}' seems to be corrupted. It contains both a 'torch_save' and " "'tf2_save' folder. Please check your model path (fpath) and try again." ) ) elif ( len([item for item in data_folders if "tf2_save" in item]) > 1 or len([item for item in data_folders if "torch_save" in item]) > 1 ): raise IOError( friendly_err( "Policy could not be loaded since the specified model folder '{}' " "seems to be corrupted. It contains multiple '{}' folders. Please " "check your model path (fpath) and try again.".format( fpath, ( "tf2_save" if any(["tf2_save" in item for item in data_folders]) else "torch_save" ), ) ) ) elif any(["tf2_save" in item for item in data_folders]): model_path = [item for item in data_folders if "tf2_save" in item][0] return model_path, "tf2" elif any(["torch_save" in item for item in data_folders]): model_path = [item for item in data_folders if "torch_save" in item][0] return model_path, "torch" else: raise FileNotFoundError( friendly_err( f"No model was found inside the supplied model path '{fpath}'. " "Please check your model path (fpath) and try again." ) )
[docs]def load_policy_and_env(fpath, itr="last"): """Load a policy from save, whether it's TF or PyTorch, along with RL env. Args: fpath (str): The path where the model is found. itr (str, optional): The current policy iteration (checkpoint). Defaults to ``last``. deterministic (bool, optional): Whether you want the action from the policy to be deterministic. Defaults to ``False``. Raises: FileNotFoundError: Thrown when the fpath does not exist. EnvLoadError: Thrown when something went wrong trying to load the saved environment. PolicyLoadError: Thrown when something went wrong trying to load the saved policy. Returns: (tuple): tuple containing: - env (:obj:`gym.env`): The gymnasium environment. - get_action (:obj:`func`): The policy get_action function. """ if not os.path.isdir(fpath): raise FileNotFoundError( friendly_err( f"The model folder you specified '{fpath}' does not exist. Please " "specify a valid model folder (fpath) and try again." ) ) # Retrieve model path and backend. fpath, backend = _retrieve_model_folder(fpath) if itr != "last": assert isinstance(itr, int), friendly_err( "Bad value provided for itr (needs to be int or 'last')." ) itr = "%d" % itr # try to load environment from save. # NOTE: Sometimes this will fail because the environment could not be pickled. try: state = joblib.load(Path(fpath).parent.joinpath("vars.pkl")) env = state["env"] except Exception as e: raise EnvLoadError( friendly_err( ( "Environment not found!\n\n It looks like the environment wasn't " "saved, and we can't run the agent in it. :( \n\n Check out the " "documentation page on the Test Policy utility for how to handle " "this situation." ) ) ) from e # load the get_action function. try: if backend.lower() == "tf2": policy = load_tf_policy(fpath, env=env, itr=itr) else: policy = load_pytorch_policy(fpath, env=env, itr=itr) except Exception as e: raise PolicyLoadError( friendly_err( ( "Policy could not be loaded!\n\n It looks like the policy wasn't " "successfully saved. :( \n\n Check out the documentation page on " "the Test Policy utility for how to handle this situation." ) ) ) from e return env, policy
[docs]def load_tf_policy(fpath, env, itr="last"): """Load a TensorFlow policy saved with Stable learning control Logger. Args: fpath (str): The path where the model is found. env (:obj:`gym.env`): The gymnasium environment in which you want to test the policy. itr (str, optional): The current policy iteration. Defaults to "last". Returns: tf.keras.Model: The policy. """ if itr != "last": model_path = _retrieve_iter_folder(fpath, itr) else: model_path = fpath tf = import_tf() # Throw custom warning if tf is not installed. print("\n") log_to_std_out("Loading model from '%s'.\n\n" % fpath, type="info") # Retrieve get_action method. save_info = load_from_json(Path(fpath).joinpath("save_info.json")) import stable_learning_control.algos.tf2 as tf2_algos try: ac_kwargs = {"ac_kwargs": save_info["setup_kwargs"]["ac_kwargs"]} except KeyError: ac_kwargs = {} model = getattr(tf2_algos, save_info["alg_name"])(env=env, **ac_kwargs) latest = tf.train.latest_checkpoint(model_path) # Restore latest checkpoint. model.load_weights(latest) return model
[docs]def load_pytorch_policy(fpath, env, itr="last"): """Load a pytorch policy saved with Stable Learning Control Logger. Args: fpath (str): The path where the model is found. env (:obj:`gym.env`): The gymnasium environment in which you want to test the policy. itr (str, optional): The current policy iteration. Defaults to "last". Returns: torch.nn.Module: The policy. """ fpath, _ = _retrieve_model_folder(fpath) if itr != "last": fpath = _retrieve_iter_folder(fpath, itr) model_file = Path(fpath).joinpath( "model_state.pt", ) print("\n") log_to_std_out("Loading model from '%s'.\n\n" % model_file, type="info") # Retrieve get_action method. save_info = load_from_json(Path(fpath).joinpath("save_info.json")) import stable_learning_control.algos.pytorch as torch_algos model_data = torch.load(model_file) try: ac_kwargs = {"ac_kwargs": save_info["setup_kwargs"]["ac_kwargs"]} except KeyError: ac_kwargs = {} model = getattr(torch_algos, save_info["alg_name"])(env=env, **ac_kwargs) model.load_state_dict(model_data) # Retore model parameters. return model
[docs]def run_policy( env, policy, max_ep_len=None, num_episodes=100, render=True, deterministic=True ): """Evaluates a policy inside a given gymnasium environment. Args: env (:obj:`gym.env`): The gymnasium environment. policy (Union[tf.keras.Model, torch.nn.Module]): The policy. max_ep_len (int, optional): The maximum episode length. Defaults to ``None``. num_episodes (int, optional): Number of episodes you want to perform in the environment. Defaults to 100. deterministic (bool, optional): Whether you want the action from the policy to be deterministic. Defaults to ``True``. render (bool, optional): Whether you want to render the episode to the screen. Defaults to ``True``. """ logger = EpochLogger(verbose_fmt="table") assert env is not None, friendly_err( "Environment not found!\n\n It looks like the environment wasn't saved, and we " "can't run the agent in it. :( \n\n Check out the documentation page on the " "Test Policy utility for how to handle this situation." ) assert env is not None, friendly_err( "Policy not found!\n\n It looks like the policy could not be loaded. :( \n\n " "Check out the documentation page on the Test Policy utility for how to " "handle this situation." ) # Apply episode length and set render mode. if max_ep_len is not None and max_ep_len != 0: if max_ep_len > env.env._max_episode_steps: logger.log( ( f"You defined your 'max_ep_len' to be {max_ep_len} " "while the environment 'max_episode_steps' is " f"{env.env._max_episode_steps}. As a result the environment " f"'max_episode_steps' has been increased to {max_ep_len}" ), type="warning", ) env.env._max_episode_steps = max_ep_len # Set render mode. if render: render_modes = env.unwrapped.metadata.get("render_modes", []) if render_modes: env.unwrapped.render_mode = "human" if "human" in render_modes else None else: logger.log( ( f"Nothing was rendered since the '{get_env_id(env)}' " f"environment does not contain a 'human' render mode." ), type="warning", ) # Perform episodes. o, _ = env.reset() r, d, ep_ret, ep_len, n = 0, False, 0, 0, 0 supports_deterministic = True # Only supported with gaussian algorithms. while n < num_episodes: # Retrieve action. if deterministic and supports_deterministic: try: a = policy.get_action(o, deterministic=deterministic) except TypeError: supports_deterministic = False logger.log( "Input argument 'deterministic' ignored as the algorithm does " "not support deterministic actions. This is only supported for " "gaussian algorithms.", type="warning", ) a = policy.get_action(o) else: a = policy.get_action(o) # Perform action in the environment and store result. o, r, d, truncated, _ = env.step(a) ep_ret += r ep_len += 1 if d or truncated: logger.store(EpRet=ep_ret, EpLen=ep_len) logger.log("Episode %d \t EpRet %.3f \t EpLen %d" % (n, ep_ret, ep_len)) o, _ = env.reset() r, d, ep_ret, ep_len = 0, False, 0, 0 n += 1 print("") logger.log_tabular("EpRet", with_min_and_max=True) logger.log_tabular("EpLen", average_only=True) logger.dump_tabular()
if __name__ == "__main__": import argparse
[docs] parser = argparse.ArgumentParser()
parser.add_argument("fpath", type=str, help="The path where the policy is stored") parser.add_argument( "--len", "-l", type=int, default=0, help="The episode length (default: 0)" ) parser.add_argument( "--episodes", "-n", type=int, default=100, help="The number of episodes you want to run per disturbance (default: 100)", ) parser.add_argument( "--norender", "-nr", action="store_true", help="Whether you want to render the environment step (default: False)", ) parser.add_argument( "--itr", "-i", type=int, default=-1, help="The policy iteration (epoch) you want to use (default: last)", ) parser.add_argument( "--deterministic", "-d", action="store_true", help=( "Whether you want to use a deterministic policy. Only available for " "Gaussian policies (default: False)" ), ) args = parser.parse_args() env, policy = load_policy_and_env(args.fpath, args.itr if args.itr >= 0 else "last") run_policy(env, policy, args.len, args.episodes, not (args.norender))