Utils
This module contains a collection of utility functions that are used throughout the library.
JAX
- gradient_step(objective: any, loss_params: tuple, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], optimizer: GradientTransformation, loss_fn: Callable) tuple[any, any, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], float | int]
Performs a gradient step on the objective with respect to
grad_loss_fnfunction.grad_loss_fnshould return tuple of(loss, aux)where loss is the value to be minimized and aux is auxiliary value to be returned (can beNone).- Parameters:
objective (any) – Objective to be optimized.
loss_params (tuple) – Parameters to pass to
loss_fn.opt_state (optax.OptState) – Optimizer state.
optimizer (optax.GradientTransformation) – Optimizer to use for gradient step.
loss_fn (Callable) – Function that returns the loss to be minimized. Can return additional values as well.
- Returns:
out – Tuple containing the updated objective and optimizer state, as well as the loss value.
- Return type:
tuple[any, any, optax.OptState, Scalar]
- init(model: Module, key: PRNGKey, *x: Any) tuple[dict, dict]
Initializes the
flaxmodel.- Parameters:
model (nn.Module) – Model to be initialized.
key (PRNGKey) – A PRNG key used as the random key.
x (any) – Input to the model.
- Returns:
Tuple containing the parameters and the state of the model.
- Return type:
tuple[dict, dict]
- forward(model: Module, params: dict, state: dict, key: PRNGKey, *x: Any) tuple[Array | ndarray | bool | number, dict]
Forward pass through the
flaxmodel. Note: by default, the model is provided with two random key streams: one for the dropout layers and one for the user. This is done to ensure that the dropout is always initialized with the same random key, and that the user can use the custom key for any other purpose. The custom key is available in the model by callingself.make_rng('rlib').- Parameters:
model (nn.Module) – Model to be used for forward pass.
params (dict) – Parameters of the model.
state (dict) – State of the network.
key (PRNGKey) – A PRNG key used as the random key.
x (any) – Input to the model.
- Returns:
Tuple containing the output of the model and the updated state.
- Return type:
tuple[Array, dict]
Experience replay
- class ExperienceReplay(init: Callable, append: Callable, sample: Callable, is_ready: Callable)
Container for experience replay buffer functions.
- init
Function that initializes the replay buffer.
- Type:
Callable
- append
Function that appends a new values to the replay buffer.
- Type:
Callable
- sample
Function that samples a batch from the replay buffer.
- Type:
Callable
- is_ready
Function that checks if the replay buffer is ready to be sampled.
- Type:
Callable
- items() a set-like object providing a view on D's items
- keys() a set-like object providing a view on D's keys
- values() an object providing a view on D's values
- class ReplayBuffer(states: Array | ndarray | bool | number, actions: Array | ndarray | bool | number, rewards: Array | ndarray | bool | number, terminals: Array | ndarray | bool | number, next_states: Array | ndarray | bool | number, size: int, ptr: int)
Dataclass containing the replay buffer values. The replay buffer is implemented as a circular buffer.
- states
Array containing the states.
- Type:
Array
- actions
Array containing the actions.
- Type:
Array
- rewards
Array containing the rewards.
- Type:
Array
- terminals
Array containing the terminal flags.
- Type:
Array
- next_states
Array containing the next states.
- Type:
Array
- size
Current size of the replay buffer.
- Type:
int
- ptr
Current pointer of the replay buffer.
- Type:
int
- items() a set-like object providing a view on D's items
- keys() a set-like object providing a view on D's keys
- values() an object providing a view on D's values
- experience_replay(buffer_size: int, batch_size: int, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any]) ExperienceReplay
Experience replay buffer used for off-policy learning. Improves the stability of the learning process by reducing the correlation between the samples and enables an agent to learn from past experiences.
- Parameters:
buffer_size (int) – Maximum size of the replay buffer.
batch_size (int) – Size of the batch to be sampled from the replay buffer.
obs_space_shape (Shape) – Shape of the observation space.
act_space_shape (Shape) – Shape of the action space.
- Returns:
out – Container for experience replay buffer functions.
- Return type:
Particle filter
- class ParticleFilterState(positions: Array | ndarray | bool | number, logit_weights: Array | ndarray | bool | number)
Bases:
AgentState,MappingContainer for the state of the particle filter agent.
- positions
Positions of the particles.
- Type:
Array
- logit_weights
Unnormalized log weights of the particles.
- Type:
Array
- items() a set-like object providing a view on D's items
- keys() a set-like object providing a view on D's keys
- values() an object providing a view on D's values
- simple_resample(operands: tuple[ParticleFilterState, Array]) ParticleFilterState
Samples new particle positions from a categorical distribution with particle weights, then sets all weights equal.
- Parameters:
operands (tuple[ParticleFilterState, PRNGKey]) – Tuple containing the filter state and a PRNG key.
- Returns:
Updated filter state.
- Return type:
- effective_sample_size(state: ParticleFilterState, threshold: float | int = 0.5) bool
Calculates the effective sample size [2] (ESS). If ESS is smaller than the number of sample times threshold, then a resampling is necessary.
- Parameters:
state (ParticleFilterState) – Current state of the filter.
threshold (float, default=0.5) – Threshold value used to decide if a resampling is necessary. \(thr \in (0, 1)\).
- Returns:
Information whether a resampling should be performed.
- Return type:
bool
References
- simple_transition(state: ParticleFilterState, key: Array, scale: float | int, *args) ParticleFilterState
Performs simple movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale\).
- Parameters:
state (ParticleFilterState) – Current state of the filter.
key (PRNGKey) – A PRNG key used as the random key.
scale (float) – Scale of the random movement of particles. \(scale > 0\).
- Returns:
Updated filter state.
- Return type:
- linear_transition(state: ParticleFilterState, key: Array, scale: float | int, delta_time: float | int) ParticleFilterState
Performs movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale \cdot \Delta t\), where \(\Delta t\) is the time elapsed since the last update.
- Parameters:
state (ParticleFilterState) – Current state of the filter.
key (PRNGKey) – A PRNG key used as the random key.
scale (float) – Scale of the random movement of particles. \(scale > 0\).
delta_time (float) – Time elapsed since the last update.
- Returns:
Updated filter state.
- Return type:
- affine_transition(state: ParticleFilterState, key: Array, scale: Array | ndarray | bool | number, delta_time: float | int) ParticleFilterState
Performs movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale_0 \cdot \Delta t + scale_1\), where \(\Delta t\) is the time elapsed since the last update.
- Parameters:
state (ParticleFilterState) – Current state of the filter.
key (PRNGKey) – A PRNG key used as the random key.
scale (Array) – Scale of the random movement of particles. \(scale_0, scale_1 > 0\).
delta_time (float) – Time elapsed since the last update.
- Returns:
Updated filter state.
- Return type:
- class ParticleFilter(initial_distribution_fn: ~typing.Callable, positions_shape: ~typing.Sequence[int | ~typing.Any], weights_shape: ~typing.Sequence[int | ~typing.Any], scale: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | float | int, observation_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState, any], ~reinforced_lib.utils.particle_filter.ParticleFilterState], resample_fn: ~typing.Callable[[tuple[~reinforced_lib.utils.particle_filter.ParticleFilterState, ~jax.Array]], ~reinforced_lib.utils.particle_filter.ParticleFilterState] = <function simple_resample>, resample_criterion_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState], bool] = <function effective_sample_size>, transition_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState, ~jax.Array, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | float | int, float | int], ~reinforced_lib.utils.particle_filter.ParticleFilterState] = <function simple_transition>)
Bases:
objectParticle filter (sequential Monte Carlo) algorithm estimating the internal environment state given noisy or partial observations.
- Parameters:
initial_distribution_fn (callable) –
- Function that samples the initial particle positions; takes two positional arguments:
key: a PRNG key used as a random key (PRNGKey).shape: shape of the sample (Shape).
Returns the initial particle positions (Array).
positions_shape (Array) – Shape of the particle positions array.
weights_shape (Array) – Shape of the particle weights array.
scale (Array) – Scale of the random movement of the particles.
observation_fn (callable) –
- Function that updates particles based on an observation from the environment; takes two positional arguments:
state: the state of the filter (ParticleFilterState).observation: an observation from the environment (any).
Returns the updated state of the filter (ParticleFilterState).
resample_fn (callable, default=particle_filter.simple_resample) –
- Function that performs resampling of the particles; takes one positional argument:
operands: a tuple containing the filter state and a PRNG key (tuple[ParticleFilterState, PRNGKey]).
Returns the updated state of the filter (ParticleFilterState).
resample_criterion_fn (callable, default=particle_filter.effective_sample_size) –
- Function that checks if a resampling is necessary; takes one positional argument:
state: the state of the filter (ParticleFilterState).
Returns information whether a resampling should be performed (bool).
transition_fn (callable, default=particle_filter.simple_transition) –
- Function that updates the particle positions; takes four positional arguments:
state: the state of the filter (ParticleFilterState).key: a PRNG key used as a random key (PRNGKey).scale: scale of the random movement of the particles (Array).time: the current time (float).
Returns the updated state of the filter (ParticleFilterState).
- static init(key: Array, initial_distribution_fn: Callable, positions_shape: Sequence[int | Any], weights_shape: Sequence[int | Any]) ParticleFilterState
Creates and initializes an instance of the particle filter.
- Parameters:
key (PRNGKey) – A PRNG key used as the random key.
initial_distribution_fn (callable) –
- Function that samples the initial particle positions.
key: PRNG key used as a random key (PRNGKey).shape: shape of the sample (Shape).
Returns the initial particle positions (Array).
positions_shape (Array) – Shape of the particle positions array.
weights_shape (Array) – Shape of the particle weights array.
- Returns:
Initial state of the Particle Filter.
- Return type:
- static update(state: ParticleFilterState, key: Array, observation_fn: Callable[[ParticleFilterState, any], ParticleFilterState], observation: any, resample_fn: Callable[[tuple[ParticleFilterState, Array]], ParticleFilterState], resample_criterion_fn: Callable[[ParticleFilterState], bool], transition_fn: Callable[[ParticleFilterState, Array, Array | ndarray | bool | number | float | int, float | int], ParticleFilterState], delta_time: float | int, scale: Array | ndarray | bool | number | float | int) ParticleFilterState
Updates the state of the filter based on an observation from the environment, then performs resampling (if necessary) and transition of the particles.
- Parameters:
state (ParticleFilterState) – Current state of the filter.
key (PRNGKey) – A PRNG key used as the random key.
observation_fn (callable) –
- Function that updates particles based on an observation from the environment; takes two positional arguments:
state: the state of the filter (ParticleFilterState).observation: an observation from the environment (any).
Returns the updated state of the filter (ParticleFilterState).
observation (any) – An observation from the environment.
resample_fn (callable, default=particle_filter.simple_resample) –
- Function that performs resampling of the particles; takes one positional argument:
operands: a tuple containing the filter state and a PRNG key (tuple[ParticleFilterState, PRNGKey]).
Returns the updated state of the filter (ParticleFilterState).
resample_criterion_fn (callable, default=particle_filter.effective_sample_size) –
- Function that checks if a resampling is necessary; takes one positional argument:
state: the state of the filter (ParticleFilterState).
Returns information whether a resampling should be performed (bool).
transition_fn (callable, default=particle_filter.simple_transition) –
- Function that updates the particle positions; takes four positional arguments:
state: the state of the filter (ParticleFilterState).key: a PRNG key used as a random key (PRNGKey).scale: scale of the random movement of the particles (Array).time: the current time (float).
delta_time (float) – Time difference between the current and the previous observation.
scale (Array) – Scale of the random movement of the particles.
- Returns:
Updated filter state.
- Return type:
- static sample(state: ParticleFilterState, key: Array) Array | ndarray | bool | number | float | int
Samples the estimated environment state from a categorical distribution with particle weights.
- Parameters:
state (ParticleFilterState) – Current state of the filter.
key (PRNGKey) – A PRNG key used as the random key.
- Returns:
Estimated environment state.
- Return type:
Array