Custom agents
Although our library provides a palette of already implemented agents, you might want to add a personalised one to the collection. This guide is to help you with this task.
Implementing new agents
To fully benefit from Reinforced-lib features, including JAX JIT optimization, your agent
should inherit from the abstract class BaseAgent. We present adding a
custom agent on an example of a simple epsilon-greedy agent:
class EGreedy(BaseAgent)
Firstly, we need to define the state of our agent, which in our case will hold
quality values of each arm (Q),
number of each arms’ tries (N),,
and will inherit from AgentState:
@dataclass
class EGreedyState(AgentState):
Q: Array
N: Array
The BaseAgent interface breaks the agent’s behaviour into three methods:
init(PRNGKey, …) -> AgentState - initializes the agent’s state,
update(AgentState, PRNGKey, …) -> AgentState - updates the agent’s state after performing some action and receiving a reward,
sample(AgentState, PRNGKey, …) -> Action - samples new action according to the agent’s and environment’s state.
We define the Epsilon-greedy agent, which will have 3 static methods:
# This method initializes the agent with 'n_arms' arms
@staticmethod
def init(
key: PRNGKey,
n_arms: int
) -> EGreedyState:
return EGreedyState(
# The initial Q values are set as zeros, due to the lack of prior knowledge
Q=jnp.zeros(n_arms),
# The numbers of tries are set as ones, to avoid null division in Q value update
N=jnp.ones(n_arms, dtype=int)
)
# This method updates the agents state
@staticmethod
def update(
state: EGreedyState,
key: PRNGKey,
action: int,
reward: Scalar,
) -> EGreedyState:
return EGreedyState(
# Q value update
Q=state.Q.at[action].add((reward - state.Q[action]) / state.N[action]),
# Incrementing the number of tries on appropriate arm
N=state.N.at[action].add(1)
)
# This method samples new action according to the agents state (experience)
@staticmethod
def sample(
state: EGreedyState,
key: PRNGKey,
e: Scalar
) -> int:
# Split PRNGkey to use it twice
epsilon_key, choice_key = jax.random.split(key)
# We further want to jax.jit this function, so basic 'if' is not allowed here
return jax.lax.cond(
# The agent experiments with probability e
jax.random.uniform(epsilon_key) < e,
# On exploration, agent chooses a random arm
lambda: jax.random.choice(choice_key, state.Q.size),
# On exploitation, agent chooses the best known arm
lambda: jnp.argmax(state.Q)
)
Having defined these static methods, we can implement the class constructor:
def __init__(
self,
n_arms: int,
e: Scalar
) -> None:
# Make sure that epsilon has correct value
assert 0 <= e <= 1
# We specify the features of our agent
self.n_arms = n_arms
# Here we can use the jax.jit() functionality with the previously
# defined behaviour functions, to make the agent super fast.
# Note that we use partial() to specify the parameters that are
# constant during the agent's lifetime to avoid passing them
# every time the function is called.
self.init = jax.jit(partial(self.init, n_arms=self.n_arms))
self.update = jax.jit(self.update)
self.sample = jax.jit(partial(self.sample, e=e))
Now we specify the initialization arguments of our agent (i.e., the parameters that are required by the
agent’s constructor). This is done by implementing the static method parameter_space() which returns
a dictionary in the format of a Gymnasium (former OpenAI Gym) space.
It is not required to implement this method, but it is a good practice to do so. This enables the library to automatically
provide initialization arguments specified by extensions.
# Parameters required by the agent constructor in Gymnasium format.
# Type of returned value is required to be gym.spaces.Dict.
@staticmethod
def parameter_space() -> gym.spaces.Dict:
return gym.spaces.Dict({
'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int),
'e': gym.spaces.Box(0.0, 1.0, (1,), float)
})
Specifying the action space of the agent is accomplished by implementing the action_space property.
While not mandatory, adhering to this practice is recommended as it allows users to conveniently inspect
the agent’s action space through the action_space method of the RLib class.
# Action returned by the agent in Gymnasium format.
@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self.n_arms)
Finally, we define the observation spaces for our agent by implementing the properties called
update_observation_space and sample_observation_space. Although not mandatory, we strongly
encourage their implementation as it allows the library to deduce absent values from raw observations
and functions defined in the extensions. Moreover, having these properties
implemented facilitates a seamless export of the agent to the TensorFlow Lite format, where
the library can automatically generate an example set of parameters during the export procedure.
# Parameters required by the 'update' method in Gymnasium format.
@property
def update_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({
'action': gym.spaces.Discrete(self.n_arms),
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)
})
# Parameters required by the 'sample' method in Gymnasium format.
@property
def sample_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({})
Now we have a ready to operate epsilon-greedy agent!
Template agent
Here is all of the above code in one piece. You can copy-paste it and use as an inspiration to create your own agent.
from functools import partial
import gymnasium as gym
import jax
import jax.numpy as jnp
from chex import dataclass, Array, Scalar, PRNGKey
from reinforced_lib.agents import BaseAgent, AgentState
@dataclass
class EGreedyState(AgentState):
Q: Array
N: Array
class EGreedy(BaseAgent):
def __init__(
self,
n_arms: int,
e: Scalar
) -> None:
assert 0 <= e <= 1
self.n_arms = n_arms
self.init = jax.jit(partial(self.init, n_arms=n_arms))
self.update = jax.jit(self.update)
self.sample = jax.jit(partial(self.sample, e=e))
@staticmethod
def parameter_space() -> gym.spaces.Dict:
return gym.spaces.Dict({
'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int),
'e': gym.spaces.Box(0.0, 1.0, (1,), float)
})
@property
def update_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({
'action': gym.spaces.Discrete(self.n_arms),
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)
})
@property
def sample_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({})
@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self.n_arms)
@staticmethod
def init(
key: PRNGKey,
n_arms: int
) -> EGreedyState:
return EGreedyState(
Q=jnp.zeros(n_arms),
N=jnp.ones(n_arms, dtype=int)
)
@staticmethod
def update(
state: EGreedyState,
key: PRNGKey,
action: int,
reward: Scalar
) -> EGreedyState:
return EGreedyState(
Q=state.Q.at[action].add((reward - state.Q[action]) / state.N[action]),
N=state.N.at[action].add(1)
)
@staticmethod
def sample(
state: EGreedyState,
key: PRNGKey,
e: Scalar
) -> int:
epsilon_key, choice_key = jax.random.split(key)
return jax.lax.cond(
jax.random.uniform(epsilon_key) < e,
lambda: jax.random.choice(choice_key, state.Q.size),
lambda: jnp.argmax(state.Q)
)
Deep learning agents
Although the above example is a simple one, it is not hard to extend it to deep reinforcement learning (DRL) agents. This can be achieved by leveraging the JAX ecosystem, along with the flax library, which provides a convenient way to define neural networks, and optax, which provides a set of optimizers. Below, we provide excerpts of the code for the deep Q-learning agent.
The state of the DRL agent often contains parameters and state of the neural network as well as an experience replay buffer:
@dataclass
class DQNState(AgentState):
params: dict
state: dict
opt_state: optax.OptState
replay_buffer: ReplayBuffer
prev_env_state: Array
epsilon: Scalar
The agent’s constructor allows you to specify parameters for the neural network architecture and optimizer, enabling users to have full control over their choice and enhancing the agent’s flexibility:
def __init__(
self,
q_network: nn.Module,
optimizer: optax.GradientTransformation = None,
...
) -> None:
if optimizer is None:
optimizer = optax.adam(1e-3)
self.init = jax.jit(partial(self.init, q_network=q_network, optimizer=optimizer, ...))
...
By implementing the constructor in this manner, users gain the flexibility to define their own architecture as follows:
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(64)(x)
x = nn.relu(x)
x = nn.Dense(64)(x)
x = nn.relu(x)
return nn.Dense(2)(x)
rl = RLib(
agent_type=DQN,
agent_params={
'q_network': QNetwork(),
'optimizer': optax.rmsprop(3e-4, decay=0.95, eps=1e-2)
},
...
)
Note
In some cases, it is necessary to use a PRNG key in the definition of a neural network to allow the stochastic
behavior of the model. The flax library provides a make_rng(stream_name) method that can be used to generate
a PRNG key from a given stream. The DRL algorithms implemented in Reinforced-lib offer a stream called rlib
by default, so you can use it in your custom model as follows: key = self.make_rng('rlib').
During the development of a DRL agent, our library offers a set of utility functions for your convenience.
Among these functions is gradient_step, designed to streamline parameter updates for the agent using JAX and optax.
In the following example code snippet, we showcase the implementation of a step function responsible for performing
a single step, taking into account the network, optimizer, and the implemented loss function:
from reinforced_lib.utils.jax_utils import gradient_step
step_fn=partial(
gradient_step,
optimizer=optimizer,
loss_fn=partial(self.loss_fn, q_network=q_network, ...)
)
There are also other utility functions that can make it easier to implement DRL agents with flax. These are the
init and forward methods which are used to initialize the network and to perform a forward pass through the
network. You can find more information about these functions in the documentation.
Our Python library also includes a pre-built experience replay buffer, which is commonly utilized in DRL agents. The following code provides an illustrative example of how to use this utility:
from reinforced_lib.utils.experience_replay import experience_replay, ExperienceReplay, ReplayBuffer
er = experience_replay(
experience_replay_buffer_size,
experience_replay_batch_size,
obs_space_shape,
act_space_shape
)
...
replay_buffer = er.init()
...
replay_buffer = er.append(replay_buffer, prev_env_state, action, reward, terminal, env_state)
perform_update = er.is_ready(replay_buffer)
for _ in range(experience_replay_steps):
batch = er.sample(replay_buffer, key)
...
Developing a DRL agent may pose challenges, so we strongly recommend thoroughly studying an example code of one of our DRL agents prior to building your custom agent.
Summary
To sum everything up one more time:
All agents inherit from the
BaseAgentclass.The agent’s state is defined as a dataclass that inherits from the
AgentStateclass.The agent’s behavior is determined by implementing the static methods
init,update, andsample.Utilizing
jax.jitcan significantly increase the agent’s performance.Although not mandatory, it is highly recommended to implement the
parameter_space,update_observation_space, andsample_observation_spaceproperties.Implementing a custom DRL agent is possible using the JAX ecosystem and utility functions provided by the library.