Soft Actor-Critic

Important

The SAC algorithm has no stability guarantees. Please use the LAC algorithm if you require stability guarantees.

Background

Soft Actor-Critic (SAC) is an algorithm that optimises a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimisation and DDPG-style approaches. It isn’t a direct successor to TD3 (having been published roughly concurrently). Still, it incorporates the clipped double-Q trick, and due to the inherent stochasticity of the policy in SAC, it also winds up benefiting from something like target policy smoothing.

A central feature of SAC is entropy regularisation. The policy is trained to maximise a trade-off between expected return and entropy, a measure of randomness in the policy. This is closely connected to the exploration-exploitation trade-off: increasing entropy results in more exploration, which can accelerate learning later on. It can also prevent the policy from prematurely converging to a bad local optimum.

Quick Facts

  • SAC is an off-policy algorithm.

  • The version of SAC implemented here can only be used for environments with continuous action spaces.

  • An alternate version of SAC, which slightly changes the policy update rule, can be implemented to handle discrete action spaces.

  • The SLC implementation of SAC does not support parallelisation.

Further Reading

The version implemented here was based on the version implemented in the SpinningUp repository. For more information on the SAC algorithm, you are referred to the SpinningUp documentation or the original paper of Haarnoja et al., 2019. Our implementation slightly differs from the SpinningUp version in that we also added the Automatic Entropy Tuning scheme introduced by Haarnoja et al., 2019. As a result, during training, the entropy Lagrange Multiplier \alpha is updated by

\alpha \leftarrow \max(0, \alpha + \delta \bigtriangledown_{\alpha}J(\alpha)))

where \delta is the learning rate. As explained in Haarnoja et al., 2019, this constrains the policy’s average entropy.

Implementation

You Should Know

In what follows, we give documentation for the PyTorch and TensorFlow implementations of SAC in SLC. They have nearly identical function calls and docstrings, except for details relating to model construction. However, we include both full docstrings for completeness.

Algorithm: PyTorch Version

stable_learning_control.algos.pytorch.sac.sac(env_fn, actor_critic=None, ac_kwargs={'activation': {'actor': <class 'torch.nn.modules.activation.ReLU'>, 'critic': <class 'torch.nn.modules.activation.ReLU'>}, 'hidden_sizes': {'actor': [256, 256], 'critic': [256, 256]}, 'output_activation': {'actor': <class 'torch.nn.modules.activation.ReLU'>, 'critic': <class 'torch.nn.modules.linear.Identity'>}}, opt_type='maximize', max_ep_len=None, epochs=100, steps_per_epoch=2048, start_steps=0, update_every=100, update_after=1000, steps_per_update=100, num_test_episodes=10, alpha=0.99, gamma=0.99, polyak=0.995, target_entropy=None, adaptive_temperature=True, lr_a=0.0001, lr_c=0.0003, lr_alpha=0.0001, lr_a_final=1e-10, lr_c_final=1e-10, lr_alpha_final=1e-10, lr_decay_type='linear', lr_a_decay_type=None, lr_c_decay_type=None, lr_alpha_decay_type=None, lr_decay_ref='epoch', batch_size=256, replay_size=1000000, seed=None, device='cpu', logger_kwargs={}, save_freq=1, start_policy=None, export=False)

Trains the SAC algorithm in a given environment.

Parameters:
  • env_fn – A function which creates a copy of the environment. The environment must satisfy the gymnasium API.

  • actor_critic (torch.nn.Module, optional) –

    The constructor method for a Torch Module with an act method, a pi module and several Q or L modules. The act method and pi module should accept batches of observations as inputs, and the Q* and L modules should accept a batch of observations and a batch of actions as inputs. When called, these modules should return:

    Call

    Output Shape

    Description

    act

    (batch, act_dim)

    Numpy array of actions for each
    observation.

    Q*/L

    (batch,)

    Tensor containing one current estimate
    of Q*/L for the provided
    observations and actions. (Critical:
    make sure to flatten this!)

    Calling pi should return:

    Symbol

    Shape

    Description

    a

    (batch, act_dim)

    Tensor containing actions from policy
    given observations.

    logp_pi

    (batch,)

    Tensor containing log probabilities of
    actions in a. Importantly:
    gradients should be able to flow back
    into a.

    Defaults to SoftActorCritic

  • ac_kwargs (dict, optional) –

    Any kwargs appropriate for the ActorCritic object you provided to SAC. Defaults to:

    Kwarg

    Value

    hidden_sizes_actor

    64 x 2

    hidden_sizes_critic

    128 x 2

    activation

    torch.nn.ReLU

    output_activation

    torch.nn.ReLU

  • opt_type (str, optional) – The optimization type you want to use. Options maximize and minimize. Defaults to maximize.

  • max_ep_len (int, optional) – Maximum length of trajectory / episode / rollout. Defaults to the environment maximum.

  • epochs (int, optional) – Number of epochs to run and train agent. Defaults to 100.

  • steps_per_epoch (int, optional) – Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch. Defaults to 2048.

  • start_steps (int, optional) – Number of steps for uniform-random action selection, before running real policy. Helps exploration. Defaults to 0.

  • update_every (int, optional) – Number of env interactions that should elapse between gradient descent updates. Defaults to 100.

  • update_after (int, optional) – Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates. Defaults to 1000.

  • steps_per_update (int, optional) – Number of gradient descent steps that are performed for each gradient descent update. This determines the ratio of env steps to gradient steps (i.e. update_every/ steps_per_update). Defaults to 100.

  • num_test_episodes (int, optional) – Number of episodes used to test the deterministic policy at the end of each epoch. This is used for logging the performance. Defaults to 10.

  • alpha (float, optional) – Entropy regularization coefficient (Equivalent to inverse of reward scale in the original SAC paper). Defaults to 0.99.

  • gamma (float, optional) – Discount factor. (Always between 0 and 1.). Defaults to 0.99.

  • polyak (float, optional) –

    Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to:

    \theta_{\text{targ}} \leftarrow
\rho \theta_{\text{targ}} + (1-\rho) \theta

    where \rho is polyak (Always between 0 and 1, usually close to 1.). In some papers \rho is defined as (1 - \tau) where \tau is the soft replacement factor. Defaults to 0.995.

  • target_entropy (float, optional) –

    Initial target entropy used while learning the entropy temperature (alpha). Defaults to the maximum information (bits) contained in action space. This can be calculated according to :

    -{\prod }_{i=0}^{n}action\_di{m}_{i}\phantom{\rule{0ex}{0ex}}

  • adaptive_temperature (bool, optional) – Enabled Automating Entropy Adjustment for maximum Entropy RL_learning.

  • lr_a (float, optional) – Learning rate used for the actor. Defaults to 1e-4.

  • lr_c (float, optional) – Learning rate used for the (soft) critic. Defaults to 1e-4.

  • lr_alpha (float, optional) – Learning rate used for the entropy temperature. Defaults to 1e-4.

  • lr_a_final (float, optional) – The final actor learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_c_final (float, optional) – The final critic learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_alpha_final (float, optional) – The final alpha learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_decay_type (str, optional) – The learning rate decay type that is used (options are: linear and exponential and constant). Defaults to linear. Can be overridden by the specific learning rate decay types.

  • lr_a_decay_type (str, optional) – The learning rate decay type that is used for the actor learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_c_decay_type (str, optional) – The learning rate decay type that is used for the critic learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_alpha_decay_type (str, optional) – The learning rate decay type that is used for the alpha learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_decay_ref (str, optional) – The reference variable that is used for decaying the learning rate (options: epoch and step). Defaults to epoch.

  • batch_size (int, optional) – Minibatch size for SGD. Defaults to 256.

  • replay_size (int, optional) – Maximum length of replay buffer. Defaults to 1e6.

  • seed (int) – Seed for random number generators. Defaults to None.

  • device (str, optional) – The device the networks are placed on (options: cpu, gpu, gpu:0, gpu:1, etc.). Defaults to cpu.

  • logger_kwargs (dict, optional) – Keyword args for EpochLogger.

  • save_freq (int, optional) – How often (in terms of gap between epochs) to save the current policy and value function.

  • start_policy (str) – Path of a already trained policy to use as the starting point for the training. By default a new policy is created.

  • export (bool) – Whether you want to export the model as a TorchScript such that it can be deployed on hardware. By default False.

Returns:

tuple containing:

Return type:

(tuple)

Saved Model Contents: PyTorch Version

The PyTorch version of the SAC algorithm is implemented by subclassing the torch.nn.Module class. As a result, the model weights are saved using the model_state dictionary ( state_dict). These saved weights can be found in the torch_save/model_state.pt file. For an example of how to load a model using this file, see Experiment Outputs or the PyTorch documentation.

Algorithm: TensorFlow Version

Attention

The TensorFlow version is still experimental. It is not guaranteed to work, and it is not guaranteed to be up-to-date with the PyTorch version.

stable_learning_control.algos.tf2.sac.sac(env_fn, actor_critic=None, ac_kwargs={'activation': {'actor': <function relu>, 'critic': <function relu>}, 'hidden_sizes': {'actor': [256, 256], 'critic': [256, 256]}, 'output_activation': {'actor': <function relu>, 'critic': None}}, opt_type='maximize', max_ep_len=None, epochs=100, steps_per_epoch=2048, start_steps=0, update_every=100, update_after=1000, steps_per_update=100, num_test_episodes=10, alpha=0.99, gamma=0.99, polyak=0.995, target_entropy=None, adaptive_temperature=True, lr_a=0.0001, lr_c=0.0003, lr_alpha=0.0001, lr_a_final=1e-10, lr_c_final=1e-10, lr_alpha_final=1e-10, lr_decay_type='linear', lr_a_decay_type=None, lr_c_decay_type=None, lr_alpha_decay_type=None, lr_decay_ref='epoch', batch_size=256, replay_size=1000000, seed=None, device='cpu', logger_kwargs={}, save_freq=1, start_policy=None, export=False)

Trains the SAC algorithm in a given environment.

Parameters:
  • env_fn – A function which creates a copy of the environment. The environment must satisfy the gymnasium API.

  • actor_critic (tf.Module, optional) –

    The constructor method for a TensorFlow Module with an act method, a pi module and several Q or L modules. The act method and pi module should accept batches of observations as inputs, and the Q* and L modules should accept a batch of observations and a batch of actions as inputs. When called, these modules should return:

    Call

    Output Shape

    Description

    act

    (batch, act_dim)

    Numpy array of actions for each
    observation.

    Q*/L

    (batch,)

    Tensor containing one current estimate
    of Q*/L for the provided
    observations and actions. (Critical:
    make sure to flatten this!)

    Calling pi should return:

    Symbol

    Shape

    Description

    a

    (batch, act_dim)

    Tensor containing actions from policy
    given observations.

    logp_pi

    (batch,)

    Tensor containing log probabilities of
    actions in a. Importantly:
    gradients should be able to flow back
    into a.

    Defaults to SoftActorCritic

  • ac_kwargs (dict, optional) –

    Any kwargs appropriate for the ActorCritic object you provided to SAC. Defaults to:

    Kwarg

    Value

    hidden_sizes_actor

    64 x 2

    hidden_sizes_critic

    128 x 2

    activation

    tf.nn.relu

    output_activation

    tf.nn.relu

  • opt_type (str, optional) – The optimization type you want to use. Options maximize and minimize. Defaults to maximize.

  • max_ep_len (int, optional) – Maximum length of trajectory / episode / rollout. Defaults to the environment maximum.

  • epochs (int, optional) – Number of epochs to run and train agent. Defaults to 100.

  • steps_per_epoch (int, optional) – Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch. Defaults to 2048.

  • start_steps (int, optional) – Number of steps for uniform-random action selection, before running real policy. Helps exploration. Defaults to 0.

  • update_every (int, optional) – Number of env interactions that should elapse between gradient descent updates. Defaults to 100.

  • update_after (int, optional) – Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates. Defaults to 1000.

  • steps_per_update (int, optional) – Number of gradient descent steps that are performed for each gradient descent update. This determines the ratio of env steps to gradient steps (i.e. update_every/ steps_per_update). Defaults to 100.

  • num_test_episodes (int, optional) – Number of episodes used to test the deterministic policy at the end of each epoch. This is used for logging the performance. Defaults to 10.

  • alpha (float, optional) – Entropy regularization coefficient (Equivalent to inverse of reward scale in the original SAC paper). Defaults to 0.99.

  • gamma (float, optional) – Discount factor. (Always between 0 and 1.). Defaults to 0.99.

  • polyak (float, optional) –

    Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to:

    \theta_{\text{targ}} \leftarrow
\rho \theta_{\text{targ}} + (1-\rho) \theta

    where \rho is polyak (Always between 0 and 1, usually close to 1.). In some papers \rho is defined as (1 - \tau) where \tau is the soft replacement factor. Defaults to 0.995.

  • target_entropy (float, optional) –

    Initial target entropy used while learning the entropy temperature (alpha). Defaults to the maximum information (bits) contained in action space. This can be calculated according to :

    -{\prod }_{i=0}^{n}action\_di{m}_{i}\phantom{\rule{0ex}{0ex}}

  • adaptive_temperature (bool, optional) – Enabled Automating Entropy Adjustment for maximum Entropy RL_learning.

  • lr_a (float, optional) – Learning rate used for the actor. Defaults to 1e-4.

  • lr_c (float, optional) – Learning rate used for the (soft) critic. Defaults to 1e-4.

  • lr_alpha (float, optional) – Learning rate used for the entropy temperature. Defaults to 1e-4.

  • lr_a_final (float, optional) – The final actor learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_c_final (float, optional) – The final critic learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_decay_type (str, optional) – The learning rate decay type that is used ( options are: linear and exponential and constant). Defaults to linear.

  • lr_alpha_final (float, optional) – The final alpha learning rate that is achieved at the end of the training. Defaults to 1e-10.

  • lr_decay_type – The learning rate decay type that is used (options are: linear and exponential and constant). Defaults to linear.Can be overridden by the specific learning rate decay types.

  • lr_a_decay_type (str, optional) – The learning rate decay type that is used for the actor learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_c_decay_type (str, optional) – The learning rate decay type that is used for the critic learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_alpha_decay_type (str, optional) – The learning rate decay type that is used for the alpha learning rate (options are: linear and exponential and constant). If not specified, the general learning rate decay type is used.

  • lr_decay_ref (str, optional) – The reference variable that is used for decaying the learning rate (options: epoch and step). Defaults to epoch.

  • batch_size (int, optional) – Minibatch size for SGD. Defaults to 256.

  • replay_size (int, optional) – Maximum length of replay buffer. Defaults to 1e6.

  • seed (int) – Seed for random number generators. Defaults to None.

  • device (str, optional) – The device the networks are placed on (options: cpu, gpu, gpu:0, gpu:1, etc.). Defaults to cpu.

  • logger_kwargs (dict, optional) – Keyword args for EpochLogger.

  • save_freq (int, optional) – How often (in terms of gap between epochs) to save the current policy and value function.

  • start_policy (str) – Path of a already trained policy to use as the starting point for the training. By default a new policy is created.

  • export (bool) – Whether you want to export the model in the SavedModel format such that it can be deployed to hardware. By default False.

Returns:

tuple containing:

Return type:

(tuple)

Saved Model Contents: TensorFlow Version

The TensorFlow version of the SAC algorithm is implemented by subclassing the tf.nn.Model class. As a result, both the full model and the current model weights are saved. The full model can be found in the saved_model.pb file, while the current weights checkpoints are in the tf_safe/weights_checkpoint* file. For an example of using these two methods, see Experiment Outputs or the TensorFlow documentation.

References

Relevant Papers

Other Public Implementations