stable_learning_control.utils.mpi_utils.mpi_tf2

Helper methods for managing TF2 MPI processes.

Note

This module is not yet translated to TF2. It is not used by any of the current algorithms, but is kept here for future reference.

Module Contents

Classes

MpiAdamOptimizer

Adam optimizer that averages gradients across MPI processes.

Functions

flat_concat(xs)

assign_params_from_flat(x, params)

sync_params(params)

sync_all_params()

Sync all tf variables across MPI processes.

stable_learning_control.utils.mpi_utils.mpi_tf2.flat_concat(xs)[source]
stable_learning_control.utils.mpi_utils.mpi_tf2.assign_params_from_flat(x, params)[source]
stable_learning_control.utils.mpi_utils.mpi_tf2.sync_params(params)[source]
stable_learning_control.utils.mpi_utils.mpi_tf2.sync_all_params()[source]

Sync all tf variables across MPI processes.

class stable_learning_control.utils.mpi_utils.mpi_tf2.MpiAdamOptimizer(**kwargs)[source]

Bases: object

Adam optimizer that averages gradients across MPI processes.

The compute_gradients method is taken from Baselines MpiAdamOptimizer. For documentation on method arguments, see the TensorFlow docs page for the base AdamOptimizer.