Source code for stable_learning_control.algos.tf2.common.get_lr_scheduler

"""Module used for creating TensorFlow learning rate schedulers."""

import numpy as np

from stable_learning_control.utils.import_utils import import_tf

[docs]tf = import_tf() # Throw custom warning if tf is not installed.
[docs]def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps): """Creates a learning rate scheduler. Args: decaying_lr_type (str): The learning rate decay type that is used (options are: ``linear`` and ``exponential`` and ``constant``). lr_start (float): Initial learning rate. lr_final (float): Final learning rate. steps (int, optional): Number of steps/epochs used in the training. This includes the starting step/epoch. Returns: tensorflow.keras.optimizers.schedules.LearningRateSchedule: A learning rate scheduler object. .. seealso:: See the :tf2:`TensorFlow <keras/optimizers/schedules>` documentation on how to implement other decay options. """ # noqa: E501 if decaying_lr_type.lower() != "constant" and lr_start != lr_final: if decaying_lr_type.lower() == "exponential": exponential_decay_rate = np.float64(lr_final) / np.float64(lr_start) lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay( lr_start, (steps - 1), np.float64(exponential_decay_rate), ) else: lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay( lr_start, (steps - 1), lr_final, power=1.0, ) return lr_scheduler else: return lambda step: lr_start # Return a constant learning rate.