Utils

This module contains a collection of utility functions that are used throughout the library.

JAX

gradient_step(objective: any, loss_params: tuple, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], optimizer: GradientTransformation, loss_fn: Callable) tuple[any, any, Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], float | int]

Performs a gradient step on the objective with respect to grad_loss_fn function. grad_loss_fn should return tuple of (loss, aux) where loss is the value to be minimized and aux is auxiliary value to be returned (can be None).

Parameters:
  • objective (any) – Objective to be optimized.

  • loss_params (tuple) – Parameters to pass to loss_fn.

  • opt_state (optax.OptState) – Optimizer state.

  • optimizer (optax.GradientTransformation) – Optimizer to use for gradient step.

  • loss_fn (Callable) – Function that returns the loss to be minimized. Can return additional values as well.

Returns:

out – Tuple containing the updated objective and optimizer state, as well as the loss value.

Return type:

tuple[any, any, optax.OptState, Scalar]

init(model: Module, key: PRNGKey, *x: Any) tuple[dict, dict]

Initializes the flax model.

Parameters:
  • model (nn.Module) – Model to be initialized.

  • key (PRNGKey) – A PRNG key used as the random key.

  • x (any) – Input to the model.

Returns:

Tuple containing the parameters and the state of the model.

Return type:

tuple[dict, dict]

forward(model: Module, params: dict, state: dict, key: PRNGKey, *x: Any) tuple[Array | ndarray | bool | number, dict]

Forward pass through the flax model. Note: by default, the model is provided with two random key streams: one for the dropout layers and one for the user. This is done to ensure that the dropout is always initialized with the same random key, and that the user can use the custom key for any other purpose. The custom key is available in the model by calling self.make_rng('rlib').

Parameters:
  • model (nn.Module) – Model to be used for forward pass.

  • params (dict) – Parameters of the model.

  • state (dict) – State of the network.

  • key (PRNGKey) – A PRNG key used as the random key.

  • x (any) – Input to the model.

Returns:

Tuple containing the output of the model and the updated state.

Return type:

tuple[Array, dict]

Experience replay

class ExperienceReplay(init: Callable, append: Callable, sample: Callable, is_ready: Callable)

Container for experience replay buffer functions.

init

Function that initializes the replay buffer.

Type:

Callable

append

Function that appends a new values to the replay buffer.

Type:

Callable

sample

Function that samples a batch from the replay buffer.

Type:

Callable

is_ready

Function that checks if the replay buffer is ready to be sampled.

Type:

Callable

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class ReplayBuffer(states: Array | ndarray | bool | number, actions: Array | ndarray | bool | number, rewards: Array | ndarray | bool | number, terminals: Array | ndarray | bool | number, next_states: Array | ndarray | bool | number, size: int, ptr: int)

Dataclass containing the replay buffer values. The replay buffer is implemented as a circular buffer.

states

Array containing the states.

Type:

Array

actions

Array containing the actions.

Type:

Array

rewards

Array containing the rewards.

Type:

Array

terminals

Array containing the terminal flags.

Type:

Array

next_states

Array containing the next states.

Type:

Array

size

Current size of the replay buffer.

Type:

int

ptr

Current pointer of the replay buffer.

Type:

int

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
experience_replay(buffer_size: int, batch_size: int, obs_space_shape: Sequence[int | Any], act_space_shape: Sequence[int | Any]) ExperienceReplay

Experience replay buffer used for off-policy learning. Improves the stability of the learning process by reducing the correlation between the samples and enables an agent to learn from past experiences.

Parameters:
  • buffer_size (int) – Maximum size of the replay buffer.

  • batch_size (int) – Size of the batch to be sampled from the replay buffer.

  • obs_space_shape (Shape) – Shape of the observation space.

  • act_space_shape (Shape) – Shape of the action space.

Returns:

out – Container for experience replay buffer functions.

Return type:

ExperienceReplay

Particle filter

class ParticleFilterState(positions: Array | ndarray | bool | number, logit_weights: Array | ndarray | bool | number)

Bases: AgentState, Mapping

Container for the state of the particle filter agent.

positions

Positions of the particles.

Type:

Array

logit_weights

Unnormalized log weights of the particles.

Type:

Array

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
simple_resample(operands: tuple[ParticleFilterState, Array]) ParticleFilterState

Samples new particle positions from a categorical distribution with particle weights, then sets all weights equal.

Parameters:

operands (tuple[ParticleFilterState, PRNGKey]) – Tuple containing the filter state and a PRNG key.

Returns:

Updated filter state.

Return type:

ParticleFilterState

effective_sample_size(state: ParticleFilterState, threshold: float | int = 0.5) bool

Calculates the effective sample size [2] (ESS). If ESS is smaller than the number of sample times threshold, then a resampling is necessary.

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • threshold (float, default=0.5) – Threshold value used to decide if a resampling is necessary. \(thr \in (0, 1)\).

Returns:

Information whether a resampling should be performed.

Return type:

bool

References

simple_transition(state: ParticleFilterState, key: Array, scale: float | int, *args) ParticleFilterState

Performs simple movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale\).

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • key (PRNGKey) – A PRNG key used as the random key.

  • scale (float) – Scale of the random movement of particles. \(scale > 0\).

Returns:

Updated filter state.

Return type:

ParticleFilterState

linear_transition(state: ParticleFilterState, key: Array, scale: float | int, delta_time: float | int) ParticleFilterState

Performs movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale \cdot \Delta t\), where \(\Delta t\) is the time elapsed since the last update.

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • key (PRNGKey) – A PRNG key used as the random key.

  • scale (float) – Scale of the random movement of particles. \(scale > 0\).

  • delta_time (float) – Time elapsed since the last update.

Returns:

Updated filter state.

Return type:

ParticleFilterState

affine_transition(state: ParticleFilterState, key: Array, scale: Array | ndarray | bool | number, delta_time: float | int) ParticleFilterState

Performs movement of the particle positions according to a normal distribution with \(\mu = 0\) and \(\sigma = scale_0 \cdot \Delta t + scale_1\), where \(\Delta t\) is the time elapsed since the last update.

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • key (PRNGKey) – A PRNG key used as the random key.

  • scale (Array) – Scale of the random movement of particles. \(scale_0, scale_1 > 0\).

  • delta_time (float) – Time elapsed since the last update.

Returns:

Updated filter state.

Return type:

ParticleFilterState

class ParticleFilter(initial_distribution_fn: ~typing.Callable, positions_shape: ~typing.Sequence[int | ~typing.Any], weights_shape: ~typing.Sequence[int | ~typing.Any], scale: ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | float | int, observation_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState, any], ~reinforced_lib.utils.particle_filter.ParticleFilterState], resample_fn: ~typing.Callable[[tuple[~reinforced_lib.utils.particle_filter.ParticleFilterState, ~jax.Array]], ~reinforced_lib.utils.particle_filter.ParticleFilterState] = <function simple_resample>, resample_criterion_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState], bool] = <function effective_sample_size>, transition_fn: ~typing.Callable[[~reinforced_lib.utils.particle_filter.ParticleFilterState, ~jax.Array, ~jax.Array | ~numpy.ndarray | ~numpy.bool | ~numpy.number | float | int, float | int], ~reinforced_lib.utils.particle_filter.ParticleFilterState] = <function simple_transition>)

Bases: object

Particle filter (sequential Monte Carlo) algorithm estimating the internal environment state given noisy or partial observations.

Parameters:
  • initial_distribution_fn (callable) –

    Function that samples the initial particle positions; takes two positional arguments:
    • key: a PRNG key used as a random key (PRNGKey).

    • shape: shape of the sample (Shape).

    Returns the initial particle positions (Array).

  • positions_shape (Array) – Shape of the particle positions array.

  • weights_shape (Array) – Shape of the particle weights array.

  • scale (Array) – Scale of the random movement of the particles.

  • observation_fn (callable) –

    Function that updates particles based on an observation from the environment; takes two positional arguments:
    • state: the state of the filter (ParticleFilterState).

    • observation: an observation from the environment (any).

    Returns the updated state of the filter (ParticleFilterState).

  • resample_fn (callable, default=particle_filter.simple_resample) –

    Function that performs resampling of the particles; takes one positional argument:
    • operands: a tuple containing the filter state and a PRNG key (tuple[ParticleFilterState, PRNGKey]).

    Returns the updated state of the filter (ParticleFilterState).

  • resample_criterion_fn (callable, default=particle_filter.effective_sample_size) –

    Function that checks if a resampling is necessary; takes one positional argument:
    • state: the state of the filter (ParticleFilterState).

    Returns information whether a resampling should be performed (bool).

  • transition_fn (callable, default=particle_filter.simple_transition) –

    Function that updates the particle positions; takes four positional arguments:
    • state: the state of the filter (ParticleFilterState).

    • key: a PRNG key used as a random key (PRNGKey).

    • scale: scale of the random movement of the particles (Array).

    • time: the current time (float).

    Returns the updated state of the filter (ParticleFilterState).

static init(key: Array, initial_distribution_fn: Callable, positions_shape: Sequence[int | Any], weights_shape: Sequence[int | Any]) ParticleFilterState

Creates and initializes an instance of the particle filter.

Parameters:
  • key (PRNGKey) – A PRNG key used as the random key.

  • initial_distribution_fn (callable) –

    Function that samples the initial particle positions.
    • key: PRNG key used as a random key (PRNGKey).

    • shape: shape of the sample (Shape).

    Returns the initial particle positions (Array).

  • positions_shape (Array) – Shape of the particle positions array.

  • weights_shape (Array) – Shape of the particle weights array.

Returns:

Initial state of the Particle Filter.

Return type:

ParticleFilterState

static update(state: ParticleFilterState, key: Array, observation_fn: Callable[[ParticleFilterState, any], ParticleFilterState], observation: any, resample_fn: Callable[[tuple[ParticleFilterState, Array]], ParticleFilterState], resample_criterion_fn: Callable[[ParticleFilterState], bool], transition_fn: Callable[[ParticleFilterState, Array, Array | ndarray | bool | number | float | int, float | int], ParticleFilterState], delta_time: float | int, scale: Array | ndarray | bool | number | float | int) ParticleFilterState

Updates the state of the filter based on an observation from the environment, then performs resampling (if necessary) and transition of the particles.

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • key (PRNGKey) – A PRNG key used as the random key.

  • observation_fn (callable) –

    Function that updates particles based on an observation from the environment; takes two positional arguments:
    • state: the state of the filter (ParticleFilterState).

    • observation: an observation from the environment (any).

    Returns the updated state of the filter (ParticleFilterState).

  • observation (any) – An observation from the environment.

  • resample_fn (callable, default=particle_filter.simple_resample) –

    Function that performs resampling of the particles; takes one positional argument:
    • operands: a tuple containing the filter state and a PRNG key (tuple[ParticleFilterState, PRNGKey]).

    Returns the updated state of the filter (ParticleFilterState).

  • resample_criterion_fn (callable, default=particle_filter.effective_sample_size) –

    Function that checks if a resampling is necessary; takes one positional argument:
    • state: the state of the filter (ParticleFilterState).

    Returns information whether a resampling should be performed (bool).

  • transition_fn (callable, default=particle_filter.simple_transition) –

    Function that updates the particle positions; takes four positional arguments:
    • state: the state of the filter (ParticleFilterState).

    • key: a PRNG key used as a random key (PRNGKey).

    • scale: scale of the random movement of the particles (Array).

    • time: the current time (float).

  • delta_time (float) – Time difference between the current and the previous observation.

  • scale (Array) – Scale of the random movement of the particles.

Returns:

Updated filter state.

Return type:

ParticleFilterState

static sample(state: ParticleFilterState, key: Array) Array | ndarray | bool | number | float | int

Samples the estimated environment state from a categorical distribution with particle weights.

Parameters:
  • state (ParticleFilterState) – Current state of the filter.

  • key (PRNGKey) – A PRNG key used as the random key.

Returns:

Estimated environment state.

Return type:

Array