"""Lyapunov actor critic policy.
This module contains a Pytorch implementation of the Q Critic policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
"""
import torch
import torch.nn as nn
from stable_learning_control.algos.pytorch.common.helpers import mlp
from stable_learning_control.utils.log_utils.helpers import log_to_std_out
[docs]class QCritic(nn.Module):
"""Soft Q critic network.
Attributes:
Q (torch.nn.Sequential): The layers of the network.
"""
def __init__(
self,
obs_dim,
act_dim,
hidden_sizes,
activation=nn.ReLU,
output_activation=nn.Identity,
):
"""Initialise the QCritic object.
Args:
obs_dim (int): Dimension of the observation space.
act_dim (int): Dimension of the action space.
hidden_sizes (list): Sizes of the hidden layers.
activation (:obj:`torch.nn.modules.activation`, optional): The activation
function. Defaults to :obj:`torch.nn.ReLU`.
output_activation (:obj:`torch.nn.modules.activation`, optional): The
activation function used for the output layers. Defaults to
:mod:`torch.nn.Identity`.
"""
super().__init__()
[docs] self.__device_warning_logged = False
[docs] self._obs_same_device = False
[docs] self._act_same_device = False
[docs] self.Q = mlp(
[obs_dim + act_dim] + list(hidden_sizes) + [1],
activation,
output_activation,
)
[docs] def forward(self, obs, act):
"""Perform forward pass through the network.
Args:
obs (torch.Tensor): The tensor of observations.
act (torch.Tensor): The tensor of actions.
Returns:
torch.Tensor:
The tensor containing the Q values of the input observations and
actions.
"""
# Make sure the observations and actions are on the right device.
self._obs_same_device = obs.device != self.Q[0].weight.device
self._act_same_device = act.device != self.Q[0].weight.device
if self._obs_same_device or self._act_same_device:
if not self.__device_warning_logged:
device_warn_strs = (
("observations and actions", obs.device)
if (self._obs_same_device or self._act_same_device)
else (
("observations", obs.device)
if self._obs_same_device
else ("actions", act.device)
)
)
device_warn_msg = (
"The {} were automatically moved from ".format(device_warn_strs[0])
+ "'{}' to '{}' during the '{}' forward pass.".format(
device_warn_strs[1],
self.Q[0].weight.device,
self.__class__.__name__,
)
+ "Please place your observations on the '{}' ".format(
self.Q[0].weight.device
)
+ "before calling the '{}' as converting them ".format(
self.__class__.__name__
)
+ "during the forward pass slows down the algorithm."
)
log_to_std_out(device_warn_msg, type="warning")
self.__device_warning_logged = True
obs = (
obs.to(self.L[0].weight.device)
if obs.device != self.Q[0].weight.device
else obs
)
act = (
act.to(self.L[0].weight.device)
if act.device != self.Q[0].weight.device
else act
)
return torch.squeeze(
self.Q(torch.cat([obs, act], dim=-1)), -1
) # NOTE: Squeeze is critical to ensure q has right shape.