Custom extensions

The extensions is a unique functionality that allows a library to infer missing observations that are not originally supported by the environment. You can either choose one of our built-in extensions or implement your own with the help of this short guide.

Key concepts of extensions

There are three main benefits of using extensions:

  1. Automatic initialization of agents - an extension can provide default arguments that can be used to initialize an agent. For example, if we would like to train the deep Q-learning agent on a cart-pole environment without using any extension, we would probably do it in the following way:

    rl = RLib(
        agent_type=DQN,
        agent_params={
            'q_network': q_network,
            'obs_space_shape': (4,),
            'act_space_size': 2
        },
        no_ext_mode=True
    )
    

    On the other hand, if we decide to use the Gymnasium extension, some of the parameters can be automatically provided by the extension:

    rl = RLib(
        agent_type=DQN,
        agent_params={'q_network': q_network}
        ext_type=Gymnasium,
        ext_params={'env_id': 'CartPole-v1'},
    )
    

    We can also overwrite all or only part of the default values provided by the extension:

    rl = RLib(
        agent_type=DQN,
        agent_params={
            'q_network': q_network,
            'act_space_size': 3
        },
        ext_type=Gymnasium,
        ext_params={'env_id': 'CartPole-v1'},
    )
    
  2. Simplification of parameter passing - extensions allow automatic matching observations returned by the environment to the appropriate methods of the agent. The code snippet below shows the agent and environment interaction loop without using any extension:

    while not terminal:
        env_state, reward, terminal, truncated, info = env.step(action.item())
    
        action = rl.sample(
            update_observations={
                'env_state': env_state,
                'action': action,
                'reward': reward,
                'terminal': terminal
            },
            sample_observations={
                'env_state': env_state
            }
        )
    

    The following code is equivalent to the above but makes use of the properly defined Gymnasium extension:

    while not env_state[2]:
        env_state = env.step(action.item())
        action = rl.sample(*env_state)
    
  3. Filling missing parameters - some parameters required by the agent can be filled with known values or calculated based on a set of basic observations. For example, a sample method of the Thompson sampling agent requires a context vector. As it is a domain specific knowledge, these values can be found in the appropriate extension. Below is a sample code that could be used to sample the next action in the IEEE 802.11ax rate adaptation problem without using any extension:

     rl = RLib(
         agent_type=ThompsonSampling,
         agent_params={'n_arms': 12},
         no_ext_mode=True
     )
    
    observations = {
        'delta_time': 0.18232,
        'n_successful': 10,
        'n_failed': 0,
        'context': jnp.array(
            [7.3, 14.6, 21.9, 29.3, 43.9, 58.5,
            65.8, 73.1, 87.8, 97.5, 109.7, 121.9]
        )
    }
    action = rl.sample(**observations)
    

    If we use the IEEE 802.11ax RA extension, part of these parameters can be provided by the extension:

     rl = RLib(
         agent_type=ThompsonSampling,
         ext_type=IEEE_802_11_ax_RA
     )
    
    observations = {
        'delta_time': 0.18232,
        'n_successful': 10,
        'n_failed': 0
    }
    action = rl.sample(**observations)
    

    We can also overwrite the values provided by the extension:

    observations = {
        'delta_time': 0.18232,
        'n_successful': 10,
        'n_failed': 0,
        'context': jnp.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
    }
    action = rl.sample(**observations)
    

You can define default values as initialization arguments for agents through parameter functions. Additionally, default values or functions to calculate missing observations can be defined using observation functions. To designate these functions correctly, they are decorated with the @observation and @parameter decorators respectively. A more detailed description of this decorator can be found in the section below.

Implementing new extensions

To create your own extension, you should inherit from the abstract class BaseExt. We present adding a custom extension using an example of the extension used in the IEEE 802.11ax rate adaptation, problem.

class IEEE_802_11_ax_RA(BaseExt)

First, we must specify the observation space of the extension. It is a basic set of environment observations that can be used by the extension to compute missing values. Note that a complete set of all parameters is not necessarily required to use the extension - if an agent does not require a given parameter and it is not used to compute missing values, the extension will ignore it. In the case of the IEEE 802.11ax environment, the observation space can look like this:

observation_space = gym.spaces.Dict({
    'time': gym.spaces.Box(0.0, jnp.inf, (1,), float),
    'n_successful': gym.spaces.Box(0, jnp.inf, (1,), int),
    'n_failed': gym.spaces.Box(0, jnp.inf, (1,), int),
    'n_wifi': gym.spaces.Box(1, jnp.inf, (1,), int),
    'power': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float),
    'cw': gym.spaces.Discrete(32767)
})

Next, we define the parameter function that will provide the default power value for agents that require this parameter as a constructor argument. We can do this by creating an appropriate method and decorating it with the @parameter decorator. The parameter functions are methods of the extension and cannot take any additional arguments:

@parameter()
def default_power(self):
    return 16.0206

We can also specify the type of the returned value in Gymnasium (former OpenAI Gym) format. It will help the library to check if a given value type is compatible with the argument required by the agent:

@parameter(parameter_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,)))
def default_power(self) -> float:
    return 16.0206

Note that the name of the function must match the name of the argument required by the agent. If there already exists a function with that name, we can name the function differently and explicitly define the argument name in the decorator:

@parameter(parameter_name='default_power', parameter_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,)))
def default_pow(self) -> float:
    return 16.0206

We define the observation functions by analogy to parameter functions. The differences are that we use the @observation decorator and that the implemented methods can take additional parameters. Below is an example observation function that provides a reward calculated as an approximated throughput in the IEEE 802.11ax environment:

@observation()
def reward(self, mcs, n_successful, n_failed, *args, **kwargs):
    if n_successful + n_failed > 0:
        return self._wifi_modes_rates[mcs] * n_successful / (n_successful + n_failed)
    else:
        return 0.0

Note that the observation function can take parameters that are specified in the observation space. BaseExt will automatically pass the given observation to the function to allow dynamic computation of the returned value. What is important, observation methods must take *args and **kwargs as the last parameters (this is required by the internal behavior of the setup_transformations function). As previously, the name of the function should match the name of the filled parameter, but we can specify the parameter name and returned type in the decorator:

@observation(observation_name='reward', observation_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,)))
def custom_reward(self, mcs: int, n_successful: int, n_failed: int, *args, **kwargs) -> float:
    if n_successful + n_failed > 0:
        return self._wifi_modes_rates[mcs] * n_successful / (n_successful + n_failed)
    else:
        return 0.0

Template extension

To simplify the process of creating new extensions, we provide an example extension that can be used as a starting point for creating your own extensions. The IEEE 802.11ax rate adaptation extension can be found here:

import gymnasium as gym
import jax.numpy as jnp
from chex import Array, Scalar

from reinforced_lib.exts import BaseExt, observation, parameter


class IEEE_802_11_ax_RA(BaseExt):
    def __init__(self) -> None:
        super().__init__()
        self.last_time = 0.0

    observation_space = gym.spaces.Dict({
        'time': gym.spaces.Box(0.0, jnp.inf, (1,), float),
        'n_successful': gym.spaces.Box(0, jnp.inf, (1,), int),
        'n_failed': gym.spaces.Box(0, jnp.inf, (1,), int),
        'n_wifi': gym.spaces.Box(1, jnp.inf, (1,), int),
        'power': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float),
        'cw': gym.spaces.Discrete(32767)
    })

    _wifi_modes_rates = jnp.array([
        7.3, 14.6, 21.9, 29.3, 43.9, 58.5, 65.8, 73.1, 87.8, 97.5, 109.7, 121.9
    ], dtype=float)

    @observation(observation_type=gym.spaces.Box(0.0, jnp.inf, (len(_wifi_modes_rates),), float))
    def rates(self, *args, **kwargs) -> Array:
        return self._wifi_modes_rates

    @observation(observation_type=gym.spaces.Box(-jnp.inf, jnp.inf, (len(_wifi_modes_rates),), float))
    def context(self, *args, **kwargs) -> Array:
        return self.rates()

    @observation(observation_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float))
    def reward(self, action: int, n_successful: int, n_failed: int, *args, **kwargs) -> float:
        if n_successful + n_failed > 0:
            return self._wifi_modes_rates[action] * n_successful / (n_successful + n_failed)
        else:
            return 0.0

    @observation(observation_type=gym.spaces.Box(0.0, jnp.inf, (1,), float))
    def delta_time(self, time: Scalar, *args, **kwargs) -> float:
        delta_time = time - self.last_time
        self.last_time = time
        return delta_time

    @observation(observation_type=gym.spaces.Box(-jnp.inf, jnp.inf, (6,), float))
    def env_state(
            self, time: Scalar, n_successful: int, n_failed: int,
            n_wifi: int, power: Scalar, cw: int, *args, **kwargs
    ) -> Array:
        return jnp.array([self.delta_time(time), n_successful, n_failed, n_wifi, power, cw], dtype=float)

    @observation(observation_type=gym.spaces.MultiBinary(1))
    def terminal(self, *args, **kwargs) -> bool:
        return False

    @parameter(parameter_type=gym.spaces.Box(1, jnp.inf, (1,), int))
    def n_mcs(self) -> int:
        return len(self._wifi_modes_rates)

    @parameter(parameter_type=gym.spaces.Box(1, jnp.inf, (1,), int))
    def n_arms(self) -> int:
        return self.n_mcs()

    @parameter(parameter_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float))
    def default_power(self) -> Scalar:
        return 16.0206

    @parameter(parameter_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float))
    def min_reward(self) -> Scalar:
        return 0

    @parameter(parameter_type=gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float))
    def max_reward(self) -> int:
        return self._wifi_modes_rates.max()

    @parameter(parameter_type=gym.spaces.Sequence(gym.spaces.Box(1, jnp.inf, (1,), int)))
    def obs_space_shape(self) -> tuple:
        return tuple((6,))

    @parameter(parameter_type=gym.spaces.Sequence(gym.spaces.Box(1, jnp.inf, (1,), int)))
    def act_space_shape(self) -> tuple:
        return tuple((1,))

    @parameter(parameter_type=gym.spaces.Box(1, jnp.inf, (1,), int))
    def act_space_size(self) -> int:
        return 12

    @parameter(parameter_type=gym.spaces.Sequence(gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)))
    def min_action(self) -> tuple:
        return 0

    @parameter(parameter_type=gym.spaces.Sequence(gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)))
    def max_action(self) -> tuple:
        return 11

Rules and limitations

Extensions are powerful mechanisms that make Reinforced-lib easy to use. The BaseExt methods can handle complex and nested observation spaces, such as these example ones. However, there are some rules and limitations that programmers and users must consider:

  • arguments and parameters provided by the user have higher priority than the default or calculated by the extension,

  • parameter functions cannot take any arguments (except self),

  • you cannot use an extension with a given agent if the agent requires a parameter that is not listed in the extensions observation space or cannot be provided by an observation function - you have to add an observation to the observation space, implement the appropriate observation function or use the agent without any extension,

  • missing parameter filling is supported only if the type of the extension observation space and the type of agent space can be matched - that means they both must be:

    • a dict type - gym.spaces.Dict,

    • or a “simple” type - gym.spaces.Box, gym.spaces.Discrete, gym.spaces.MultiBinary, gym.spaces.MultiDiscrete, gym.spaces.Space,

  • missing parameter filling is not supported if spaces inherit from gym.spaces.Tuple - values would have to be matched based on the type and this can lead to ambiguities if there are multiple parameters with the same type,

  • if spaces do not inherit from gym.spaces.Dict, missing values are matched based on the type of the value, not the name, so the first function that type matches the agent space is chosen,

  • if an observation function requires some parameter and it is not provided by a named argument, BaseExt will select the first (possibly nested) positional argument and pass it to the function, but if there are no positional arguments, the library will raise an exception.

How do extensions work?

The main axis of this module is the abstract class BaseExt, which provides the core functionality of extensions. It implements important methods, such as get_agent_params, transform, and setup_transformations. The class internally makes use of these methods to provide a simple and powerful API of Reinforced-lib. You can read more about the BaseExt class here or check out the source code.