Agents

This module is a set of RL agents. You can either choose one of our built-in agents or implement your agent with the help of the Custom agents guide.

BaseAgent

class AgentState

Base class for agent state containers.

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class BaseAgent

Base interface of agents.

abstractmethod static init(key: Array, *args, **kwargs) AgentState

Creates and initializes instance of the agent.

abstractmethod static update(state: AgentState, key: Array, *args, **kwargs) AgentState

Updates the state of the agent after performing some action and receiving a reward.

abstractmethod static sample(state: AgentState, key: Array, *args, **kwargs) any

Selects the next action based on the current environment and agent state.

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Space

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Space

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Space

Action space of the agent in Gymnasium format.

export(init_key: Array, state: AgentState = None, sample_only: bool = False) tuple[any, any, any]

Exports the agent to TensorFlow Lite format.

Parameters:
  • init_key (PRNGKey) – Key used to initialize the agent.

  • state (AgentState, optional) – State of the agent to be exported. If not specified, the agent is initialized with init_key.

  • sample_only (bool, optional) – If True, the exported agent will only be able to sample actions, but not update its state.

Deep Q-Learning (DQN)

class DQNState(params: dict, net_state: dict, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], replay_buffer: ReplayBuffer, prev_env_state: Array | ndarray | bool | number, epsilon: float | int)

Bases: AgentState, Mapping

Container for the state of the deep Q-learning agent.

params

Parameters of the Q-network.

Type:

dict

net_state

State of the Q-network.

Type:

dict

opt_state

Optimizer state.

Type:

optax.OptState

replay_buffer

Experience replay buffer.

Type:

ReplayBuffer

prev_env_state

Previous environment state.

Type:

Array

epsilon

\(\epsilon\)-greedy parameter.

Type:

Scalar

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class DQN(q_network: Module, obs_space_shape: Sequence[int | Any], act_space_size: int, optimizer: GradientTransformation = None, experience_replay_buffer_size: int = 10000, experience_replay_batch_size: int = 64, experience_replay_steps: int = 5, discount: float | int = 0.99, epsilon: float | int = 1.0, epsilon_decay: float | int = 0.999, epsilon_min: float | int = 0.001)

Bases: BaseAgent

Deep Q-learning agent [1] with \(\epsilon\)-greedy exploration and experience replay buffer. The agent uses a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman error. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action spaces.

Parameters:
  • q_network (nn.Module) – Architecture of the Q-network.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_size (int) – Size of the action space.

  • optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-network. If None, the Adam optimizer with learning rate 1e-3 is used.

  • experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.

  • experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.

  • experience_replay_steps (int, default=5) – Number of experience replay steps per update.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • epsilon (Scalar, default=1.0) – Initial \(\epsilon\)-greedy parameter. \(0 \leq \epsilon \leq 1\).

  • epsilon_decay (Scalar, default=0.999) – Epsilon decay factor. \(\epsilon_{t+1} = \epsilon_{t} * \epsilon_{decay}\). \(0 \leq \epsilon_{decay} \leq 1\).

  • epsilon_min (Scalar, default=0.01) – Minimum \(\epsilon\)-greedy parameter. \(0 \leq \epsilon_{min} \leq \epsilon\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, obs_space_shape: Sequence[int | Any], q_network: Module, optimizer: GradientTransformation, er: ExperienceReplay, epsilon: float | int) DQNState

Initializes the Q-network, optimizer and experience replay buffer with given parameters. The first state of the environment is assumed to be a tensor of zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • obs_space_shape (Shape) – The shape of the observation space.

  • q_network (nn.Module) – The Q-network.

  • optimizer (optax.GradientTransformation) – The optimizer.

  • er (ExperienceReplay) – The experience replay buffer.

  • epsilon (Scalar) – The initial \(\epsilon\)-greedy parameter.

Returns:

Initial state of the deep Q-learning agent.

Return type:

DQNState

static loss_fn(params: dict, key: Array, net_state: dict, params_target: dict, net_state_target: dict, batch: tuple, q_network: Module, discount: float | int) tuple[float | int, dict]

Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor, \(Q(s, a)\) is the Q-value of the state-action pair. Loss can be calculated on a batch of transitions.

Parameters:
  • params (dict) – The parameters of the Q-network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • net_state (dict) – The state of the Q-network.

  • params_target (dict) – The parameters of the target Q-network.

  • net_state_target (dict) – The state of the target Q-network.

  • batch (tuple) – A batch of transitions from the experience replay buffer.

  • q_network (nn.Module) – The Q-network.

  • discount (Scalar) – The discount factor.

Returns:

The loss and the new state of the Q-network.

Return type:

Tuple[Scalar, dict]

static update(state: DQNState, key: Array, env_state: Array | ndarray | bool | number, action: Array | ndarray | bool | number, reward: float | int, terminal: bool, step_fn: Callable, er: ExperienceReplay, experience_replay_steps: int, epsilon_decay: float | int, epsilon_min: float | int) DQNState

Appends the transition to the experience replay buffer and performs experience_replay_steps steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss using the loss_fn function and performing a gradient step on the Q-network. The \(\epsilon\)-greedy parameter is decayed by epsilon_decay.

Parameters:
  • state (DQNState) – The current state of the deep Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • action (Array) – The action taken by the agent.

  • reward (Scalar) – The reward received by the agent.

  • terminal (bool) – Whether the episode has terminated.

  • step_fn (Callable) – The function that performs a single gradient step on the Q-network.

  • er (ExperienceReplay) – The experience replay buffer.

  • experience_replay_steps (int) – The number of experience replay steps.

  • epsilon_decay (Scalar) – The decay rate of the \(\epsilon\)-greedy parameter.

  • epsilon_min (Scalar) – The minimum value of the \(\epsilon\)-greedy parameter.

Returns:

The updated state of the deep Q-learning agent.

Return type:

DQNState

static sample(state: DQNState, key: Array, env_state: Array | ndarray | bool | number, q_network: Module, act_space_size: int) int

Samples random action with probability \(\epsilon\) and the greedy action with probability \(1 - \epsilon\). The greedy action is the action with the highest Q-value.

Parameters:
  • state (DQNState) – The state of the deep Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • q_network (nn.Module) – The Q-network.

  • act_space_size (int) – The size of the action space.

Returns:

Selected action.

Return type:

int

Double Deep Q-Learning (DDQN)

class DDQNState(params: dict, net_state: dict, params_target: dict, net_state_target: dict, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], replay_buffer: ReplayBuffer, prev_env_state: Array | ndarray | bool | number, epsilon: float | int)

Bases: AgentState, Mapping

Container for the state of the double deep Q-learning agent.

params

Parameters of the main Q-network.

Type:

dict

net_state

State of the main Q-network.

Type:

dict

params_target

Parameters of the target Q-network.

Type:

dict

net_state_target

State of the target Q-network.

Type:

dict

opt_state

Optimizer state of the main Q-network.

Type:

optax.OptState

replay_buffer

Experience replay buffer.

Type:

ReplayBuffer

prev_env_state

Previous environment state.

Type:

Array

epsilon

\(\epsilon\)-greedy parameter.

Type:

Scalar

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class DDQN(q_network: Module, obs_space_shape: Sequence[int | Any], act_space_size: int, optimizer: GradientTransformation = None, experience_replay_buffer_size: int = 10000, experience_replay_batch_size: int = 64, experience_replay_steps: int = 5, discount: float | int = 0.99, epsilon: float | int = 1.0, epsilon_decay: float | int = 0.999, epsilon_min: float | int = 0.001, tau: float | int = 0.01)

Bases: BaseAgent

Double deep Q-learning agent [2] with \(\epsilon\)-greedy exploration and experience replay buffer. The agent uses two Q-networks to stabilize the learning process and avoid overestimation of the Q-values. The main Q-network is trained to minimize the Bellman error. The target Q-network is updated with a soft update. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action spaces.

Parameters:
  • q_network (nn.Module) – Architecture of the Q-networks.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_size (int) – Size of the action space.

  • optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-networks. If None, the Adam optimizer with learning rate 1e-3 is used.

  • experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.

  • experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.

  • experience_replay_steps (int, default=5) – Number of experience replay steps per update.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • epsilon (Scalar, default=1.0) – Initial \(\epsilon\)-greedy parameter. \(0 \leq \epsilon \leq 1\).

  • epsilon_decay (Scalar, default=0.999) – Epsilon decay factor. \(\epsilon_{t+1} = \epsilon_{t} * \epsilon_{decay}\). \(0 \leq \epsilon_{decay} \leq 1\).

  • epsilon_min (Scalar, default=0.01) – Minimum \(\epsilon\)-greedy parameter. \(0 \leq \epsilon_{min} \leq \epsilon\).

  • tau (Scalar, default=0.01) – Soft update factor. \(\tau = 0.0\) means no soft update, \(\tau = 1.0\) means hard update. \(0 \leq \tau \leq 1\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, obs_space_shape: Sequence[int | Any], q_network: Module, optimizer: GradientTransformation, er: ExperienceReplay, epsilon: float | int) DDQNState

Initializes the Q-networks, optimizer and experience replay buffer with given parameters. The first state of the environment is assumed to be a tensor of zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • obs_space_shape (Shape) – The shape of the observation space.

  • q_network (nn.Module) – The Q-network.

  • optimizer (optax.GradientTransformation) – The optimizer.

  • er (ExperienceReplay) – The experience replay buffer.

  • epsilon (Scalar) – The initial \(\epsilon\)-greedy parameter.

Returns:

Initial state of the double Q-learning agent.

Return type:

DDQNState

static loss_fn(params: dict, key: Array, state: DDQNState, batch: tuple, q_network: Module, discount: float | int) tuple[float | int, dict]

Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r + \gamma \max_{a'} Q'(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor, \(Q(s, a)\) is the Q-value of the main Q-network, \(Q'(s', a')\) is the Q-value of the target Q-network. Loss can be calculated on a batch of transitions.

Parameters:
  • params (dict) – The parameters of the Q-network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • state (DDQNState) – The state of the double deep Q-learning agent.

  • batch (tuple) – A batch of transitions from the experience replay buffer.

  • q_network (nn.Module) – The Q-network.

  • discount (Scalar) – The discount factor.

Returns:

The loss and the new state of the Q-network.

Return type:

tuple[Scalar, dict]

static update(state: DDQNState, key: Array, env_state: Array | ndarray | bool | number, action: Array | ndarray | bool | number, reward: float | int, terminal: bool, step_fn: Callable, er: ExperienceReplay, experience_replay_steps: int, epsilon_decay: float | int, epsilon_min: float | int, tau: float | int) DDQNState

Appends the transition to the experience replay buffer and performs experience_replay_steps steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss using the loss_fn function, performing a gradient step on the main Q-network, and soft updating the target Q-network. Soft update of the parameters is defined as \(\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}\). The \(\epsilon\)-greedy parameter is decayed by epsilon_decay.

Parameters:
  • state (DDQNState) – The current state of the double Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • action (Array) – The action taken by the agent.

  • reward (Scalar) – The reward received by the agent.

  • terminal (bool) – Whether the episode has terminated.

  • step_fn (Callable) – The function that performs a single gradient step on the Q-network.

  • er (ExperienceReplay) – The experience replay buffer.

  • experience_replay_steps (int) – The number of experience replay steps.

  • epsilon_decay (Scalar) – The decay rate of the \(\epsilon\)-greedy parameter.

  • epsilon_min (Scalar) – The minimum value of the \(\epsilon\)-greedy parameter.

  • tau (Scalar) – The soft update parameter.

Returns:

The updated state of the double Q-learning agent.

Return type:

DDQNState

static sample(state: DDQNState, key: Array, env_state: Array | ndarray | bool | number, q_network: Module, act_space_size: int) int

Samples random action with probability \(\epsilon\) and the greedy action with probability \(1 - \epsilon\) using the main Q-network. The greedy action is the action with the highest Q-value.

Parameters:
  • state (DDQNState) – The state of the double Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • q_network (nn.Module) – The Q-network.

  • act_space_size (int) – The size of the action space.

Returns:

Selected action.

Return type:

int

Deep Expected SARSA

class ExpectedSarsaState(params: dict, net_state: dict, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], replay_buffer: ReplayBuffer, prev_env_state: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the deep expected SARSA agent.

params

Parameters of the Q-network.

Type:

dict

net_state

State of the Q-network.

Type:

dict

opt_state

Optimizer state.

Type:

optax.OptState

replay_buffer

Experience replay buffer.

Type:

ReplayBuffer

prev_env_state

Previous environment state.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class ExpectedSarsa(q_network: Module, obs_space_shape: Sequence[int | Any], act_space_size: int, optimizer: GradientTransformation = None, experience_replay_buffer_size: int = 10000, experience_replay_batch_size: int = 64, experience_replay_steps: int = 5, discount: float | int = 0.99, tau: float | int = 1.0)

Bases: BaseAgent

Deep expected SARSA agent with temperature parameter \(\tau\) and experience replay buffer. The agent uses a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman error. This agent follows the on-policy learning paradigm and is suitable for environments with discrete action spaces.

Parameters:
  • q_network (nn.Module) – Architecture of the Q-network.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_size (int) – Size of the action space.

  • optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-network. If None, the Adam optimizer with learning rate 1e-3 is used.

  • experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.

  • experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.

  • experience_replay_steps (int, default=5) – Number of experience replay steps per update.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • tau (Scalar, default=1.0) – Temperature parameter. \(\tau = 0.0\) means no exploration, \(\tau = \infty\) means infinite exploration. \(\tau > 0\)

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, obs_space_shape: Sequence[int | Any], q_network: Module, optimizer: GradientTransformation, er: ExperienceReplay) ExpectedSarsaState

Initializes the Q-network, optimizer and experience replay buffer with given parameters. The first state of the environment is assumed to be a tensor of zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • obs_space_shape (Shape) – The shape of the observation space.

  • q_network (nn.Module) – The Q-network.

  • optimizer (optax.GradientTransformation) – The optimizer.

  • er (ExperienceReplay) – The experience replay buffer.

Returns:

Initial state of the deep expected SARSA agent.

Return type:

ExpectedSarsaState

static loss_fn(params: dict, key: Array, net_state: dict, params_target: dict, net_state_target: dict, batch: tuple, q_network: Module, discount: float | int, tau: float | int) tuple[float | int, dict]

Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r + \gamma \sum_{a'} \pi(a'|s') Q(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor, \(Q(s, a)\) is the Q-value of the state-action pair. Loss can be calculated on a batch of transitions.

Parameters:
  • params (dict) – The parameters of the Q-network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • net_state (dict) – The state of the Q-network.

  • params_target (dict) – The parameters of the target Q-network.

  • net_state_target (dict) – The state of the target Q-network.

  • batch (tuple) – A batch of transitions from the experience replay buffer.

  • q_network (nn.Module) – The Q-network.

  • discount (Scalar) – The discount factor.

  • tau (Scalar) – The temperature parameter.

Returns:

The loss and the new state of the Q-network.

Return type:

Tuple[Scalar, dict]

static update(state: ExpectedSarsaState, key: Array, env_state: Array | ndarray | bool | number, action: Array | ndarray | bool | number, reward: float | int, terminal: bool, q_network: Module, step_fn: Callable, er: ExperienceReplay, experience_replay_steps: int) ExpectedSarsaState

Appends the transition to the experience replay buffer and performs experience_replay_steps steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss using the loss_fn function and performing a gradient step on the Q-network.

Parameters:
  • state (ExpectedSarsaState) – The current state of the deep expected SARSA agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • action (Array) – The action taken by the agent.

  • reward (Scalar) – The reward received by the agent.

  • terminal (bool) – Whether the episode has terminated.

  • q_network (nn.Module) – The Q-network.

  • step_fn (Callable) – The function that performs a single gradient step on the Q-network.

  • er (ExperienceReplay) – The experience replay buffer.

  • experience_replay_steps (int) – The number of experience replay steps.

Returns:

The updated state of the deep expected SARSA agent.

Return type:

ExpectedSarsaState

static sample(state: ExpectedSarsaState, key: Array, env_state: Array | ndarray | bool | number, q_network: Module, act_space_size: int, tau: float | int) int

Selects an action using the softmax policy with the temperature parameter \(\tau\):

\[\pi(a|s) = \frac{e^{Q(s, a) / \tau}}{\sum_{a'} e^{Q(s, a') / \tau}}\]
Parameters:
  • state (ExpectedSarsaState) – The state of the deep expected SARSA agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • q_network (nn.Module) – The Q-network.

  • act_space_size (int) – The size of the action space.

  • tau (Scalar) – The temperature parameter.

Returns:

Selected action.

Return type:

int

Deep Deterministic Policy Gradient (DDPG)

class DDPGState(q_params: dict, q_net_state: dict, q_params_target: dict, q_net_state_target: dict, q_opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], a_params: dict, a_net_state: dict, a_params_target: dict, a_net_state_target: dict, a_opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], replay_buffer: ReplayBuffer, prev_env_state: Array | ndarray | bool | number, noise: float | int)

Bases: AgentState, Mapping

Container for the state of the deep deterministic policy gradient agent.

q_params

Parameters of the Q-network.

Type:

dict

q_net_state

State of the Q-network.

Type:

dict

q_params_target

Parameters of the target Q-network.

Type:

dict

q_net_state_target

State of the target Q-network.

Type:

dict

q_opt_state

Optimizer state of the Q-network.

Type:

optax.OptState

a_params

Parameters of the policy network.

Type:

dict

a_net_state

State of the policy network.

Type:

dict

a_params_target

Parameters of the target policy network.

Type:

dict

a_net_state_target

State of the target policy network.

Type:

dict

a_opt_state

Optimizer state of the policy network.

Type:

optax.OptState

replay_buffer

Experience replay buffer.

Type:

ReplayBuffer

prev_env_state

Previous environment state.

Type:

Array

noise

Current noise level.

Type:

Scalar

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class DDPG(q_network: Module, a_network: Module, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any], min_action: Array | ndarray | bool | number | float | int, max_action: Array | ndarray | bool | number | float | int, q_optimizer: GradientTransformation = None, a_optimizer: GradientTransformation = None, experience_replay_buffer_size: int = 10000, experience_replay_batch_size: int = 64, experience_replay_steps: int = 5, discount: float | int = 0.99, noise: float | int = None, noise_decay: float | int = 0.99, noise_min: float | int = 0.01, tau: float | int = 0.01)

Bases: BaseAgent

Deep deterministic policy gradient [3] [4] agent with white Gaussian noise exploration and experience replay buffer. The agent simultaneously learns a Q-function and a policy. The Q-function is updated using the Bellman equation. The policy is learned using the gradient of the Q-function with respect to the policy parameters to maximize the Q-value. The agent uses two Q-networks and two policy networks to stabilize the learning process and avoid overestimation. The target networks are updated with a soft update. This agent follows the off-policy learning paradigm and is suitable for environments with continuous action spaces.

Parameters:
  • q_network (nn.Module) – Architecture of the Q-networks. The input to the network should be two tensors of observations and actions respectively.

  • a_network (nn.Module) – Architecture of the policy networks.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_shape (Shape) – Shape of the action space.

  • min_action (Scalar or Array) – Minimum action value.

  • max_action (Scalar or Array) – Maximum action value.

  • q_optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-networks. If None, the Adam optimizer with learning rate 1e-3 is used.

  • a_optimizer (optax.GradientTransformation, optional) – Optimizer of the policy networks. If None, the Adam optimizer with learning rate 1e-3 is used.

  • experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.

  • experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.

  • experience_replay_steps (int, default=5) – Number of experience replay steps per update.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • noise (Scalar, default=(max_action - min_action) / 2) – Initial Gaussian noise level. \(0 \leq \sigma\).

  • noise_decay (Scalar, default=0.99) – Gaussian noise decay factor. \(\sigma_{t+1} = \sigma_{t} * \sigma_{decay}\). \(0 \leq \sigma_{decay} \leq 1\).

  • noise_min (Scalar, default=0.01) – Minimum Gaussian noise level. \(0 \leq \sigma_{min} \leq \sigma\).

  • tau (Scalar, default=0.01) – Soft update factor. \(\tau = 0.0\) means no soft update, \(\tau = 1.0\) means hard update. \(0 \leq \tau \leq 1\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Box

Action space of the agent in Gymnasium format.

static init(key: Array, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any], q_network: Module, a_network: Module, q_optimizer: GradientTransformation, a_optimizer: GradientTransformation, er: ExperienceReplay, noise: float | int) DDPGState

Initializes the Q-networks and the policy networks, optimizers, and experience replay buffer. The first state of the environment is assumed to be a tensor of zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • obs_space_shape (Shape) – The shape of the observation space.

  • act_space_shape (Shape) – The shape of the action space.

  • q_network (nn.Module) – The Q-network.

  • a_network (nn.Module) – The policy network.

  • q_optimizer (optax.GradientTransformation) – The Q-network optimizer.

  • a_optimizer (optax.GradientTransformation) – The policy network optimizer.

  • er (ExperienceReplay) – The experience replay buffer.

  • noise (Scalar) – The initial noise value.

Returns:

Initial state of the deep deterministic policy gradient agent.

Return type:

DDPGState

static q_loss_fn(q_params: dict, key: Array, state: DDPGState, batch: tuple, q_network: Module, a_network: Module, discount: float | int) tuple[float | int, dict]

Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r + \gamma \max Q'(s', \pi'(s')) - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor, \(Q(s, a)\) is the Q-value of the main Q-network, \(Q'(s, a)\) is the Q-value of the target Q-network, and \(\pi'(s)\) is the action of the target policy network. The policy network parameters are considered as fixed. Loss can be calculated on a batch of transitions.

Parameters:
  • q_params (dict) – The parameters of the Q-network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • state (DDPGState) – The state of the deep deterministic policy gradient agent.

  • batch (tuple) – A batch of transitions from the experience replay buffer.

  • q_network (nn.Module) – The Q-network.

  • a_network (nn.Module) – The policy network.

  • discount (Scalar) – The discount factor.

Returns:

The loss and the new state of the Q-network.

Return type:

tuple[Scalar, dict]

static a_loss_fn(a_params: dict, key: Array, state: DDPGState, batch: tuple, q_network: Module, a_network: Module) tuple[float | int, dict]

The policy network is updated using the gradient of the Q-network to maximize the Q-value of the current state and action \(\max_{\theta} \mathbb{E}_{s, a} \left[ Q(s, \pi_{\theta}(s)) \right]\). Q-network parameters are considered as fixed. The policy network can be updated on a batch of transitions.

Parameters:
  • a_params (dict) – The parameters of the policy network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • state (DDPGState) – The state of the deep deterministic policy gradient agent.

  • batch (tuple) – A batch of transitions from the experience replay buffer.

  • q_network (nn.Module) – The Q-network.

  • a_network (nn.Module) – The policy network.

Returns:

The loss and the new state of the policy network.

Return type:

tuple[Scalar, dict]

static update(state: DDPGState, key: Array, env_state: Array | ndarray | bool | number, action: Array | ndarray | bool | number, reward: float | int, terminal: bool, q_step_fn: Callable, a_step_fn: Callable, er: ExperienceReplay, experience_replay_steps: int, noise_decay: float | int, noise_min: float | int, tau: float | int) DDPGState

Appends the transition to the experience replay buffer and performs experience_replay_steps steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the Q-network loss and the policy network loss using q_loss_fn and a_loss_fn respectively, performing a gradient step on both networks, and soft updating the target networks. Soft update of the parameters is defined as \(\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}\).The noise parameter is decayed by noise_decay.

Parameters:
  • state (DDPGState) – The current state of the double Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • action (Array) – The action taken by the agent.

  • reward (Scalar) – The reward received by the agent.

  • terminal (bool) – Whether the episode has terminated.

  • q_step_fn (Callable) – The function that performs a single gradient step on the Q-network.

  • a_step_fn (Callable) – The function that performs a single gradient step on the policy network.

  • er (ExperienceReplay) – The experience replay buffer.

  • experience_replay_steps (int) – The number of experience replay steps.

  • noise_decay (Scalar) – The decay rate of the noise parameter.

  • noise_min (Scalar) – The minimum value of the noise parameter.

  • tau (Scalar) – The soft update parameter.

Returns:

The updated state of the deep deterministic policy gradient agent.

Return type:

DDPGState

static sample(state: DDPGState, key: Array, env_state: Array | ndarray | bool | number, a_network: Module, min_action: float | int, max_action: float | int) Array | ndarray | bool | number | float | int

Calculates deterministic action using the policy network. Then adds white Gaussian noise with standard deviation state.noise to the action and clips it to the range \([min\_action, max\_action]\).

Parameters:
  • state (DDPGState) – The state of the double Q-learning agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_state (Array) – The current state of the environment.

  • a_network (nn.Module) – The policy network.

  • min_action (Scalar or Array) – The minimum value of the action.

  • max_action (Scalar or Array) – The maximum value of the action.

Returns:

Selected action.

Return type:

Scalar or Array

Proximal Policy Optimization (PPO) [discrete]

class PPOState(params: dict, net_state: dict, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], rollout_memory: RolloutMemory, prev_env_states: Array | ndarray | bool | number, counter: int)

Bases: AgentState, Mapping

Container for the state of the PPO agent.

params

Parameters of the agent network.

Type:

dict

net_state

State of the agent network.

Type:

dict

opt_state

Optimizer state.

Type:

optax.OptState

rollout_memory

Rollout buffer storing the trajectories.

Type:

RolloutMemory

prev_env_states

Previous environment state.

Type:

Array

counter

Number of the current step during the rollout.

Type:

int

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class PPODiscrete(network: Module, obs_space_shape: Sequence[int | Any], act_space_size: int, optimizer: GradientTransformation = None, discount: float | int = 0.99, lambda_gae: float | int = 0.9, normalize_advantage: bool = True, clip_coef: float | int = 0.2, clip_value: bool = True, clip_grad: float | int = 0.5, entropy_coef: float | int = 0.01, value_coef: float | int = 0.5, rollout_length: int = 512, num_envs: int = 1, batch_size: int = 128, num_epochs: int = 4)

Bases: BaseAgent

Proximal Policy Optimization (PPO) agent [5]. This implementation uses the clipped surrogate objective. The policy and value functions should be represented by a single Flax module with two outputs: the action logits and the state value. The network should be able to process a batch of observations. The actions are sampled from a categorical distribution, while the value function is used to compute the advantages using Generalized Advantage Estimation (GAE) [6]. The agent is trained using mini-batch gradient descent. This agent follows the on-policy learning paradigm and is suitable for environments with discrete action spaces.

Parameters:
  • network (nn.Module) – Architecture of the PPO agent network.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_size (int) – Size of the action space.

  • optimizer (optax.GradientTransformation, optional) – Optimizer of the network. If None, the Adam optimizer with learning rate 3e-4 and \(\epsilon\) = 1e-5 is used.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • lambda_gae (Scalar, default=0.9) – GAE parameter. \(\lambda = 0.0\) means no GAE, \(\lambda = 1.0\) means pure Monte Carlo advantage. \(0 \leq \lambda \leq 1\)

  • normalize_advantage (bool, default=True) – If True, the advantages are normalized to have mean 0 and standard deviation 1.

  • clip_coef (Scalar, default=0.2) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5].

  • clip_value (bool, default=True) – If True, the loss for the value function is clipped.

  • clip_grad (Scalar, default=0.5) – If not None, the gradients are clipped to have a maximum norm of clip_grad.

  • entropy_coef (Scalar, default=0.01) – Coefficient for the entropy bonus.

  • value_coef (Scalar, default=0.5) – Coefficient for the value function loss.

  • rollout_length (int, default=512) – Length of the rollout buffer.

  • num_envs (int, default=1) – Number of parallel environments.

  • batch_size (int, default=128) – Size of the batch to be sampled from the rollout buffer.

  • num_epochs (int, default=4) – Number of update epochs to perform on the rollout buffer.

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: MultiDiscrete

Action space of the agent in Gymnasium format.

static init(key: Array, num_envs: int, obs_space_shape: Sequence[int | Any], network: Module, optimizer: GradientTransformation, rb: RolloutBuffer) PPOState

Initializes the PPO network, optimizer and rollout buffer with given parameters. The first state of the environment is assumed to be a tensor of zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • num_envs (int) – The number of parallel environments.

  • obs_space_shape (Shape) – The shape of the observation space.

  • network (nn.Module) – The agent network.

  • optimizer (optax.GradientTransformation) – The optimizer.

  • rb (RolloutBuffer) – The rollout buffer functions.

Returns:

Initial state of the PPO agent.

Return type:

PPOState

static loss_fn(params: dict, key: Array, net_state: dict, batch: tuple, network: Module, normalize_advantage: bool, clip_coef: float | int, clip_value: bool, entropy_coef: float | int, value_coef: float | int) tuple[float | int, dict]

Loss is the clipped surrogate objective with value function loss and entropy regularization:

\[\mathcal{L}(\theta) = \mathbb{E}_t \Big[ -\min\big( A_t r_t(\theta), A_t \, \mathrm{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \big) + c_v \, \mathcal{L}_v(\theta) - c_e \, \mathcal{H}[\pi_\theta](s_t) \Big]\]

where

  • \(r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{old}}(a_t \mid s_t)}\) is the probability ratio between the new and old policies,

  • \(A_t\) is the advantage estimate,

  • \(\epsilon\) is the clipping coefficient,

  • \(\mathcal{L}_v(\theta) = \tfrac{1}{2}\,(V_\theta(s_t) - R_t)^2\) is the value function loss, possibly clipped,

  • \(R_t\) is the discounted return,

  • \(\mathcal{H}[\pi_\theta](s_t)\) is the entropy of the policy at state \(s_t\),

  • \(c_v\) and \(c_e\) are coefficients for the value loss and entropy bonus, respectively.

Loss is calculated on a batch of transitions sampled from the rollout buffer.

Parameters:
  • params (dict) – The parameters of the agent network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • net_state (dict) – The state of the agent network.

  • batch (tuple) – A batch of transitions from the rollout buffer.

  • network (nn.Module) – The agent network.

  • normalize_advantage (bool) – If True, the advantages are normalized to have mean 0 and standard deviation 1.

  • clip_coef (Scalar) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5]

  • clip_value (bool) – If True, the loss for the value function is clipped.

  • entropy_coef (Scalar) – Coefficient for the entropy bonus.

  • value_coef (Scalar) – Coefficient for the value function loss.

Returns:

The loss and the new state of the agent network.

Return type:

Tuple[Scalar, dict]

static update(state: PPOState, key: Array, env_states: Array | ndarray | bool | number, actions: Array | ndarray | bool | number, rewards: float | int, terminals: bool, network: Module, step_fn: Callable, rb: RolloutBuffer, num_envs: int, rollout_length: int, batch_size: int, num_epochs: int) PPOState

Appends the transition to the on-policy rollout buffer. Once the rollout buffer reaches rollout_length steps, computes advantages and returns using GAE, shuffles and flattens the buffer, and performs multiple gradient updates using mini-batches.

Parameters:
  • state (PPOState) – The current state of the PPO agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current states of the environments.

  • actions (Array) – The actions taken by the agent.

  • rewards (Scalar) – The rewards received by the agent.

  • terminals (bool) – Whether the episodes have terminated.

  • network (nn.Module) – The agent network.

  • step_fn (Callable) – The function that performs a single gradient step on the agent network.

  • rb (RolloutBuffer) – The rollout buffer functions.

  • rollout_length (int) – The length of the rollout buffer.

  • num_envs (int) – The number of parallel environments.

  • batch_size (int) – The size of the batch sampled from the rollout buffer.

  • num_epochs (int) – The number of gradient steps to perform.

Returns:

The updated state of the PPO agent.

Return type:

PPOState

static sample(state: PPOState, key: Array, env_states: Array | ndarray | bool | number, network: Module) Array | ndarray | bool | number

Samples actions from the categorical distribution defined by the logits computed by the agent network.

Parameters:
  • state (PPOState) – The state of the PPO agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current state of the environment.

  • network (nn.Module) – The agent network.

Returns:

Selected actions.

Return type:

Array

Proximal Policy Optimization (PPO) [continuous]

class PPOContinuous(network: Module, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any], min_action: Array | ndarray | bool | number | float | int, max_action: Array | ndarray | bool | number | float | int, optimizer: GradientTransformation = None, discount: float | int = 0.99, lambda_gae: float | int = 0.9, normalize_advantage: bool = True, clip_coef: float | int = 0.2, clip_value: bool = True, clip_grad: float | int = 0.5, entropy_coef: float | int = 0.01, value_coef: float | int = 0.5, rollout_length: int = 512, num_envs: int = 1, batch_size: int = 128, num_epochs: int = 4)

Bases: PPODiscrete

Proximal Policy Optimization (PPO) agent [5]. This implementation uses the clipped surrogate objective. The policy and value functions should be represented by a single Flax module with three outputs: the actions, the log standard deviation of the actions, and the state value. The network should be able to process a batch of observations. The actions are sampled from a diagonal Gaussian distribution, while the value function is used to compute the advantages using Generalized Advantage Estimation (GAE) [6]. The agent is trained using mini-batch gradient descent. This agent follows the on-policy learning paradigm and is suitable for environments with continuous action spaces.

Parameters:
  • network (nn.Module) – Architecture of the PPO agent network.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_shape (Shape) – Shape of the action space.

  • min_action (Scalar or Array) – Minimum action value.

  • max_action (Scalar or Array) – Maximum action value.

  • optimizer (optax.GradientTransformation, optional) – Optimizer of the network. If None, the Adam optimizer with learning rate 3e-4 and \(\epsilon\) = 1e-5 is used.

  • discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)

  • lambda_gae (Scalar, default=0.9) – GAE parameter. \(\lambda = 0.0\) means no GAE, \(\lambda = 1.0\) means pure Monte Carlo advantage. \(0 \leq \lambda \leq 1\)

  • normalize_advantage (bool, default=True) – If True, the advantages are normalized to have mean 0 and standard deviation 1.

  • clip_coef (Scalar, default=0.2) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5].

  • clip_value (bool, default=True) – If True, the loss for the value function is clipped.

  • clip_grad (Scalar, default=0.5) – If not None, the gradients are clipped to have a maximum norm of clip_grad.

  • entropy_coef (Scalar, default=0.01) – Coefficient for the entropy bonus.

  • value_coef (Scalar, default=0.5) – Coefficient for the value function loss.

  • rollout_length (int, default=512) – Length of the rollout buffer.

  • num_envs (int, default=1) – Number of parallel environments.

  • batch_size (int, default=128) – Size of the batch to be sampled from the rollout buffer.

  • num_epochs (int, default=4) – Number of update epochs to perform on the rollout buffer.

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Box

Action space of the agent in Gymnasium format.

static loss_fn(params: dict, key: Array, net_state: dict, batch: tuple, network: Module, normalize_advantage: bool, clip_coef: float | int, clip_value: bool, entropy_coef: float | int, value_coef: float | int) tuple[float | int, dict]

Loss is the clipped surrogate objective with value function loss and entropy regularization:

\[\mathcal{L}(\theta) = \mathbb{E}_t \Big[ -\min\big( A_t r_t(\theta), A_t \, \mathrm{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \big) + c_v \, \mathcal{L}_v(\theta) - c_e \, \mathcal{H}[\pi_\theta](s_t) \Big]\]

where

  • \(r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{old}}(a_t \mid s_t)}\) is the probability ratio between the new and old policies,

  • \(A_t\) is the advantage estimate,

  • \(\epsilon\) is the clipping coefficient,

  • \(\mathcal{L}_v(\theta) = \tfrac{1}{2}\,(V_\theta(s_t) - R_t)^2\) is the value function loss, possibly clipped,

  • \(R_t\) is the discounted return,

  • \(\mathcal{H}[\pi_\theta](s_t)\) is the entropy of the policy at state \(s_t\),

  • \(c_v\) and \(c_e\) are coefficients for the value loss and entropy bonus, respectively.

Loss is calculated on a batch of transitions sampled from the rollout buffer.

Parameters:
  • params (dict) – The parameters of the agent network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • net_state (dict) – The state of the agent network.

  • batch (tuple) – A batch of transitions from the rollout buffer.

  • network (nn.Module) – The agent network.

  • normalize_advantage (bool) – If True, the advantages are normalized to have mean 0 and standard deviation 1.

  • clip_coef (Scalar) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5]

  • clip_value (bool) – If True, the loss for the value function is clipped.

  • entropy_coef (Scalar) – Coefficient for the entropy bonus.

  • value_coef (Scalar) – Coefficient for the value function loss.

Returns:

The loss and the new state of the agent network.

Return type:

Tuple[Scalar, dict]

static update(state: PPOState, key: Array, env_states: Array | ndarray | bool | number, actions: Array | ndarray | bool | number, rewards: float | int, terminals: bool, network: Module, step_fn: Callable, rb: RolloutBuffer, num_envs: int, rollout_length: int, batch_size: int, num_epochs: int) PPOState

Appends the transition to the on-policy rollout buffer. Once the rollout buffer reaches rollout_length steps, computes advantages and returns using GAE, shuffles and flattens the buffer, and performs multiple gradient updates using mini-batches.

Parameters:
  • state (PPOState) – The current state of the PPO agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current states of the environments.

  • actions (Array) – The actions taken by the agent.

  • rewards (Scalar) – The rewards received by the agent.

  • terminals (bool) – Whether the episodes have terminated.

  • network (nn.Module) – The agent network.

  • step_fn (Callable) – The function that performs a single gradient step on the agent network.

  • rb (RolloutBuffer) – The rollout buffer functions.

  • rollout_length (int) – The length of the rollout buffer.

  • num_envs (int) – The number of parallel environments.

  • batch_size (int) – The size of the batch sampled from the rollout buffer.

  • num_epochs (int) – The number of gradient steps to perform.

Returns:

The updated state of the PPO agent.

Return type:

PPOState

static sample(state: PPOState, key: Array, env_states: Array | ndarray | bool | number, network: Module, min_action: Array | ndarray | bool | number | float | int, max_action: Array | ndarray | bool | number | float | int) Array | ndarray | bool | number

Samples actions from the diagonal Gaussian policy defined by the mean and log standard deviation computed by the agent network.

Parameters:
  • state (PPOState) – The state of the PPO agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current state of the environment.

  • network (nn.Module) – The agent network.

  • min_action (Numeric) – Minimum action value.

  • max_action (Numeric) – Maximum action value.

Returns:

Selected actions.

Return type:

Array

Evosax wrapper

class EvosaxState(es_state: State, population: dict, best_params: Params, fitness: Array | ndarray | bool | number, counter: int, terminals: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the evosax agent.

es_state

The state of the evosax algorithm.

Type:

State

population

The current population of the evolution strategy algorithm.

Type:

dict

best_params

The best parameters found so far.

Type:

Params

fitness

The fitness values of the current population.

Type:

Array

counter

Number of the current step of the fitness evaluation.

Type:

int

terminals

Whether the episodes have terminated.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class Evosax(network: Module, evo_strategy: type, population_size: int, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any], evo_strategy_kwargs: dict = None, evo_strategy_default_params: dict = None, num_eval_steps: int = None)

Bases: BaseAgent

Evolution strategies (ES)-based agent using the evosax library [12]. This implementation maintains a population of candidate solutions (parameter vectors), evaluates them in parallel across environments, and updates the population by applying an evolutionary algorithm. Unlike gradient-based RL methods, this agent does not rely on backpropagation through the value or policy network. Instead, the network parameters are evolved using black-box optimization. This agent is suitable for environments with both discrete and continuous action spaces.

Note! The user is responsible for providing appropriate network output in the correct format (e.g., discrete actions should be sampled from logits with jax.random.categorical inside the network definition).

Note! This agent does not discount future rewards, therefore, the fitness is computed as a simple sum of rewards obtained during the evaluation phase.

Note! This agent is compatible only with distribution-based evolution strategies from the evosax library (see this list for available algorithms). Population-based methods (listed here will be supported in future releases.

Parameters:
  • network (nn.Module) – Architecture of the PPO agent network.

  • evo_strategy (type) – Evolution strategy class from evosax.

  • population_size (int) – Size of the population.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_shape (Shape, default=(1,)) – Shape of the action space. For discrete action spaces, use (1,).

  • evo_strategy_kwargs (dict, default=None) – Parameters for the evolution strategy initialization. The population size and initial solution are set automatically.

  • evo_strategy_default_params (dict, default=None) – Custom default parameters for the evolution strategy. If None, the default parameters are used.

  • num_eval_steps (int, default=None) – Number of evaluation steps. If None, the evaluation runs until all episodes end.

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Box

Action space of the agent in Gymnasium format.

static init(key: Array, population_size: int, variables: dict, evo_strategy: EvolutionaryAlgorithm, evo_strategy_default_params: dict) EvosaxState

Initializes the evolution strategy state and the population. The fitness values, step counter, and terminals are initialized to zeros.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • population_size (int) – The size of the population.

  • variables (dict) – The initialized parameters of the agent network.

  • evo_strategy (EvolutionaryAlgorithm) – Initialized evosax evolution strategy.

  • evo_strategy_default_params (dict) – Custom default parameters for the evolution strategy.

Returns:

Initial state of the evosax agent.

Return type:

EvosaxState

static update(state: EvosaxState, key: Array, env_states: Array | ndarray | bool | number, actions: Array | ndarray | bool | number, rewards: float | int, terminals: bool, num_eval_steps: int, evo_strategy: EvolutionaryAlgorithm) EvosaxState

Updates the agent state after one evaluation step of the population. The method accumulates rewards into fitness values for each individual and tracks episode terminations. Once the evaluation is considered complete, either because all episodes have terminated (num_eval_steps=None) or because a fixed number of steps has been reached (num_eval_steps specified), the population is evolved. The evolution strategy’s tell method is called with the negative fitness values (as evosax minimizes the fitness), and a new population is generated using the ask method. The best parameters found so far are stored in the state.

Parameters:
  • state (EvosaxState) – The current state of the evosax agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current states of the environments.

  • actions (Array) – The actions taken by the agent.

  • rewards (Scalar) – The rewards received by the agent.

  • terminals (bool) – Whether the episodes have terminated.

  • num_eval_steps (int) – Number of evaluation steps. If None, the evaluation runs until all episodes end.

  • evo_strategy (EvolutionaryAlgorithm) – The evosax evolution strategy.

Returns:

The updated state of the evosax agent.

Return type:

EvosaxState

static sample(state: EvosaxState, key: Array, env_states: Array | ndarray | bool | number, population_size: int, network: Module, params_format_fn: Callable) Array | ndarray | bool | number

Returns actions computed by the agents in the population. Note that the user is responsible for providing network output in the correct format.

Parameters:
  • state (EvosaxState) – The state of the PPO agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • env_states (Array) – The current state of the environment.

  • population_size (int) – The size of the population.

  • network (nn.Module) – The agent network.

  • params_format_fn (Callable) – Function that formats the flattened parameters of the population members to the original neural network parameter format.

Returns:

Selected actions.

Return type:

Array

Epsilon-greedy

class EGreedyState(Q: Array | ndarray | bool | number, N: Array | ndarray | bool | number, e: float | int)

Bases: AgentState, Mapping

Container for the state of the \(\epsilon\)-greedy agent.

Q

Action-value function estimates for each arm.

Type:

Array

N

Number of tries for each arm.

Type:

Array

e

Experiment rate (epsilon).

Type:

float

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class EGreedy(n_arms: int, e: float | int, e_min: float | int = 0.0, e_decay: float | int = 1.0, optimistic_start: float | int = 0.0, alpha: float | int = 0.0)

Bases: BaseAgent

Epsilon-greedy [7] agent with an optimistic start behavior and optional exponential recency-weighted average update. It selects a random action from a set of all actions \(\mathscr{A}\) with probability \(\epsilon\) (exploration), otherwise it chooses the currently best action (exploitation). Epsilon can be decayed over time to shift the policy from exploration to exploitation.

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).

  • e (float) – Initial experiment rate (epsilon). \(\epsilon \in [0, 1]\).

  • e_min (float, default=0.0) – Minimum value of the experiment rate. \(\epsilon_{\min} \in [0, 1]\).

  • e_decay (float, default=1.0) – Decay factor for the experiment rate. \(\epsilon_{\text{decay}} \in [0, 1]\).

  • optimistic_start (float, default=0.0) – Interpreted as the optimistic start to encourage exploration in the early stages.

  • alpha (float, default=0.0) – If non-zero, exponential recency-weighted average is used to update \(Q\) values. \(\alpha \in [0, 1]\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int, e: float | int, optimistic_start: float | int) EGreedyState

Creates and initializes instance of the \(\epsilon\)-greedy agent for n_arms arms. Action-value function estimates are set to optimistic_start value and the number of tries is one for each arm.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

  • e (float) – Initial experiment rate (epsilon).

  • optimistic_start (float) – Interpreted as the optimistic start to encourage exploration in the early stages.

Returns:

Initial state of the \(\epsilon\)-greedy agent.

Return type:

EGreedyState

static update(state: EGreedyState, key: Array, action: int, reward: float | int, alpha: float | int, e_decay: float | int, e_min: float | int) EGreedyState

In the stationary case, the action-value estimate for a given arm is updated as \(Q_{t + 1} = Q_t + \frac{1}{t} \lbrack R_t - Q_t \rbrack\) after receiving reward \(R_t\) at step \(t\) and the number of tries for the corresponding arm is incremented. In the non-stationary case, the update follows the equation \(Q_{t + 1} = Q_t + \alpha \lbrack R_t - Q_t \rbrack\). Exploration rate \(\epsilon\) can be decayed over time following the equation \(\epsilon_{t + 1} = \max(\epsilon_{\min}, \epsilon_t \times \epsilon_{\text{decay}})\).

Parameters:
  • state (EGreedyState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (float) – Reward collected by the agent after taking the previous action.

  • alpha (float) – Exponential recency-weighted average factor (used when \(\alpha > 0\)).

  • e_decay (float) – Decay factor for the experiment rate.

  • e_min (float) – Minimum value of the experiment

Returns:

Updated agent state.

Return type:

EGreedyState

static sample(state: EGreedyState, key: Array) int

Epsilon-greedy agent follows the policy:

\[\begin{split}A = \begin{cases} \operatorname*{argmax}_{a \in \mathscr{A}} Q(a) & \text{with probability } 1 - \epsilon , \\ \text{random action} & \text{with probability } \epsilon . \\ \end{cases}\end{split}\]
Parameters:
  • state (EGreedyState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

Returns:

Selected action.

Return type:

int

Exp3

class Exp3State(omega: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the Exp3 agent.

omega

Preference for each arm.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class Exp3(n_arms: int, gamma: float | int, min_reward: float | int, max_reward: float | int)

Bases: BaseAgent

Basic Exp3 agent for stationary multi-armed bandit problems with exploration factor \(\gamma\). The higher the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [8]. There are many variants of the Exp3 algorithm, you can find more information in the original paper.

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).

  • gamma (float) – Exploration factor. \(\gamma \in (0, 1]\).

  • min_reward (float) – Minimum possible reward.

  • max_reward (float) – Maximum possible reward.

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int) Exp3State

Initializes the Exp3 agent state with uniform preference for each arm.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

Returns:

Initial state of the Exp3 agent.

Return type:

Exp3State

static update(state: Exp3State, key: Array, action: int, reward: float | int, gamma: float | int, min_reward: float | int, max_reward: float | int) Exp3State

Agent updates its preference for the selected arm \(a\) according to the following formula:

\[\omega_{t + 1}(a) = \omega_{t}(a) \exp \left( \frac{\gamma r}{\pi(a) K} \right)\]

where \(\omega_{t + 1}(a)\) is the preference of arm \(a\) at time \(t + 1\), \(\pi(a)\) is the probability of selecting arm \(a\), and \(K\) is the number of arms. The reward \(r\) is normalized to the range \([0, 1]\). The exponential growth significantly increases the weight of good arms, so in the long use of the agent it is important to ensure that the values of \(\omega\) do not exceed the maximum value of the floating point type!

Parameters:
  • state (Exp3State) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (float) – Reward collected by the agent after taking the previous action.

  • gamma (float) – Exploration factor.

  • min_reward (float) – Minimum possible reward.

  • max_reward (float) – Maximum possible reward.

Returns:

Updated agent state.

Return type:

Exp3State

static sample(state: Exp3State, key: Array, gamma: float | int) int

The Exp3 policy is stochastic. Algorithm chooses a random arm with probability \(\gamma\), otherwise it draws arm \(a\) with probability \(\omega(a) / \sum_{b=1}^N \omega(b)\).

Parameters:
  • state (Exp3State) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • gamma (float) – Exploration factor.

Returns:

Selected action.

Return type:

int

Softmax

class SoftmaxState(H: Array | ndarray | bool | number, r: float | int, n: int)

Bases: AgentState, Mapping

Container for the state of the Softmax agent.

H

Preference for each arm.

Type:

Array

r

Average of all obtained rewards \(\bar{R}\).

Type:

float

n

Step number.

Type:

int

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class Softmax(n_arms: int, lr: float | int, alpha: float | int = 0.0, tau: float | int = 1.0, multiplier: float | int = 1.0)

Bases: BaseAgent

Softmax agent with baseline and optional exponential recency-weighted average update. It learns a preference function \(H\), which indicates a preference of selecting one arm over others. Algorithm policy can be controlled by the temperature parameter \(\tau\). The implementation is inspired by the work of Sutton and Barto [7]. Note: For this agent, some environments find it very beneficial to use 64-bit JAX mode!

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).

  • lr (float) – Step size. \(lr > 0\).

  • alpha (float, default=0.0) – If non-zero, exponential recency-weighted average is used to update \(\bar{R}\). \(\alpha \in [0, 1]\).

  • tau (float, default=1.0) – Temperature parameter. \(\tau > 0\).

  • multiplier (float, default=1.0) – Multiplier for the reward. \(multiplier > 0\).

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int) SoftmaxState

Creates and initializes instance of the Softmax agent for n_arms arms. Preferences \(H\) for each arm are set to zero, as well as the average of all rewards \(\bar{R}\). The step number \(n\) is initialized to one.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

Returns:

Initial state of the Softmax agent.

Return type:

SoftmaxState

static update(state: SoftmaxState, key: Array, action: int, reward: float | int, lr: float | int, alpha: float | int, tau: float | int, multiplier: float | int) SoftmaxState

Preferences \(H\) can be learned by stochastic gradient ascent. The softmax algorithm searches for such a set of preferences that maximizes the expected reward \(\mathbb{E}[R]\). The updates of \(H\) for each action \(a\) are calculated as:

\[H_{t + 1}(a) = H_t(a) + \alpha (R_t - \bar{R}_t)(\mathbb{1}_{A_t = a} - \pi_t(a)),\]

where \(\bar{R_t}\) is the average of all rewards up to but not including step \(t\) (by definition \(\bar{R}_1 = R_1\)). The derivation of given formula can be found in [7].

In the stationary case, \(\bar{R_t}\) can be calculated as \(\bar{R}_{t + 1} = \bar{R}_t + \frac{1}{t} \lbrack R_t - \bar{R}_t \rbrack\). To improve the algorithm’s performance in the non-stationary case, we apply \(\bar{R}_{t + 1} = \bar{R}_t + \alpha \lbrack R_t - \bar{R}_t \rbrack\) with a constant step size \(\alpha\).

Reward \(R_t\) is multiplied by multiplier before updating preferences to allow for more flexible reward scaling while keeping the algorithm’s properties.

Parameters:
  • state (SoftmaxState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (float) – Reward collected by the agent after taking the previous action.

  • lr (float) – Step size.

  • alpha (float) – Exponential recency-weighted average factor (used when \(\alpha > 0\)).

  • tau (float) – Temperature parameter.

  • multiplier (float) – Multiplier for the reward.

Returns:

Updated agent state.

Return type:

SoftmaxState

static sample(state: SoftmaxState, key: Array, tau: float | int) int

The policy of the Softmax algorithm is stochastic. The algorithm draws the next action from the softmax distribution. The probability of selecting action \(i\) is calculated as:

\[softmax(H)_i = \frac{\exp(H_i / \tau)}{\sum_{h \in H} \exp(h / \tau)} .\]
Parameters:
  • state (SoftmaxState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • tau (float) – Temperature parameter.

Returns:

Selected action.

Return type:

int

Thompson sampling

class ThompsonSamplingState(alpha: Array | ndarray | bool | number, beta: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the Thompson sampling agent.

alpha

Number of successful tries for each arm.

Type:

Array

beta

Number of failed tries for each arm.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class ThompsonSampling(n_arms: int, decay: float | int = 0.0)

Bases: BaseAgent

Contextual Bernoulli Thompson sampling agent with the exponential smoothing. The implementation is inspired by the work of Krotov et al. [9]. Thompson sampling is based on a beta distribution with parameters related to the number of successful and failed attempts. Higher values of the parameters decrease the entropy of the distribution while changing the ratio of the parameters shifts the expected value.

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).

  • decay (float, default=0.0) – Decay rate. If equal to zero, smoothing is not applied. \(w \geq 0\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Discrete

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int) ThompsonSamplingState

Creates and initializes an instance of the Thompson sampling agent for n_arms arms. The \(\mathbf{\alpha}\) and \(\mathbf{\beta}\) vectors are set to zero to create a non-informative prior distribution. The last_decay is also set to zero.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

Returns:

Initial state of the Thompson sampling agent.

Return type:

ThompsonSamplingState

static update(state: ThompsonSamplingState, key: Array, action: int, n_successful: int, n_failed: int, delta_time: float | int, decay: float | int) ThompsonSamplingState

Thompson sampling can be adjusted to non-stationary environments by exponential smoothing of values of vectors \(\mathbf{\alpha}\) and \(\mathbf{\beta}\) which increases the entropy of a distribution over time. Given a result of trial \(s\), we apply the following equations for each action \(a\):

\[\begin{split}\begin{gather} \mathbf{\alpha}_{t + 1}(a) = \mathbf{\alpha}_t(a) e^{\frac{-\Delta t}{w}} + \mathbb{1}_{A = a} \cdot s , \\ \mathbf{\beta}_{t + 1}(a) = \mathbf{\beta}_t(a) e^{\frac{-\Delta t}{w}} + \mathbb{1}_{A = a} \cdot (1 - s) , \end{gather}\end{split}\]

where \(\Delta t\) is the time elapsed since the last action selection and \(w\) is the decay rate.

Parameters:
  • state (ThompsonSamplingState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • n_successful (int) – Number of successful tries.

  • n_failed (int) – Number of failed tries.

  • delta_time (float) – Time elapsed since the last action selection.

  • decay (float) – Decay rate.

Returns:

Updated agent state.

Return type:

ThompsonSamplingState

static sample(state: ThompsonSamplingState, key: Array, context: Array | ndarray | bool | number) int

The Thompson sampling policy is stochastic. The algorithm draws \(q_a\) from the distribution \(\operatorname{Beta}(1 + \mathbf{\alpha}(a), 1 + \mathbf{\beta}(a))\) for each arm \(a\). The next action is selected as

\[A = \operatorname*{argmax}_{a \in \mathscr{A}} q_a r_a ,\]

where \(r_a\) is contextual information for the arm \(a\), and \(\mathscr{A}\) is a set of all actions.

Parameters:
  • state (ThompsonSamplingState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • context (Array) – One-dimensional array of features for each arm.

Returns:

Selected action.

Return type:

int

Normal Thompson sampling

class NormalThompsonSamplingState(alpha: Array | ndarray | bool | number, beta: Array | ndarray | bool | number, lam: Array | ndarray | bool | number, mu: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the normal Thompson sampling agent.

alpha

The concentration parameter of the inverse-gamma distribution.

Type:

Array

beta

The scale parameter of the inverse-gamma distribution.

Type:

Array

lam

The number of observations.

Type:

Array

mu

The mean of the normal distribution.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class NormalThompsonSampling(n_arms: int, alpha: float | int, beta: float | int, lam: float | int, mu: float | int)

Bases: BaseAgent

Normal Thompson sampling agent [11]. The normal-inverse-gamma distribution is a conjugate prior for the normal distribution with unknown mean and variance. The parameters of the distribution are updated after each observation. The mean of the normal distribution is sampled from the normal-inverse-gamma distribution and the action with the highest expected value is selected.

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\) .

  • alpha (float) – See also NormalThompsonSamplingState for interpretation. \(\alpha > 0\).

  • beta (float) – See also NormalThompsonSamplingState for interpretation. \(\beta > 0\).

  • lam (float) – See also NormalThompsonSamplingState for interpretation. \(\lambda > 0\).

  • mu (float) – See also NormalThompsonSamplingState for interpretation. \(\mu \in \mathbb{R}\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Space

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int, alpha: float | int, beta: float | int, lam: float | int, mu: float | int) NormalThompsonSamplingState

Creates and initializes an instance of the normal Thompson sampling agent for n_arms arms and the given initial parameters for the prior distribution.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

  • alpha (float) – See also NormalThompsonSamplingState for interpretation.

  • beta (float) – See also NormalThompsonSamplingState for interpretation.

  • lam (float) – See also NormalThompsonSamplingState for interpretation.

  • mu (float) – See also NormalThompsonSamplingState for interpretation.

Returns:

Initial state of the normal Thompson sampling agent.

Return type:

NormalThompsonSamplingState

static update(state: NormalThompsonSamplingState, key: Array, action: int, reward: float | int) NormalThompsonSamplingState

Normal Thompson sampling update according to [11].

\[\begin{split}\begin{align} \alpha_{t + 1}(a) &= \alpha_t(a) + \frac{1}{2} \\ \beta_{t + 1}(a) &= \beta_t(a) + \frac{\lambda_t(a) (r_t(a) - \mu_t(a))^2}{2 (\lambda_t(a) + 1)} \\ \lambda_{t + 1}(a) &= \lambda_t(a) + 1 \\ \mu_{t + 1}(a) &= \frac{\mu_t(a) \lambda_t(a) + r_t(a)}{\lambda_t(a) + 1} \end{align}\end{split}\]
Parameters:
  • state (NormalThompsonSamplingState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (Float) – Reward obtained upon execution of action.

Returns:

Updated agent state.

Return type:

NormalThompsonSamplingState

static inverse_gamma(key: Array, concentration: Array | ndarray | bool | number, scale: Array | ndarray | bool | number) Array | ndarray | bool | number

Samples from the inverse gamma distribution. Implementation is based on the gamma distribution and the following dependence:

\[\begin{split}\begin{gather} X \sim \operatorname{Gamma}(\alpha, \beta) \\ \frac{1}{X} \sim \operatorname{Inverse-gamma}(\alpha, \frac{1}{\beta}) \end{gather}\end{split}\]
Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • concentration (Array) – The concentration parameter of the inverse-gamma distribution.

  • scale (Array) – The scale parameter of the inverse-gamma distribution.

Returns:

Sampled values from the inverse gamma distribution.

Return type:

Array

static sample(state: NormalThompsonSamplingState, key: Array) int

The normal Thompson sampling policy is stochastic. The algorithm draws \(q_a\) from the distribution \(\operatorname{Normal}(\mu(a), \operatorname{scale}(a)/\sqrt{\lambda(a)})\) for each arm \(a\) where \(\text{scale}(a)\) is sampled from the inverse gamma distribution with parameters \(\alpha(a)\) and \(\beta(a)\). The next action is selected as \(A = \operatorname*{argmax}_{a \in \mathscr{A}} q_a\), where \(\mathscr{A}\) is a set of all actions.

Parameters:
Returns:

Selected action.

Return type:

int

Log-normal Thompson sampling

class LogNormalThompsonSampling(n_arms: int, alpha: float | int, beta: float | int, lam: float | int, mu: float | int)

Bases: NormalThompsonSampling

Log-normal Thompson sampling agent. This algorithm is designed to handle positive rewards by transforming them into the log-space. For more details, refer to the documentation on NormalThompsonSampling.

static update(state: NormalThompsonSamplingState, key: Array, action: int, reward: float | int) NormalThompsonSamplingState

Log-normal Thompson sampling update. The update is analogous to the one in NormalThompsonSampling except that the reward is transformed into the log-space.

Parameters:
  • state (NormalThompsonSamplingState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (Float) – Reward obtained upon execution of action.

Returns:

Updated agent state.

Return type:

NormalThompsonSamplingState

static sample(state: NormalThompsonSamplingState, key: Array) int

Sampling actions is analogous to the one in NormalThompsonSampling except that the expected value of the log-normal distribution is computed instead of the expected value of the normal distribution.

Parameters:
Returns:

Selected action.

Return type:

int

Upper confidence bound (UCB)

class UCBState(R: Array | ndarray | bool | number, N: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the UCB agent.

R

Sum of the rewards obtained for each arm.

Type:

Array

N

Number of tries for each arm.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class UCB(n_arms: int, c: float | int, gamma: float | int = 1.0)

Bases: BaseAgent

UCB agent with optional discounting. The main idea behind this algorithm is to introduce a preference factor in its policy, so that the selection of the next action is based on both the current estimation of the action-value function and the uncertainty of this estimation.

Parameters:
  • n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).

  • c (float) – Degree of exploration. \(c \geq 0\).

  • gamma (float, default=1.0) – If less than one, a discounted UCB algorithm [10] is used. \(\gamma \in (0, 1]\).

References

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Space

Action space of the agent in Gymnasium format.

static init(key: Array, n_arms: int) UCBState

Creates and initializes instance of the UCB agent for n_arms arms. The sum of the rewards is set to zero and the number of tries is set to one for each arm.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • n_arms (int) – Number of bandit arms.

Returns:

Initial state of the UCB agent.

Return type:

UCBState

static update(state: UCBState, key: Array, action: int, reward: float | int, gamma: float | int) UCBState

In the stationary case, the sum of the rewards for a given arm is increased by reward \(r\) obtained after step \(t\) and the number of tries for the corresponding arm is incremented. In the non-stationary case, the update follows the equations

\[\begin{split}\begin{gather} R_{t + 1}(a) = \mathbb{1}_{A_t = a} r + \gamma R_t(a) , \\ N_{t + 1}(a) = \mathbb{1}_{A_t = a} + \gamma N_t(a). \end{gather}\end{split}\]
Parameters:
  • state (UCBState) – Current state of agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • action (int) – Previously selected action.

  • reward (float) – Reward collected by the agent after taking the previous action.

  • gamma (float) – Discount factor.

Returns:

Updated agent state.

Return type:

UCBState

static sample(state: UCBState, key: Array, c: float | int) int

UCB agent follows the policy

\[A = \operatorname*{argmax}_{a \in \mathscr{A}} \left[ Q(a) + c \sqrt{\frac{\ln \left( {\sum_{a' \in \mathscr{A}}} N(a') \right) }{N(a)}} \right] .\]

where \(\mathscr{A}\) is a set of all actions and \(Q\) is calculated as \(Q(a) = \frac{R(a)}{N(a)}\). The second component of the sum represents a sort of upper bound on the value of \(Q\), where \(c\) behaves like a confidence interval and the square root - like an approximation of the \(Q\) function estimation uncertainty. Note that the UCB policy is deterministic (apart from choosing between several optimal actions).

Parameters:
  • state (UCBState) – Current state of the agent.

  • key (PRNGKey) – A PRNG key used as the random key.

  • c (float) – Degree of exploration.

Returns:

Selected action.

Return type:

int

Random scheduler

class RandomSchedulerState

Bases: AgentState, Mapping

Random scheduler has no memory, thus the state is empty.

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class RandomScheduler(n_arms: int)

Bases: BaseAgent

Random scheduler with MAB interface. This scheduler pics item randomly.

Parameters:

n_arms (int) – Number of items to choose from. \(N \in \mathbb{N}_{+}\).

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Space

Action space of the agent in Gymnasium format.

static init(key: Array) RandomSchedulerState

Creates and initializes instance of the agent.

static update(state: RandomSchedulerState, key: Array) RandomSchedulerState

Updates the state of the agent after performing some action and receiving a reward.

static sample(state: RandomSchedulerState, key: Array, n_arms: int) int

Selects the next action based on the current environment and agent state.

Round-robin scheduler

class RoundRobinSchedulerState(item: Array | ndarray | bool | number | float | int)

Bases: AgentState, Mapping

Container for the state of the round-robin scheduler.

item

Scheduled item.

Type:

Numeric

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class RoundRobinScheduler(n_arms: int, initial_item: int = 0)

Bases: BaseAgent

Round-robin with MAB interface. This scheduler pics item sequentially. Sampling is deterministic, one must call update to change state.

Parameters:
  • n_arms (int) – Number of items to choose from. \(N \in \mathbb{N}_{+}\).

  • initial_item (int, default=0) – Initial item to start sampling from.

static parameter_space() Dict

Parameters of the agent constructor in Gymnasium format. Type of returned value is required to be gym.spaces.Dict or None. If None, the user must provide all parameters manually.

property update_observation_space: Dict

Observation space of the update method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property sample_observation_space: Dict

Observation space of the sample method in Gymnasium format. Allows to infer missing observations using an extensions and easily export the agent to TensorFlow Lite format. If None, the user must provide all parameters manually.

property action_space: Space

Action space of the agent in Gymnasium format.

static init(key: Array, item: Array | ndarray | bool | number | float | int) RoundRobinSchedulerState

Creates and initializes instance of the agent.

static update(state: RoundRobinSchedulerState, key: Array, n_arms: int) RoundRobinSchedulerState

Updates the state of the agent after performing some action and receiving a reward.

static sample(state: RoundRobinSchedulerState, key: Array) Array | ndarray | bool | number | float | int

Selects the next action based on the current environment and agent state.

Masked MAB

class MaskedState(agent_state: reinforced_lib.agents.base_agent.AgentState, mask: jax.Array | numpy.ndarray | numpy.bool | numpy.number)

Bases: AgentState, Mapping

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class Masked(agent: BaseAgent, mask: Array | ndarray | bool | number)

Bases: BaseAgent

Meta agent supporting dynamic change of number of arms.

This agent is highly experimental and is expected to be used with an extreme caution. In particular, this agent makes the following strong assumptions:

  • Each entry in the base agent state has the first dimension corresponding to an arm.

  • The base agent must be stochastic as this agent uses rejection sampling to choose a possible action

Example usage of the agent can be found in the test test/experimental/test_masked.py.

Parameters:
  • agent (BaseAgent) – A MAB agent type which actions are masked.

  • mask (Array) – Binary mask array of the length equal to the number of arms. Positive entries are the masked actions.

static init(key: Array, *args: tuple, agent: BaseAgent, mask: Array | ndarray | bool | number, **kwargs: dict) MaskedState

Initialize the masked agent state given the mask and the base agent.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • args (tuple) – Positional arguments passed to the base agent init method.

  • agent (BaseAgent) – A base agent whose state is initialized.

  • mask (Array) – Binary mask array of the length equal to the number of arms.

  • kwargs (dict) – Keyword arguments passed to the base agent init method.

Returns:

Initialized masked agent state.

Return type:

MaskedState

static update(state: MaskedState, key: Array, *args: tuple, agent: BaseAgent, **kwargs: dict) MaskedState

Update the base agent state. The entries corresponding to the masked actions are not updated.

Parameters:
  • state (MaskedState) – Current masked agent state.

  • key (PRNGKey) – A PRNG key used as the random key.

  • args (tuple) – Positional arguments passed to the base agent update method.

  • agent (BaseAgent) – A base agent whose state is updated.

  • kwargs (dict) – Keyword arguments passed to the base agent update method.

Returns:

Updated masked agent state.

Return type:

MaskedState

static sample(state: MaskedState, key: Array, *args, agent: BaseAgent, **kwargs) int

Sample an action from the base agent. If the sampled action is masked, resample until an unmasked action is found.

Parameters:
  • state (MaskedState) – Current masked agent state.

  • key (PRNGKey) – A PRNG key used as the random key.

  • args (tuple) – Positional arguments passed to the base agent sample method.

  • agent (BaseAgent) – A base agent whose action is sampled.

  • kwargs (dict) – Keyword arguments passed to the base agent sample method.

Returns:

Sampled action.

Return type:

int