Source code for stable_learning_control.algos.pytorch.policies.actors.squashed_gaussian_actor

"""Squashed Gaussian Actor policy.

This module contains a Pytorch implementation of the Squashed Gaussian Actor policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from stable_learning_control.algos.pytorch.common.helpers import mlp, rescale
from stable_learning_control.utils.log_utils.helpers import log_to_std_out


[docs]class SquashedGaussianActor(nn.Module): """The squashed gaussian actor network. Attributes: net (torch.nn.Sequential): The input/hidden layers of the network. mu (torch.nn.Linear): The output layer which returns the mean of the actions. log_std_layer (torch.nn.Linear): The output layer which returns the log standard deviation of the actions. act_limits (dict, optional): The ``high`` and ``low`` action bounds of the environment. Used for rescaling the actions that comes out of network from ``(-1, 1)`` to ``(low, high)``. No scaling will be applied if left empty. """ def __init__( self, obs_dim, act_dim, hidden_sizes, activation=nn.ReLU, output_activation=nn.ReLU, act_limits=None, log_std_min=-20, log_std_max=2.0, ): """Initialise the SquashedGaussianActor object. Args: obs_dim (int): Dimension of the observation space. act_dim (int): Dimension of the action space. hidden_sizes (list): Sizes of the hidden layers. activation (:obj:`torch.nn.modules.activation`): The activation function. Defaults to :obj:`torch.nn.ReLU`. output_activation (:obj:`torch.nn.modules.activation`, optional): The activation function used for the output layers. Defaults to :obj:`torch.nn.ReLU`. act_limits (dict): The ``high`` and ``low`` action bounds of the environment. Used for rescaling the actions that comes out of network from ``(-1, 1)`` to ``(low, high)``. log_std_min (int, optional): The minimum log standard deviation. Defaults to ``-20``. log_std_max (float, optional): The maximum log standard deviation. Defaults to ``2.0``. """ super().__init__() self.__device_warning_logged = False self.act_limits = act_limits self._log_std_min = log_std_min self._log_std_max = log_std_max self.net = mlp([obs_dim] + list(hidden_sizes), activation, output_activation) self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim) self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
[docs] def forward(self, obs, deterministic=False, with_logprob=True): """Perform forward pass through the network. Args: obs (torch.Tensor): The tensor of observations. deterministic (bool, optional): Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If false the action is sampled from the stochastic policy. Defaults to ``False``. with_logprob (bool, optional): Whether we want to return the log probability of an action. Defaults to ``True``. Returns: (tuple): tuple containing: - pi_action (:obj:`torch.Tensor`): The actions given by the policy. - logp_pi (:obj:`torch.Tensor`): The log probabilities of each of these actions. """ # noqa: E501 # Make sure the observations are on the right device. if obs.device != self.net[0].weight.device: if not self.__device_warning_logged: device_warn_msg = ( "The observations were automatically moved from " "'{}' to '{}' during the '{}' forward pass.".format( obs.device, self.net[0].weight.device, self.__class__.__name__ ) + "Please place your observations on the '{}' ".format( self.net[0].weight.device ) + "before calling the '{}' as converting them ".format( self.__class__.__name__ ) + "during the forward pass slows down the algorithm." ) log_to_std_out(device_warn_msg, type="warning") self.__device_warning_logged = True obs = obs.to(self.net[0].weight.device) # Calculate mean action and standard deviation. net_out = self.net(obs) mu = self.mu_layer(net_out) log_std = self.log_std_layer(net_out) log_std = torch.clamp(log_std, self._log_std_min, self._log_std_max) std = torch.exp(log_std) # Pre-squash distribution and sample pi_distribution = Normal(mu, std) if deterministic: # Only used for evaluating policy at test time. pi_action = mu else: pi_action = ( pi_distribution.rsample() ) # Sample while using the parameterization trick. # Compute logprob from Gaussian, and then apply correction for Tanh squashing. if with_logprob: # NOTE: The correction formula is a little bit magic. To get an # understanding of where it comes from, check out the original SAC paper # (arXiv 1801.01290) and look in appendix C. This is a more # numerically-stable equivalent to Eq 21. Try deriving it yourself as a # (very difficult) exercise. :) sum_axis = 0 if obs.shape.__len__() == 1 else 1 logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1) logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum( axis=sum_axis ) else: logp_pi = None # Calculate scaled action and return the action and its log probability. pi_action = torch.tanh(pi_action) # Squash gaussian to be between -1 and 1 # Rescale the normalized actions such that they are in range of the environment. if self.act_limits is not None: pi_action = rescale( pi_action, min_bound=self.act_limits["low"], max_bound=self.act_limits["high"], ) return pi_action, logp_pi
[docs] def act(self, obs, deterministic=False): """Returns the action from the current state given the current policy. Args: obs (torch.Tensor): The current observation (state). deterministic (bool, optional): Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If ``False`` the action is sampled from the stochastic policy. Defaults to ``False``. Returns: numpy.ndarray: The action from the current state given the current policy. """ with torch.no_grad(): a, _ = self(obs, deterministic, False) return a.cpu().numpy()
[docs] def get_action(self, obs, deterministic=False): """Simple warpper for making the :meth:`act` method available under the 'get_action' alias. Args: obs (torch.Tensor): The current observation (state). deterministic (bool, optional): Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If ``False`` the action is sampled from the stochastic policy. Defaults to ``False``. Returns: numpy.ndarray: The action from the current state given the current policy. """ return self.act(obs, deterministic=deterministic)