stable_learning_control.algos.pytorch.policies.soft_actor_critic
This module contains a Pytorch implementation of the Soft Actor Critic policy of Haarnoja et al. 2019.
Attributes
Classes
Soft Actor-Critic network. |
Module Contents
- class stable_learning_control.algos.pytorch.policies.soft_actor_critic.SoftActorCritic(observation_space, action_space, hidden_sizes=HIDDEN_SIZES_DEFAULT, activation=ACTIVATION_DEFAULT, output_activation=OUTPUT_ACTIVATION_DEFAULT)[source]
Bases:
torch.nn.ModuleSoft Actor-Critic network.
- self.pi
The squashed gaussian policy network (actor).
- Type:
Initialise the SoftActorCritic object.
- Parameters:
observation_space (
gym.space.box.Box) – A gymnasium observation space.action_space (
gym.space.box.Box) – A gymnasium action space.hidden_sizes (Union[dict, tuple, list], optional) – Sizes of the hidden layers for the actor. Defaults to
(256, 256).activation (Union[
dict,torch.nn.modules.activation], optional) – The (actor and critic) hidden layers activation function. Defaults totorch.nn.ReLU.output_activation (Union[
dict,torch.nn.modules.activation], optional) – The (actor and critic) output activation function. Defaults totorch.nn.ReLUfor the actor and nn.Identity for the critic.
- forward(obs, act, deterministic=False, with_logprob=True)[source]
Performs a forward pass through all the networks (Actor, Q critic 1 and Q critic 2).
- Parameters:
obs (torch.Tensor) – The tensor of observations.
act (torch.Tensor) – The tensor of actions.
deterministic (bool, optional) – Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If false the action is sampled from the stochastic policy. Defaults to
False.with_logprob (bool, optional) – Whether we want to return the log probability of an action. Defaults to
True.
- Returns:
tuple containing:
pi_action (
torch.Tensor): The actions given by the policy.logp_pi (
torch.Tensor): The log probabilities of each of these actions.Q1(
torch.Tensor): Q-values of the first critic.Q2(
torch.Tensor): Q-values of the second critic.
- Return type:
(tuple)
Note
Useful for when you want to print out the full network graph using TensorBoard.
- act(obs, deterministic=False)[source]
Returns the action from the current state given the current policy.
- Parameters:
obs (torch.Tensor) – The current observation (state).
deterministic (bool, optional) – Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If
Falsethe action is sampled from the stochastic policy. Defaults toFalse.
- Returns:
The action from the current state given the current policy.
- Return type: