Source code for stable_gym.envs.biological.oscillator.oscillator

"""The oscillator gymnasium environment."""

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium import logger, spaces

[docs]EPISODES = 10 # Number of env episodes to run when __main__ is called.
[docs]RANDOM_STEP = True # Use random action in __main__. Zero action otherwise.
# TODO: Update solving criteria after training.
[docs]class Oscillator(gym.Env): r"""Synthetic oscillatory network environment. .. Note:: Can also be used in a vectorized manner. See the :gymnasium:`gym.vector <api/vector>` documentation. Description: The goal of the agent in the oscillator environment is to act in such a way that one of the proteins of the synthetic oscillatory network follows a supplied reference signal. Source: This environment corresponds to the Oscillator environment used in the paper `Han et al. 2020`_. In our implementation several additional features were added to the environment to make it more flexible and easier to use: - Environment arguments now allow for modification of the reference signal parameters. - System parameters can now be individually adjusted for each protein, rather than applying the same parameters across all proteins. - The reference can be omitted from the observation. - Reference error can be included in the info dictionary. - The observation space was expanded to accurately reproduce the plots presented in `Han et al. 2020`_, which was not possible with the original code's observation space. - Added an adjustable ``max_cost`` threshold for episode termination, defaulting to 100 to match the original environment. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 Observation: **Type**: Box(7) or Box(8) depending on the ``exclude_reference_error_from_observation`` argument. +-----+-----------------------------------------------+-------------------+-------------------+ | Num | Observation | Min | Max | +=====+===============================================+===================+===================+ | 0 | Lacl mRNA transcripts concentration | 0 | :math:`\infty` | +-----+-----------------------------------------------+-------------------+-------------------+ | 1 | tetR mRNA transcripts concentration | 0 | :math:`\infty` | +-----+-----------------------------------------------+-------------------+-------------------+ | 2 | CI mRNA transcripts concentration | 0 | :math:`\infty` | +-----+-----------------------------------------------+-------------------+-------------------+ | 3 || lacI (repressor) protein concentration | 0 | :math:`\infty` | | || (Inhibits transcription of the tetR gene) | | | +-----+-----------------------------------------------+-------------------+-------------------+ | 4 || tetR (repressor) protein concentration | 0 | :math:`\infty` | | || (Inhibits transcription of CI gene) | | | +-----+-----------------------------------------------+-------------------+-------------------+ | 5 || CI (repressor) protein concentration | 0 | :math:`\infty` | | || (Inhibits transcription of lacI gene) | | | +-----+-----------------------------------------------+-------------------+-------------------+ | 6 | The reference we want to follow | 0 | :math:`\infty` | +-----+-----------------------------------------------+-------------------+-------------------+ | (7) || **Optional** - The error between the current | :math:`-\infty` | :math:`\infty` | | || value of protein 1 and the reference | | | +-----+-----------------------------------------------+-------------------+-------------------+ Actions: **Type**: Box(3) +-----+------------------------------------------------------------+---------+---------+ | Num | Action | Min | Max | +=====+============================================================+=========+=========+ | 0 || Relative intensity of light signal that induce the | 0 | 1 | | || expression of the Lacl mRNA gene. | | | +-----+------------------------------------------------------------+---------+---------+ | 1 || Relative intensity of light signal that induce the | 0 | 1 | | || expression of the tetR mRNA gene. | | | +-----+------------------------------------------------------------+---------+---------+ | 2 || Relative intensity of light signal that induce the | 0 | 1 | | || expression of the CI mRNA gene. | | | +-----+------------------------------------------------------------+---------+---------+ Cost: A cost, computed as the sum of the squared differences between the estimated and the actual states: .. math:: C = {p_1 - r_1}^2 Starting State: All observations are assigned a uniform random value in ``[0..5]`` Episode Termination: - An episode is terminated when the maximum step limit is reached. - The step exceeds a threshold (default is 100). This threshold can be adjusted using the `max_cost` environment argument. Solved Requirements: Considered solved when the average cost is lower than 300. How to use: .. code-block:: python import stable_gym import gymnasium as gym env = gym.make("stable_gym:Oscillator-v1") On reset, the ``options`` parameter allows the user to change the bounds used to determine the new random state when ``random=True``. Attributes: state (numpy.ndarray): The current system state. t (float): The current time step. dt (float): The environment step size. Also available as :attr:`.tau`. sigma (float): The variance of the system noise. max_cost (float): The maximum cost allowed before the episode is terminated. """ # noqa: E501 def __init__( self, render_mode=None, # NOTE: Custom environment arguments. max_cost=100.0, reference_target_position=8.0, reference_amplitude=7.0, reference_frequency=(1 / 200), # NOTE: Han et al. 2020 uses a period of 200. reference_phase_shift=0.0, clip_action=True, exclude_reference_from_observation=False, exclude_reference_error_from_observation=False, action_space_dtype=np.float64, observation_space_dtype=np.float64, ): """Initialise a new Oscillator environment instance. Args: render_mode (str, optional): The render mode you want to use. Defaults to ``None``. Not used in this environment. max_cost (float, optional): The maximum cost allowed before the episode is terminated. Defaults to ``100.0``. reference_target_position: The reference target position, by default ``8.0`` (i.e. the mean of the reference signal). reference_amplitude: The reference amplitude, by default ``7.0``. reference_frequency: The reference frequency, by default ``0.005``. reference_phase_shift: The reference phase shift, by default ``0.0``. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. exclude_reference_from_observation (bool, optional): Whether the reference should be excluded from the observation. Defaults to ``False``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``False``. action_space_dtype (union[numpy.dtype, str], optional): The data type of the action space. Defaults to ``np.float64``. observation_space_dtype (union[numpy.dtype, str], optional): The data type of the observation space. Defaults to ``np.float64``. """ super().__init__() assert max_cost > 0, "The maximum cost must be greater than 0."
[docs] self.max_cost = max_cost
[docs] self._action_clip_warning = False
[docs] self._clip_action = clip_action
[docs] self._exclude_reference_from_observation = exclude_reference_from_observation
[docs] self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation )
[docs] self._action_space_dtype = action_space_dtype
[docs] self._observation_space_dtype = observation_space_dtype
[docs] self._action_dtype_conversion_warning = False
# Validate input arguments. assert (reference_amplitude == 0 or reference_frequency == 0) or not ( exclude_reference_from_observation and exclude_reference_error_from_observation ), ( "The agent needs to observe either the reference or the reference error " "for it to be able to learn when the reference is not constant." ) assert ( reference_frequency >= 0 ), "The reference frequency must be greater than or equal to zero."
[docs] self.t = 0.0
[docs] self.dt = 1.0
[docs] self._init_state = np.array( [0.8, 1.5, 0.5, 3.3, 3, 3], dtype=self._observation_space_dtype ) # Used when random is disabled in reset.
[docs] self._init_state_range = { "low": [0, 0, 0, 0, 0, 0], "high": [5, 5, 5, 5, 5, 5], } # Used when random is enabled in reset.
# Set oscillator network parameters.
[docs] self.K1 = 1.0 # mRNA dissociation constants m1.
[docs] self.K2 = 1.0 # mRNA dissociation constant m2.
[docs] self.K3 = 1.0 # mRNA dissociation constant m3.
[docs] self.a1 = 1.6 # Maximum promoter strength m1.
[docs] self.a2 = 1.6 # Maximum promoter strength m2.
[docs] self.a3 = 1.6 # Maximum promoter strength m3.
[docs] self.gamma1 = 0.16 # mRNA degradation rate m1.
[docs] self.gamma2 = 0.16 # mRNA degradation rate m2.
[docs] self.gamma3 = 0.16 # mRNA degradation rate m3.
[docs] self.beta1 = 0.16 # Protein production rate p1.
[docs] self.beta2 = 0.16 # Protein production rate p2.
[docs] self.beta3 = 0.16 # Protein production rate p3.
[docs] self.c1 = 0.06 # Protein degradation rate p1.
[docs] self.c2 = 0.06 # Protein degradation rate p2.
[docs] self.c3 = 0.06 # Protein degradation rate p3.
[docs] self.b1 = 5.0 # Control input gain u1.
[docs] self.b2 = 5.0 # Control input gain u2.
[docs] self.b3 = 5.0 # Control input gain u3.
# Set noise parameters. # NOTE: Zero during training.
[docs] self.delta1 = 0.0 # m1 noise.
[docs] self.delta2 = 0.0 # m2 noise.
[docs] self.delta3 = 0.0 # m3 noise.
[docs] self.delta4 = 0.0 # p1 noise.
[docs] self.delta5 = 0.0 # p2 noise.
[docs] self.delta6 = 0.0 # p3 noise.
# NOTE: Observation space was changed compared to the original codebase of # Han et al. 2020 to match paper's plots.
[docs] obs_low = np.array( [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ) # NOTE: Han's original code used -1.0.
[docs] obs_high = np.array( [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf] ) # NOTE: Han's original code used 1.0.
if not self._exclude_reference_from_observation: obs_low = np.append(obs_low, 0.0) obs_high = np.append(obs_high, np.inf) if not self._exclude_reference_error_from_observation: obs_low = np.append(obs_low, -np.inf) obs_high = np.append(obs_high, np.inf) # NOTE: Han et al. 2020 did not clearly detail the action space in their paper. # As a result the action space from their original code is used.
[docs] self.action_space = spaces.Box( low=np.array([0.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=self._action_space_dtype, )
[docs] self.observation_space = spaces.Box( obs_low, obs_high, dtype=self._observation_space_dtype )
[docs] self.reward_range = (0.0, self.max_cost)
[docs] self.viewer = None
[docs] self.state = None
[docs] self.steps_beyond_done = None
# Reference target, amplitude, frequency and phase shift.
[docs] self.reference_target_pos = reference_target_position
[docs] self.reference_amplitude = reference_amplitude
[docs] self.reference_frequency = reference_frequency
[docs] self.phase_shift = reference_phase_shift
[docs] def step(self, action): """Take step into the environment. Args: action (numpy.ndarray): The action we want to perform in the environment. Returns: (tuple): tuple containing: - obs (:obj:`np.ndarray`): Environment observation. - cost (:obj:`float`): Cost of the action. - terminated (:obj:`bool`): Whether the episode is terminated. - truncated (:obj:`bool`): Whether the episode was truncated. This value is set by wrappers when for example a time limit is reached or the agent goes out of bounds. - info (:obj:`dict`): Additional information about the environment. """ # Convert action to correct data type if needed. if action.dtype != self._action_space_dtype: if not self._action_dtype_conversion_warning: logger.warn( "The data type of the action that is supplied to the " f"'ros_gazebo_gym:{self.spec.id}' environment ({action.dtype}) " "does not match the data type of the action space " f"({self._action_space_dtype.__name__}). The action data type will " "be converted to the action space data type." ) self._action_dtype_conversion_warning = True action = action.astype(self._action_space_dtype) # Clip action if needed. if self._clip_action: # Throw warning if clipped and not already thrown. if not self.action_space.contains(action) and not self._action_clip_warning: logger.warn( f"Action '{action}' was clipped as it is not in the action_space " f"'high: {self.action_space.high}, low: {self.action_space.low}'." ) self._action_clip_warning = True u1, u2, u3 = np.clip(action, self.action_space.low, self.action_space.high) else: assert self.action_space.contains( action ), f"{action!r} ({type(action)}) invalid" u1, u2, u3 = action assert self.state is not None, "Call reset before using step method." # Perform action in the environment and return the new state. # NOTE: The new state is found by solving 3 first-order differential equations. m1, m2, m3, p1, p2, p3 = self.state # NOTE: [x1, x2, x3, x4, x5, x6] in paper. m1_dot = -self.gamma1 * m1 + self.a1 / (self.K1 + np.square(p3)) + self.b1 * u1 m2_dot = -self.gamma2 * m2 + self.a2 / (self.K2 + np.square(p1)) + self.b2 * u2 m3_dot = -self.gamma3 * m3 + self.a3 / (self.K3 + np.square(p2)) + self.b3 * u3 p1_dot = -self.c1 * p1 + self.beta1 * m1 p2_dot = -self.c2 * p2 + self.beta2 * m2 p3_dot = -self.c3 * p3 + self.beta3 * m3 # Calculate mRNA concentrations. # Note: Use max to make sure concentrations can not be negative. m1 = np.max( [ m1 + m1_dot * self.dt + self.np_random.uniform(-self.delta1, self.delta1, 1), np.zeros([1]), ] ) m2 = np.max( [ m2 + m2_dot * self.dt + self.np_random.uniform(-self.delta2, self.delta2, 1), np.zeros([1]), ] ) m3 = np.max( [ m3 + m3_dot * self.dt + self.np_random.uniform(-self.delta3, self.delta3, 1), np.zeros([1]), ] ) # Calculate protein concentrations. # Note: Use max to make sure concentrations can not be negative. p1 = np.max( [ p1 + p1_dot * self.dt + self.np_random.uniform(-self.delta4, self.delta4, 1), np.zeros([1]), ] ) p2 = np.max( [ p2 + p2_dot * self.dt + self.np_random.uniform(-self.delta5, self.delta5, 1), np.zeros([1]), ] ) p3 = np.max( [ p3 + p3_dot * self.dt + self.np_random.uniform(-self.delta6, self.delta6, 1), np.zeros([1]), ] ) # Retrieve state. self.state = np.array([m1, m2, m3, p1, p2, p3]) self.t = self.t + self.dt # Calculate cost. r1 = self.reference(self.t).astype(self._observation_space_dtype) cost = np.square(p1 - r1) # Define stopping criteria. terminated = cost < self.reward_range[0] or cost > self.reward_range[1] # Create observation and info_dict. obs = np.array([m1, m2, m3, p1, p2, p3], dtype=self._observation_space_dtype) p1 = p1.astype(self._observation_space_dtype) if not self._exclude_reference_from_observation: obs = np.append(obs, r1) if not self._exclude_reference_error_from_observation: obs = np.append(obs, p1 - r1) info_dict = dict( reference=r1, state_of_interest=p1, reference_error=p1 - r1, ) # Return state, cost, terminated, truncated and info_dict. return ( obs, cost, terminated, False, info_dict, )
[docs] def reset( self, seed=None, options=None, random=True, ): """Reset gymnasium environment. Args: seed (int, optional): A random seed for the environment. By default ``None``. options (dict, optional): A dictionary containing additional options for resetting the environment. By default ``None``. Not used in this environment. random (bool, optional): Whether we want to randomly initialise the environment. By default True. Returns: (tuple): tuple containing: - obs (:obj:`numpy.ndarray`): Initial environment observation. - info (:obj:`dict`): Dictionary containing additional information. """ super().reset(seed=seed) # Initialise custom bounds while ensuring that the bounds are valid. # NOTE: If you use custom reset bounds, it may lead to out-of-bound # state/observations. low = np.array( ( options["low"] if options is not None and "low" in options else self._init_state_range["low"] ), dtype=self._observation_space_dtype, ) high = np.array( ( options["high"] if options is not None and "high" in options else self._init_state_range["high"] ), dtype=self._observation_space_dtype, ) assert ( self.observation_space.contains( np.append( low, np.zeros( self.observation_space.shape[0] - low.shape[0], dtype=self._observation_space_dtype, ), ) ) ) and ( self.observation_space.contains( np.append( high, np.zeros( self.observation_space.shape[0] - high.shape[0], dtype=self._observation_space_dtype, ), ) ) ), ( "Reset bounds must be within the observation space bounds " f"({self.observation_space})." ) # Set initial state, reset time, retrieve initial observation and info_dict. self.state = ( self.np_random.uniform(low=low, high=high, size=(6,)) if random else self._init_state ) self.t = 0.0 obs = self.state.astype(self._observation_space_dtype) p1 = obs[3] r1 = self.reference(self.t).astype(self._observation_space_dtype) if not self._exclude_reference_from_observation: obs = np.append(obs, r1) if not self._exclude_reference_error_from_observation: obs = np.append(obs, p1 - r1) info_dict = dict( reference=r1, state_of_interest=p1, reference_error=p1 - r1, ) # Return initial observation and info_dict. return obs, info_dict
[docs] def reference(self, t): r"""Returns the current value of the periodic reference signal that is tracked by the Synthetic oscillatory network. Args: t (float): The current time step. Returns: float: The current reference value. .. note:: This uses the general form of a periodic signal: .. math:: y(t) = A \sin(\omega t + \phi) + C \\ y(t) = A \sin(2 \pi f t + \phi) + C \\ y(t) = A \sin(\frac{2 \pi}{T} t + \phi) + C Where: - :math:`t` is the time. - :math:`A` is the amplitude of the signal. - :math:`\omega` is the frequency of the signal. - :math:`f` is the frequency of the signal. - :math:`T` is the period of the signal. - :math:`\phi` is the phase of the signal. - :math:`C` is the offset of the signal. """ return self.reference_target_pos + self.reference_amplitude * np.sin( ((2 * np.pi) * self.reference_frequency * t) - self.phase_shift )
[docs] def render(self, mode="human"): """Render one frame of the environment. Args: mode (str, optional): Gym rendering mode. The default mode will do something human friendly, such as pop up a window. Raises: NotImplementedError: Will throw a NotImplimented error since the render method has not yet been implemented. Note: This currently is not yet implemented. """ raise NotImplementedError( "No render method was implemented yet for the Oscillator environment." )
@property
[docs] def tau(self): """Alias for the environment step size. Done for compatibility with the other gymnasium environments. """ return self.dt
@property
[docs] def physics_time(self): """Returns the physics time. Alias for :attr:`.t`.""" return self.t
if __name__ == "__main__": print("Setting up 'Oscillator' environment.")
[docs] env = gym.make("stable_gym:Oscillator")
# Run episodes. episode = 0 path, paths = [], [] reference, references = [], [] s, info = env.reset() path.append(s) reference.append(info["reference"]) print(f"\nPerforming '{EPISODES}' in the 'Oscillator' environment...\n") print(f"Episode: {episode}") while episode + 1 <= EPISODES: action = ( env.action_space.sample() if RANDOM_STEP else np.zeros(env.action_space.shape) ) s, r, terminated, truncated, info = env.step(action) path.append(s) reference.append(info["reference"]) if terminated or truncated: paths.append(path) references.append(reference) episode += 1 path, reference = [], [] s, info = env.reset() path.append(s) reference.append(info["reference"]) print(f"Episode: {episode}") print("\nFinished 'Oscillator' environment simulation.") # Plot results per episode. print("\nPlotting episode data...") for i in range(len(paths)): path = paths[i] fig, ax = plt.subplots() print(f"\nEpisode: {i+1}") path = np.array(path) t = np.linspace(0, path.shape[0] * env.dt, path.shape[0]) for j in range(path.shape[1]): # NOTE: Change if you want to plot less states. ax.plot(t, path[:, j], label=f"State {j+1}") ax.set_xlabel("Time (s)") ax.set_title(f"Oscillator episode '{i+1}'") # Plot reference signal. ax.plot( t, np.array(references[i]), color="black", linestyle="--", label="Reference", ) ax.legend() print("Close plot to see next episode...") plt.show() print("\nDone") env.close()