Source code for stable_gym.common.utils

"""Utility functions that are used in multiple Stable Gym gymnasium environments."""

import re

import gymnasium as gym
import numpy as np
from gymnasium.utils import colorize as gym_colorize


[docs]def colorize(string, color, bold=False, highlight=False): """Colorize a string. .. seealso:: This function wraps the :meth:`gym.utils.colorize` function to make sure that it also works with empty color strings. Args: string (str): The string you want to colorize. color (str): The color you want to use. bold (bool, optional): Whether you want the text to be bold. Defaults to ``False``. highlight (bool, optional): Whether you want to highlight the text. Defaults to ``False``. Returns: str: Colorized string. """ if color: # If not empty. return gym_colorize(string, color, bold, highlight) else: return string
[docs]def get_flattened_values(input_obj): """Retrieves all the values that are present in a nested dictionary and appends them to a list. Its like a recursive version of the :meth:`dict.values()` method. Args: input_obj (dict): The input dictionary from which you want to retrieve all the values. Returns: list: A list containing all the values that were present in the nested dictionary. """ flat_values = [] if isinstance(input_obj, dict): for item in input_obj.values(): if isinstance(item, dict): for it in item.values(): flat_values.extend(get_flattened_values(it)) else: flat_values.append(item) else: flat_values.append(input_obj) return flat_values
[docs]def get_flattened_keys(input_obj, include_root=False): """Retrieves all the keys that are present in a nested dictionary and appends them to a list. Its like a recursive version of the :meth:`dict.keys()` method. Args: input_obj (dict): The input dictionary from which you want to retrieve all the keys. include_root (bool): Whether you want to include the root level keys. Defaults to ``False``. Returns: list: A list containing all the keys that were present in the nested dictionary. """ flat_keys = [] if isinstance(input_obj, dict): if include_root: flat_keys.extend(input_obj.keys()) for key, val in input_obj.items(): if isinstance(val, dict): flat_keys.extend(get_flattened_keys(val)) else: flat_keys.append(key) else: flat_keys.append(input_obj) return flat_keys
[docs]def abbreviate(input_item, length=1, max_length=4, capitalize=True): """Creates unique abbreviations for a string or list of strings. Args: input_item (union[str, list]): The string of list of strings which you want to abbreviate. length (int, optional): The desired length of the abbreviation. Defaults to ``1``. max_length (int, optional): The maximum length of the abbreviation. Defaults to 4. capitalize (bool, optional): Whether the abbrevaitions should be capitalized. Defaults to True. Returns: list: List with abbreviations. """ if isinstance(input_item, list): items = [] abbreviations = [] for it in input_item: unique = False length_tmp = length suffix = "" while not unique: abbreviation = ( it[:length_tmp].capitalize() + str(suffix) if capitalize else it[:length_tmp] + suffix ) if abbreviation not in abbreviations: # Check if unique. abbreviations.append(abbreviation) items.append(it) unique = True else: prev_item = items[abbreviations.index(abbreviation)] if it == prev_item: # Allow if item was equal. abbreviations.append(abbreviation) items.append(it) unique = True else: # Use longer abbreviation otherwise. if length_tmp < max_length: length_tmp += 1 else: suffix = get_lowest_next_int(abbreviations) return abbreviations else: return input_item[:length].capitalize() if capitalize else input_item[:length]
[docs]def get_lowest_next_int(input_item): """Retrieves the lowest next integer that is not present in a string or float list. Args: input_item (union[int, str, list]): The input for which you want to determine the next lowest interger. Returns: int: The next lowest integer. """ if isinstance(input_item, list): input_ints = [ ( ( round(float(re.sub("[^0-9.]", "", item))) if (re.sub("[^0-9.]", "", item) != "") else "" ) if isinstance(item, str) else item ) for item in input_item ] # Trim all non-numeric chars input_ints = [item for item in input_ints if item != ""] else: input_ints = [ ( ( round(float(re.sub("[^0-9.]", "", input_item))) if (re.sub("[^0-9.]", "", input_item) != "") else "" ) if isinstance(input_item, str) else input_item ) ] input_ints = input_ints if input_ints else [0] return list(set(input_ints) ^ set(range(min(input_ints), max(input_ints) + 2)))[0]
[docs]def friendly_list(input_list, apostrophes=False): """Transforms a list to a human friendly format (separated by commas and ampersand). Args: input_list (list): The input list. apostrophes(bool, optional): Whether the list items should be encapsuled with apostrophes. Defaults to ``False``. Returns: str: Human friendly list string. """ input_list = ( ["'" + item + "'" for item in input_list] if apostrophes else input_list ) return " & ".join(", ".join(input_list).rsplit(", ", 1))
[docs]def strip_underscores(text, position="all"): """Strips leading and/or trailing underscores from a string. Args: text (str): The input string. position (str, optional): From which position underscores should be removed. Options are 'leading', 'trailing' & 'both'. Defaults to "both". Returns: str: String without the underscores. """ if position.lower() == "leading": while text.startswith("_"): text = text[1:] elif position.lower() == "trailing": while text.endswith("_"): text = text[:-1] else: text = text.strip("_") return text
[docs]def inject_value(input_item, value, round_accuracy=2, order=False, axis=0): """Injects a value into a list or dictionary if it is not yet present. Args: input_item (union[list,dict]): The input list or dictionary. value (float): The value you want to inject. round_accuracy (int, optional): The accuracy used for checking whether a value is present. Defaults to 2. order (bool, optional): Whether the list should be ordered when returned. Defaults to ``false``. axis (int, optional): The axis along which you want to inject the value. Only used when the input is a numpy array. Defaults to ``0``. Returns: union[list,dict]: The list or dictionary that contains the value. """ order_op = lambda *args, **kwargs: ( # noqa: E731 sorted(*args, **kwargs) if order else list(*args, **kwargs) ) if isinstance(input_item, dict): return { k: inject_value( v, value=value, round_accuracy=round_accuracy, order=order, axis=axis ) for k, v in input_item.items() } elif isinstance(input_item, np.ndarray) and input_item.ndim > 1: transpose_matrix = np.eye(input_item.ndim, dtype=np.int16) return np.transpose( np.array( [ order_op([value] + [it for it in item if it != value]) for item in np.transpose(input_item, transpose_matrix[axis]) ] ), transpose_matrix[axis], ) else: return order_op([value] + [item for item in input_item if item != value])
[docs]def verify_number_and_cast(x): """Verify parameter is a single number and cast to a float.""" try: x = float(x) except (ValueError, TypeError) as e: raise ValueError(f"An option ({x}) could not be converted to a float.") from e return x
[docs]def maybe_parse_reset_bounds(options, default_low, default_high): """This function can be called during a reset() to customize the sampling ranges for setting the initial state distributions. Args: options: Options passed in to reset(). default_low: Default lower limit to use, if none specified in options. default_high: Default upper limit to use, if none specified in options. Returns: (tuple): a tuple containing: - low (:obj:`np.ndarray`): Lower limit for each dimension. - high ():obj:`np.ndarray`): Upper limit for each dimension. """ if options is None: return default_low, default_high low = options.get("low") if "low" in options else default_low high = options.get("high") if "high" in options else default_high # We expect only numerical inputs. low = verify_number_and_cast(low) high = verify_number_and_cast(high) if low > high: raise ValueError( f"Lower bound ({low}) must be lower than higher bound ({high})." ) return low, high
[docs]def change_dict_key(d, old_key, new_key, default_value=None): """Changes the key of a dictionary. Args: d (dict): The dictionary. old_key (str): The old key. new_key (str): The new key. default_value (any, optional): The default value to use if the old key is not present in the dictionary. Defaults to ``None``. """ d[new_key] = d.pop(old_key, default_value) return d
[docs]def convert_gym_box_to_gymnasium_box(gym_box_space, **kwargs): """Converts a gym box space to a gymnasium box space. Args: gym_box_space (gym.spaces.Box): The gym box space. **kwargs: Additional keyword arguments that are passed to the gymnasium box space. Returns: gymnasium.spaces.Box: The gymnasium box space. """ low = kwargs.pop("low", gym_box_space.low) high = kwargs.pop("high", gym_box_space.high) shape = kwargs.pop("shape", gym_box_space.shape) dtype = kwargs.pop("dtype", gym_box_space.dtype) seed = kwargs.pop("seed", gym_box_space.np_random) return gym.spaces.Box( low=low, high=high, shape=shape, dtype=dtype, seed=seed, )
[docs]def change_precision(input_value, precision=16): """Changes the precision of a value. Args: input_value (object): The input value. precision (int, optional): The precision (i.e. number of decimals) to use. Defaults to ``16``. If ``None``, the input value is returned as is. Returns: object: The input value with the new precision. """ if precision is None: return input_value if isinstance(input_value, dict): for key, value in input_value.items(): input_value[key] = change_precision(value, precision) elif isinstance(input_value, np.ndarray): input_value = np.around(input_value, decimals=precision) elif isinstance(input_value, float): input_value = round(input_value, precision) else: pass return input_value