Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Deep Q-learning agent [1] with \(\epsilon\)-greedy exploration and experience replay buffer. The agent uses
a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman
error. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action
spaces.
Parameters:
q_network (nn.Module) – Architecture of the Q-network.
obs_space_shape (Shape) – Shape of the observation space.
act_space_size (int) – Size of the action space.
optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-network. If None, the Adam optimizer with learning rate 1e-3 is used.
experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.
experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.
experience_replay_steps (int, default=5) – Number of experience replay steps per update.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Initializes the Q-network, optimizer and experience replay buffer with given parameters.
The first state of the environment is assumed to be a tensor of zeros.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
obs_space_shape (Shape) – The shape of the observation space.
q_network (nn.Module) – The Q-network.
optimizer (optax.GradientTransformation) – The optimizer.
Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r +
\gamma \max_{a'} Q(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the
current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor,
\(Q(s, a)\) is the Q-value of the state-action pair. Loss can be calculated on a batch of transitions.
Parameters:
params (dict) – The parameters of the Q-network.
key (PRNGKey) – A PRNG key used as the random key.
net_state (dict) – The state of the Q-network.
params_target (dict) – The parameters of the target Q-network.
net_state_target (dict) – The state of the target Q-network.
batch (tuple) – A batch of transitions from the experience replay buffer.
Appends the transition to the experience replay buffer and performs experience_replay_steps steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss
using the loss_fn function and performing a gradient step on the Q-network. The \(\epsilon\)-greedy
parameter is decayed by epsilon_decay.
Parameters:
state (DQNState) – The current state of the deep Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
action (Array) – The action taken by the agent.
reward (Scalar) – The reward received by the agent.
terminal (bool) – Whether the episode has terminated.
step_fn (Callable) – The function that performs a single gradient step on the Q-network.
Samples random action with probability \(\epsilon\) and the greedy action with probability
\(1 - \epsilon\). The greedy action is the action with the highest Q-value.
Parameters:
state (DQNState) – The state of the deep Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
q_network (nn.Module) – The Q-network.
act_space_size (int) – The size of the action space.
Double deep Q-learning agent [2] with \(\epsilon\)-greedy exploration and experience replay buffer. The agent
uses two Q-networks to stabilize the learning process and avoid overestimation of the Q-values. The main Q-network
is trained to minimize the Bellman error. The target Q-network is updated with a soft update. This agent follows
the off-policy learning paradigm and is suitable for environments with discrete action spaces.
Parameters:
q_network (nn.Module) – Architecture of the Q-networks.
obs_space_shape (Shape) – Shape of the observation space.
act_space_size (int) – Size of the action space.
optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-networks. If None, the Adam optimizer with learning rate 1e-3 is used.
experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.
experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.
experience_replay_steps (int, default=5) – Number of experience replay steps per update.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Initializes the Q-networks, optimizer and experience replay buffer with given parameters.
The first state of the environment is assumed to be a tensor of zeros.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
obs_space_shape (Shape) – The shape of the observation space.
q_network (nn.Module) – The Q-network.
optimizer (optax.GradientTransformation) – The optimizer.
Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r +
\gamma \max_{a'} Q'(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\) is the
current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount factor,
\(Q(s, a)\) is the Q-value of the main Q-network, \(Q'(s', a')\) is the Q-value of the target
Q-network. Loss can be calculated on a batch of transitions.
Parameters:
params (dict) – The parameters of the Q-network.
key (PRNGKey) – A PRNG key used as the random key.
state (DDQNState) – The state of the double deep Q-learning agent.
batch (tuple) – A batch of transitions from the experience replay buffer.
Appends the transition to the experience replay buffer and performs experience_replay_steps steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss
using the loss_fn function, performing a gradient step on the main Q-network, and soft updating the target
Q-network. Soft update of the parameters is defined as \(\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}\).
The \(\epsilon\)-greedy parameter is decayed by epsilon_decay.
Parameters:
state (DDQNState) – The current state of the double Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
action (Array) – The action taken by the agent.
reward (Scalar) – The reward received by the agent.
terminal (bool) – Whether the episode has terminated.
step_fn (Callable) – The function that performs a single gradient step on the Q-network.
Samples random action with probability \(\epsilon\) and the greedy action with probability
\(1 - \epsilon\) using the main Q-network. The greedy action is the action with the highest Q-value.
Parameters:
state (DDQNState) – The state of the double Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
q_network (nn.Module) – The Q-network.
act_space_size (int) – The size of the action space.
Deep expected SARSA agent with temperature parameter \(\tau\) and experience replay buffer. The agent uses
a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman
error. This agent follows the on-policy learning paradigm and is suitable for environments with discrete action
spaces.
Parameters:
q_network (nn.Module) – Architecture of the Q-network.
obs_space_shape (Shape) – Shape of the observation space.
act_space_size (int) – Size of the action space.
optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-network. If None, the Adam optimizer with learning rate 1e-3 is used.
experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.
experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.
experience_replay_steps (int, default=5) – Number of experience replay steps per update.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
tau (Scalar, default=1.0) – Temperature parameter. \(\tau = 0.0\) means no exploration, \(\tau = \infty\) means infinite exploration. \(\tau > 0\)
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Initializes the Q-network, optimizer and experience replay buffer with given parameters.
The first state of the environment is assumed to be a tensor of zeros.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
obs_space_shape (Shape) – The shape of the observation space.
q_network (nn.Module) – The Q-network.
optimizer (optax.GradientTransformation) – The optimizer.
Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r +
\gamma \sum_{a'} \pi(a'|s') Q(s', a') - Q(s, a) \right)^2 \right]\) where \(s\) is the current state,
\(a\) is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is
the discount factor, \(Q(s, a)\) is the Q-value of the state-action pair. Loss can be calculated on a batch
of transitions.
Parameters:
params (dict) – The parameters of the Q-network.
key (PRNGKey) – A PRNG key used as the random key.
net_state (dict) – The state of the Q-network.
params_target (dict) – The parameters of the target Q-network.
net_state_target (dict) – The state of the target Q-network.
batch (tuple) – A batch of transitions from the experience replay buffer.
Appends the transition to the experience replay buffer and performs experience_replay_steps steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss
using the loss_fn function and performing a gradient step on the Q-network.
Parameters:
state (ExpectedSarsaState) – The current state of the deep expected SARSA agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
action (Array) – The action taken by the agent.
reward (Scalar) – The reward received by the agent.
terminal (bool) – Whether the episode has terminated.
q_network (nn.Module) – The Q-network.
step_fn (Callable) – The function that performs a single gradient step on the Q-network.
Deep deterministic policy gradient [3][4] agent with white Gaussian noise exploration and experience replay
buffer. The agent simultaneously learns a Q-function and a policy. The Q-function is updated using the Bellman
equation. The policy is learned using the gradient of the Q-function with respect to the policy parameters
to maximize the Q-value. The agent uses two Q-networks and two policy networks to stabilize the learning process
and avoid overestimation. The target networks are updated with a soft update. This agent follows the off-policy
learning paradigm and is suitable for environments with continuous action spaces.
Parameters:
q_network (nn.Module) – Architecture of the Q-networks.
The input to the network should be two tensors of observations and actions respectively.
a_network (nn.Module) – Architecture of the policy networks.
obs_space_shape (Shape) – Shape of the observation space.
act_space_shape (Shape) – Shape of the action space.
min_action (Scalar or Array) – Minimum action value.
max_action (Scalar or Array) – Maximum action value.
q_optimizer (optax.GradientTransformation, optional) – Optimizer of the Q-networks. If None, the Adam optimizer with learning rate 1e-3 is used.
a_optimizer (optax.GradientTransformation, optional) – Optimizer of the policy networks. If None, the Adam optimizer with learning rate 1e-3 is used.
experience_replay_buffer_size (int, default=10000) – Size of the experience replay buffer.
experience_replay_batch_size (int, default=64) – Batch size of the samples from the experience replay buffer.
experience_replay_steps (int, default=5) – Number of experience replay steps per update.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Initializes the Q-networks and the policy networks, optimizers, and experience replay buffer.
The first state of the environment is assumed to be a tensor of zeros.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
obs_space_shape (Shape) – The shape of the observation space.
act_space_shape (Shape) – The shape of the action space.
q_network (nn.Module) – The Q-network.
a_network (nn.Module) – The policy network.
q_optimizer (optax.GradientTransformation) – The Q-network optimizer.
a_optimizer (optax.GradientTransformation) – The policy network optimizer.
Loss is the mean squared Bellman error \(\mathcal{L}(\theta) = \mathbb{E}_{s, a, r, s'} \left[ \left( r
+ \gamma \max Q'(s', \pi'(s')) - Q(s, a) \right)^2 \right]\) where \(s\) is the current state, \(a\)
is the current action, \(r\) is the reward, \(s'\) is the next state, \(\gamma\) is the discount
factor, \(Q(s, a)\) is the Q-value of the main Q-network, \(Q'(s, a)\) is the Q-value of the target
Q-network, and \(\pi'(s)\) is the action of the target policy network. The policy network parameters
are considered as fixed. Loss can be calculated on a batch of transitions.
Parameters:
q_params (dict) – The parameters of the Q-network.
key (PRNGKey) – A PRNG key used as the random key.
state (DDPGState) – The state of the deep deterministic policy gradient agent.
batch (tuple) – A batch of transitions from the experience replay buffer.
The policy network is updated using the gradient of the Q-network to maximize the Q-value of the current state
and action \(\max_{\theta} \mathbb{E}_{s, a} \left[ Q(s, \pi_{\theta}(s)) \right]\). Q-network parameters are
considered as fixed. The policy network can be updated on a batch of transitions.
Parameters:
a_params (dict) – The parameters of the policy network.
key (PRNGKey) – A PRNG key used as the random key.
state (DDPGState) – The state of the deep deterministic policy gradient agent.
batch (tuple) – A batch of transitions from the experience replay buffer.
Appends the transition to the experience replay buffer and performs experience_replay_steps steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the
Q-network loss and the policy network loss using q_loss_fn and a_loss_fn respectively, performing
a gradient step on both networks, and soft updating the target networks. Soft update of the parameters
is defined as \(\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}\).The noise parameter is
decayed by noise_decay.
Parameters:
state (DDPGState) – The current state of the double Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
action (Array) – The action taken by the agent.
reward (Scalar) – The reward received by the agent.
terminal (bool) – Whether the episode has terminated.
q_step_fn (Callable) – The function that performs a single gradient step on the Q-network.
a_step_fn (Callable) – The function that performs a single gradient step on the policy network.
Calculates deterministic action using the policy network. Then adds white Gaussian noise with standard
deviation state.noise to the action and clips it to the range \([min\_action, max\_action]\).
Parameters:
state (DDPGState) – The state of the double Q-learning agent.
key (PRNGKey) – A PRNG key used as the random key.
env_state (Array) – The current state of the environment.
a_network (nn.Module) – The policy network.
min_action (Scalar or Array) – The minimum value of the action.
max_action (Scalar or Array) – The maximum value of the action.
Proximal Policy Optimization (PPO) agent [5]. This implementation uses the clipped surrogate objective.
The policy and value functions should be represented by a single Flax module with two outputs: the action logits
and the state value. The network should be able to process a batch of observations. The actions are sampled
from a categorical distribution, while the value function is used to compute the advantages using Generalized
Advantage Estimation (GAE) [6]. The agent is trained using mini-batch gradient descent. This agent follows
the on-policy learning paradigm and is suitable for environments with discrete action spaces.
Parameters:
network (nn.Module) – Architecture of the PPO agent network.
obs_space_shape (Shape) – Shape of the observation space.
act_space_size (int) – Size of the action space.
optimizer (optax.GradientTransformation, optional) – Optimizer of the network. If None, the Adam optimizer with learning rate 3e-4 and \(\epsilon\) = 1e-5 is used.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
lambda_gae (Scalar, default=0.9) – GAE parameter. \(\lambda = 0.0\) means no GAE, \(\lambda = 1.0\) means pure Monte Carlo advantage. \(0 \leq \lambda \leq 1\)
normalize_advantage (bool, default=True) – If True, the advantages are normalized to have mean 0 and standard deviation 1.
clip_coef (Scalar, default=0.2) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5].
clip_value (bool, default=True) – If True, the loss for the value function is clipped.
clip_grad (Scalar, default=0.5) – If not None, the gradients are clipped to have a maximum norm of clip_grad.
entropy_coef (Scalar, default=0.01) – Coefficient for the entropy bonus.
value_coef (Scalar, default=0.5) – Coefficient for the value function loss.
rollout_length (int, default=512) – Length of the rollout buffer.
num_envs (int, default=1) – Number of parallel environments.
batch_size (int, default=128) – Size of the batch to be sampled from the rollout buffer.
num_epochs (int, default=4) – Number of update epochs to perform on the rollout buffer.
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Initializes the PPO network, optimizer and rollout buffer with given parameters.
The first state of the environment is assumed to be a tensor of zeros.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
num_envs (int) – The number of parallel environments.
obs_space_shape (Shape) – The shape of the observation space.
network (nn.Module) – The agent network.
optimizer (optax.GradientTransformation) – The optimizer.
rb (RolloutBuffer) – The rollout buffer functions.
Appends the transition to the on-policy rollout buffer. Once the rollout buffer reaches rollout_length
steps, computes advantages and returns using GAE, shuffles and flattens the buffer, and performs multiple
gradient updates using mini-batches.
Parameters:
state (PPOState) – The current state of the PPO agent.
key (PRNGKey) – A PRNG key used as the random key.
env_states (Array) – The current states of the environments.
actions (Array) – The actions taken by the agent.
rewards (Scalar) – The rewards received by the agent.
terminals (bool) – Whether the episodes have terminated.
network (nn.Module) – The agent network.
step_fn (Callable) – The function that performs a single gradient step on the agent network.
rb (RolloutBuffer) – The rollout buffer functions.
rollout_length (int) – The length of the rollout buffer.
num_envs (int) – The number of parallel environments.
batch_size (int) – The size of the batch sampled from the rollout buffer.
num_epochs (int) – The number of gradient steps to perform.
Proximal Policy Optimization (PPO) agent [5]. This implementation uses the clipped surrogate objective.
The policy and value functions should be represented by a single Flax module with three outputs: the actions,
the log standard deviation of the actions, and the state value. The network should be able to process a batch
of observations. The actions are sampled from a diagonal Gaussian distribution, while the value function is used
to compute the advantages using Generalized Advantage Estimation (GAE) [6]. The agent is trained using mini-batch
gradient descent. This agent follows the on-policy learning paradigm and is suitable for environments with
continuous action spaces.
Parameters:
network (nn.Module) – Architecture of the PPO agent network.
obs_space_shape (Shape) – Shape of the observation space.
act_space_shape (Shape) – Shape of the action space.
min_action (Scalar or Array) – Minimum action value.
max_action (Scalar or Array) – Maximum action value.
optimizer (optax.GradientTransformation, optional) – Optimizer of the network. If None, the Adam optimizer with learning rate 3e-4 and \(\epsilon\) = 1e-5 is used.
discount (Scalar, default=0.99) – Discount factor. \(\gamma = 0.0\) means no discount, \(\gamma = 1.0\) means infinite discount. \(0 \leq \gamma \leq 1\)
lambda_gae (Scalar, default=0.9) – GAE parameter. \(\lambda = 0.0\) means no GAE, \(\lambda = 1.0\) means pure Monte Carlo advantage. \(0 \leq \lambda \leq 1\)
normalize_advantage (bool, default=True) – If True, the advantages are normalized to have mean 0 and standard deviation 1.
clip_coef (Scalar, default=0.2) – Clipping coefficient for the surrogate objective, \(\epsilon\) in [5].
clip_value (bool, default=True) – If True, the loss for the value function is clipped.
clip_grad (Scalar, default=0.5) – If not None, the gradients are clipped to have a maximum norm of clip_grad.
entropy_coef (Scalar, default=0.01) – Coefficient for the entropy bonus.
value_coef (Scalar, default=0.5) – Coefficient for the value function loss.
rollout_length (int, default=512) – Length of the rollout buffer.
num_envs (int, default=1) – Number of parallel environments.
batch_size (int, default=128) – Size of the batch to be sampled from the rollout buffer.
num_epochs (int, default=4) – Number of update epochs to perform on the rollout buffer.
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Appends the transition to the on-policy rollout buffer. Once the rollout buffer reaches rollout_length
steps, computes advantages and returns using GAE, shuffles and flattens the buffer, and performs multiple
gradient updates using mini-batches.
Parameters:
state (PPOState) – The current state of the PPO agent.
key (PRNGKey) – A PRNG key used as the random key.
env_states (Array) – The current states of the environments.
actions (Array) – The actions taken by the agent.
rewards (Scalar) – The rewards received by the agent.
terminals (bool) – Whether the episodes have terminated.
network (nn.Module) – The agent network.
step_fn (Callable) – The function that performs a single gradient step on the agent network.
rb (RolloutBuffer) – The rollout buffer functions.
rollout_length (int) – The length of the rollout buffer.
num_envs (int) – The number of parallel environments.
batch_size (int) – The size of the batch sampled from the rollout buffer.
num_epochs (int) – The number of gradient steps to perform.
Evolution strategies (ES)-based agent using the evosax library [12]. This implementation maintains a population
of candidate solutions (parameter vectors), evaluates them in parallel across environments, and updates the
population by applying an evolutionary algorithm. Unlike gradient-based RL methods, this agent does not rely
on backpropagation through the value or policy network. Instead, the network parameters are evolved using
black-box optimization. This agent is suitable for environments with both discrete and continuous action spaces.
Note! The user is responsible for providing appropriate network output in the correct format (e.g., discrete
actions should be sampled from logits with jax.random.categorical inside the network definition).
Note! This agent does not discount future rewards, therefore, the fitness is computed as a simple sum of
rewards obtained during the evaluation phase.
Note! This agent is compatible only with distribution-based evolution strategies from the evosax library
(see this list for
available algorithms). Population-based methods (listed here
will be supported in future releases.
Parameters:
network (nn.Module) – Architecture of the PPO agent network.
evo_strategy (type) – Evolution strategy class from evosax.
population_size (int) – Size of the population.
obs_space_shape (Shape) – Shape of the observation space.
act_space_shape (Shape, default=(1,)) – Shape of the action space. For discrete action spaces, use (1,).
evo_strategy_kwargs (dict, default=None) – Parameters for the evolution strategy initialization. The population size and initial solution are set automatically.
evo_strategy_default_params (dict, default=None) – Custom default parameters for the evolution strategy. If None, the default parameters are used.
num_eval_steps (int, default=None) – Number of evaluation steps. If None, the evaluation runs until all episodes end.
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Updates the agent state after one evaluation step of the population. The method accumulates rewards
into fitness values for each individual and tracks episode terminations. Once the evaluation is considered
complete, either because all episodes have terminated (num_eval_steps=None) or because a fixed number
of steps has been reached (num_eval_steps specified), the population is evolved. The evolution strategy’s
tell method is called with the negative fitness values (as evosax minimizes the fitness), and a new population
is generated using the ask method. The best parameters found so far are stored in the state.
Parameters:
state (EvosaxState) – The current state of the evosax agent.
key (PRNGKey) – A PRNG key used as the random key.
env_states (Array) – The current states of the environments.
actions (Array) – The actions taken by the agent.
rewards (Scalar) – The rewards received by the agent.
terminals (bool) – Whether the episodes have terminated.
num_eval_steps (int) – Number of evaluation steps. If None, the evaluation runs until all episodes end.
evo_strategy (EvolutionaryAlgorithm) – The evosax evolution strategy.
key (PRNGKey) – A PRNG key used as the random key.
env_states (Array) – The current state of the environment.
population_size (int) – The size of the population.
network (nn.Module) – The agent network.
params_format_fn (Callable) – Function that formats the flattened parameters of the population members to the original neural network
parameter format.
Epsilon-greedy [7] agent with an optimistic start behavior and optional exponential recency-weighted average update.
It selects a random action from a set of all actions \(\mathscr{A}\) with probability
\(\epsilon\) (exploration), otherwise it chooses the currently best action (exploitation).
Epsilon can be decayed over time to shift the policy from exploration to exploitation.
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Creates and initializes instance of the \(\epsilon\)-greedy agent for n_arms arms. Action-value function estimates are
set to optimistic_start value and the number of tries is one for each arm.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
n_arms (int) – Number of bandit arms.
e (float) – Initial experiment rate (epsilon).
optimistic_start (float) – Interpreted as the optimistic start to encourage exploration in the early stages.
In the stationary case, the action-value estimate for a given arm is updated as
\(Q_{t + 1} = Q_t + \frac{1}{t} \lbrack R_t - Q_t \rbrack\) after receiving reward \(R_t\) at step
\(t\) and the number of tries for the corresponding arm is incremented. In the non-stationary case,
the update follows the equation \(Q_{t + 1} = Q_t + \alpha \lbrack R_t - Q_t \rbrack\).
Exploration rate \(\epsilon\) can be decayed over time following the equation
\(\epsilon_{t + 1} = \max(\epsilon_{\min}, \epsilon_t \times \epsilon_{\text{decay}})\).
Parameters:
state (EGreedyState) – Current state of the agent.
key (PRNGKey) – A PRNG key used as the random key.
action (int) – Previously selected action.
reward (float) – Reward collected by the agent after taking the previous action.
alpha (float) – Exponential recency-weighted average factor (used when \(\alpha > 0\)).
e_decay (float) – Decay factor for the experiment rate.
Basic Exp3 agent for stationary multi-armed bandit problems with exploration factor \(\gamma\). The higher
the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [8]. There
are many variants of the Exp3 algorithm, you can find more information in the original paper.
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
where \(\omega_{t + 1}(a)\) is the preference of arm \(a\) at time \(t + 1\), \(\pi(a)\) is
the probability of selecting arm \(a\), and \(K\) is the number of arms. The reward \(r\) is
normalized to the range \([0, 1]\). The exponential growth significantly increases the weight of good arms,
so in the long use of the agent it is important to ensure that the values of\(\omega\)do not exceed
the maximum value of the floating point type!
The Exp3 policy is stochastic. Algorithm chooses a random arm with probability \(\gamma\), otherwise it
draws arm \(a\) with probability \(\omega(a) / \sum_{b=1}^N \omega(b)\).
Softmax agent with baseline and optional exponential recency-weighted average update. It learns a preference
function \(H\), which indicates a preference of selecting one arm over others. Algorithm policy can be
controlled by the temperature parameter \(\tau\). The implementation is inspired by the work of Sutton
and Barto [7]. Note: For this agent, some environments find it very beneficial to use 64-bit JAX mode!
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).
lr (float) – Step size. \(lr > 0\).
alpha (float, default=0.0) – If non-zero, exponential recency-weighted average is used to update \(\bar{R}\). \(\alpha \in [0, 1]\).
tau (float, default=1.0) – Temperature parameter. \(\tau > 0\).
multiplier (float, default=1.0) – Multiplier for the reward. \(multiplier > 0\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Creates and initializes instance of the Softmax agent for n_arms arms. Preferences \(H\) for each arm
are set to zero, as well as the average of all rewards \(\bar{R}\). The step number \(n\) is
initialized to one.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
Preferences \(H\) can be learned by stochastic gradient ascent. The softmax algorithm searches
for such a set of preferences that maximizes the expected reward \(\mathbb{E}[R]\).
The updates of \(H\) for each action \(a\) are calculated as:
where \(\bar{R_t}\) is the average of all rewards up to but not including step \(t\)
(by definition \(\bar{R}_1 = R_1\)). The derivation of given formula can be found in [7].
In the stationary case, \(\bar{R_t}\) can be calculated as
\(\bar{R}_{t + 1} = \bar{R}_t + \frac{1}{t} \lbrack R_t - \bar{R}_t \rbrack\). To improve the
algorithm’s performance in the non-stationary case, we apply
\(\bar{R}_{t + 1} = \bar{R}_t + \alpha \lbrack R_t - \bar{R}_t \rbrack\) with a constant
step size \(\alpha\).
Reward \(R_t\) is multiplied by multiplier before updating preferences to allow for
more flexible reward scaling while keeping the algorithm’s properties.
Parameters:
state (SoftmaxState) – Current state of the agent.
key (PRNGKey) – A PRNG key used as the random key.
action (int) – Previously selected action.
reward (float) – Reward collected by the agent after taking the previous action.
lr (float) – Step size.
alpha (float) – Exponential recency-weighted average factor (used when \(\alpha > 0\)).
The policy of the Softmax algorithm is stochastic. The algorithm draws the next action from the softmax
distribution. The probability of selecting action \(i\) is calculated as:
Contextual Bernoulli Thompson sampling agent with the exponential smoothing. The implementation is inspired by the
work of Krotov et al. [9]. Thompson sampling is based on a beta distribution with parameters related to the number
of successful and failed attempts. Higher values of the parameters decrease the entropy of the distribution while
changing the ratio of the parameters shifts the expected value.
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).
decay (float, default=0.0) – Decay rate. If equal to zero, smoothing is not applied. \(w \geq 0\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Creates and initializes an instance of the Thompson sampling agent for n_arms arms. The \(\mathbf{\alpha}\)
and \(\mathbf{\beta}\) vectors are set to zero to create a non-informative prior distribution.
The last_decay is also set to zero.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
Thompson sampling can be adjusted to non-stationary environments by exponential smoothing of values of
vectors \(\mathbf{\alpha}\) and \(\mathbf{\beta}\) which increases the entropy of a distribution
over time. Given a result of trial \(s\), we apply the following equations for each action \(a\):
The Thompson sampling policy is stochastic. The algorithm draws \(q_a\) from the distribution
\(\operatorname{Beta}(1 + \mathbf{\alpha}(a), 1 + \mathbf{\beta}(a))\) for each arm \(a\).
The next action is selected as
Normal Thompson sampling agent [11]. The normal-inverse-gamma distribution is a conjugate prior for the normal
distribution with unknown mean and variance. The parameters of the distribution are updated after each observation.
The mean of the normal distribution is sampled from the normal-inverse-gamma distribution and the action with
the highest expected value is selected.
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\) .
alpha (float) – See also NormalThompsonSamplingState for interpretation. \(\alpha > 0\).
beta (float) – See also NormalThompsonSamplingState for interpretation. \(\beta > 0\).
lam (float) – See also NormalThompsonSamplingState for interpretation. \(\lambda > 0\).
mu (float) – See also NormalThompsonSamplingState for interpretation. \(\mu \in \mathbb{R}\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
The normal Thompson sampling policy is stochastic. The algorithm draws \(q_a\) from the distribution
\(\operatorname{Normal}(\mu(a), \operatorname{scale}(a)/\sqrt{\lambda(a)})\) for each arm \(a\) where
\(\text{scale}(a)\) is sampled from the inverse gamma distribution with parameters \(\alpha(a)\) and
\(\beta(a)\). The next action is selected as \(A = \operatorname*{argmax}_{a \in \mathscr{A}} q_a\),
where \(\mathscr{A}\) is a set of all actions.
Log-normal Thompson sampling agent. This algorithm is designed to handle positive rewards by transforming
them into the log-space. For more details, refer to the documentation on NormalThompsonSampling.
Log-normal Thompson sampling update. The update is analogous to the one in NormalThompsonSampling except
that the reward is transformed into the log-space.
Sampling actions is analogous to the one in NormalThompsonSampling except that the expected value of
the log-normal distribution is computed instead of the expected value of the normal distribution.
UCB agent with optional discounting. The main idea behind this algorithm is to introduce a preference factor
in its policy, so that the selection of the next action is based on both the current estimation of the
action-value function and the uncertainty of this estimation.
Parameters:
n_arms (int) – Number of bandit arms. \(N \in \mathbb{N}_{+}\).
c (float) – Degree of exploration. \(c \geq 0\).
gamma (float, default=1.0) – If less than one, a discounted UCB algorithm [10] is used. \(\gamma \in (0, 1]\).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Creates and initializes instance of the UCB agent for n_arms arms. The sum of the rewards is set to zero
and the number of tries is set to one for each arm.
Parameters:
key (PRNGKey) – A PRNG key used as the random key.
In the stationary case, the sum of the rewards for a given arm is increased by reward \(r\) obtained after
step \(t\) and the number of tries for the corresponding arm is incremented. In the non-stationary case,
the update follows the equations
where \(\mathscr{A}\) is a set of all actions and \(Q\) is calculated as \(Q(a) = \frac{R(a)}{N(a)}\).
The second component of the sum represents a sort of upper bound on the value of \(Q\), where \(c\)
behaves like a confidence interval and the square root - like an approximation of the \(Q\) function
estimation uncertainty. Note that the UCB policy is deterministic (apart from choosing between several optimal actions).
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be gym.spaces.Dict or None. If None, the user must provide all parameters manually.
Observation space of the update method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Observation space of the sample method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If None, the user must provide all parameters manually.
Meta agent supporting dynamic change of number of arms.
This agent is highly experimental and is expected to be used with an extreme caution.
In particular, this agent makes the following strong assumptions:
Each entry in the base agent state has the first dimension corresponding to an arm.
The base agent must be stochastic as this agent uses rejection sampling to choose a possible action
Example usage of the agent can be found in the test test/experimental/test_masked.py.
Parameters:
agent (BaseAgent) – A MAB agent type which actions are masked.
mask (Array) – Binary mask array of the length equal to the number of arms. Positive entries are the masked actions.