stable_learning_control.algos.tf2.policies.soft_actor_critic
Soft actor critic policy.
This module contains a TensorFlow 2.x implementation of the Soft Actor Critic policy of Haarnoja et al. 2019.
Attributes
Classes
Soft Actor-Critic network. |
Module Contents
- class stable_learning_control.algos.tf2.policies.soft_actor_critic.SoftActorCritic(observation_space, action_space, hidden_sizes=HIDDEN_SIZES_DEFAULT, activation=ACTIVATION_DEFAULT, output_activation=OUTPUT_ACTIVATION_DEFAULT, name='soft_actor_critic')[source]
Bases:
tf.keras.ModelSoft Actor-Critic network.
- self.pi
The squashed gaussian policy network (actor).
- Type:
Initialise the SoftActorCritic object.
- Parameters:
observation_space (
gym.space.box.Box) – A gymnasium observation space.action_space (
gym.space.box.Box) – A gymnasium action space.hidden_sizes (Union[dict, tuple, list], optional) – Sizes of the hidden layers for the actor. Defaults to
(256, 256).activation (Union[
dict,tf.keras.activations], optional) – The (actor and critic) hidden layers activation function. Defaults totf.nn.relu.output_activation (Union[
dict,tf.keras.activations], optional) – The (actor and critic) output activation function. Defaults totf.nn.relufor the actor and the Identity function for the critic.name (str, optional) – The name given to the SoftActorCritic. Defaults to “soft_actor_critic”.
- call(inputs, deterministic=False, with_logprob=True)[source]
Performs a forward pass through all the networks (Actor, Q critic 1 and Q critic 2).
- Parameters:
inputs (tuple) –
tuple containing:
obs (tf.Tensor): The tensor of observations.
act (tf.Tensor): The tensor of actions.
deterministic (bool, optional) – Whether we want to use a deterministic policy (used at test time). When true the mean action of the stochastic policy is returned. If false the action is sampled from the stochastic policy. Defaults to
False.with_logprob (bool, optional) – Whether we want to return the log probability of an action. Defaults to
True.
- Returns:
tuple containing:
- Return type:
(tuple)
Note
Useful for when you want to print out the full network graph using TensorBoard.