API module
RLib class
- class RLib(*, agent_type: type = None, agent_params: dict[str, any] = None, ext_type: type = None, ext_params: dict[str, any] = None, logger_types: type | list[type] = None, logger_sources: tuple[str, SourceType] | str | None | list[tuple[str, SourceType] | str | None] = None, logger_params: dict[str, any] = None, no_ext_mode: bool = False, auto_checkpoint: int = None, auto_checkpoint_path: str = None)
Main class of the library. Exposes a simple and intuitive interface to use the library.
- Parameters:
agent_type (type, optional) – Type of the selected agent. Must inherit from the
BaseAgentclass.agent_params (dict, optional) – Parameters of the selected agent.
ext_type (type, optional) – Type of the selected extension. Must inherit from the
BaseExtclass.ext_params (dict, optional) – Parameters of the selected extension.
logger_types (type or list[type], optional) – Types of the selected loggers. Must inherit from the
BaseLoggerclass.logger_sources (Source or list[Source], optional) – Sources to log.
logger_params (dict, optional) – Parameters of the selected loggers.
no_ext_mode (bool, default=False) – Pass observations directly to the agent (do not use the extensions).
auto_checkpoint (int, optional) – Automatically save the experiment every
auto_checkpointsteps. IfNone, the automatic checkpointing is disabled.auto_checkpoint_path (str, optional, default=~) – Path to the directory where the automatic checkpoints will be saved.
- finish() None
Used to explicitly finalize the library’s work. In particular, it finishes the logger’s work.
- set_agent(agent_type: type, agent_params: dict = None) None
Initializes an agent of type
agent_typewith parametersagent_params. The agent type must inherit from theBaseAgentclass. The agent type cannot be changed after the first agent instance has been initialized.- Parameters:
agent_type (type) – Type of the selected agent. Must inherit from the
BaseAgentclass.agent_params (dict, optional) – Parameters of the selected agent.
- set_ext(ext_type: type, ext_params: dict = None) None
Initializes an extension of type
ext_typewith parametersext_params. The extension type must inherit from theBaseExtclass. The extension type cannot be changed after the first agent instance has been initialized.- Parameters:
ext_type (type) – Type of selected extension. Must inherit from the
BaseExtclass.ext_params (dict, optional) – Parameters of the selected extension.
- set_loggers(logger_types: type | list[type], logger_sources: tuple[str, SourceType] | str | None | list[tuple[str, SourceType] | str | None] = None, logger_params: dict[str, any] = None) None
Initializes loggers of types
logger_typeswith parameterslogger_params. The logger types must inherit from theBaseLoggerclass. The logger types cannot be changed after the first agent instance has been initialized.logger_typesandlogger_sourcescan be objects or lists of objects, the function broadcasts them so that all loggers are connected to all sources. Thelogger_sourcesparameter specifies the sources to log. A source can be a name (e.g., “action”) or tuple containing the name and theSourceType(e.g.,("action", SourceType.OBSERVATION)). If the name itself is inconclusive (e.g., it occurs as a metric and as an observation), the behavior depends on the implementation of the logger.- Parameters:
logger_types (type or list[type]) – Types of the selected loggers.
logger_sources (Source or list[Source], optional) – Sources to log.
logger_params (dict, optional) – Parameters of the selected loggers.
- property observation_space: Space
Returns the observation space of the selected extension (or agent, if
no_ext_modeis set).- Returns:
Observation space of the selected extension or agent.
- Return type:
gym.spaces.Space
- property action_space: Space
Returns the action space of the selected agent.
- Returns:
Action space of the selected agent.
- Return type:
gym.spaces.Space
- init(seed: int = 42) int
Initializes a new instance of the agent.
- Parameters:
seed (int, default=42) – Number used to initialize the JAX pseudo-random number generator.
- Returns:
Identifier of the created instance.
- Return type:
int
- sample(*args, agent_id: int = 0, is_training: bool = True, update_observations: dict | tuple | any = None, sample_observations: dict | tuple | any = None, **kwargs) any
Takes the extension state as an input, updates the agent state, and returns the next action selected by the agent. If
no_ext_modeis disabled, observations are passed by args and kwargs (the observations must match the extension observation space). Ifno_ext_modeis enabled, observations must be passed by theupdate_observationsandsample_observationsparameters (the observations must match the agent’supdate_observation_spaceandsample_observation_space). If there are no agent instances initialized, the method automatically initializes the first instance. If theis_trainingflag is set, theupdateandsampleagent methods will be called. Otherwise, only thesamplemethod will be called.- Parameters:
*args (tuple) – Environment observations.
agent_id (int, default=0) – The identifier of the agent instance.
is_training (bool) – Flag indicating whether the agent state should be updated in this step.
update_observations (dict or tuple or any, optional) – Observations used when
no_ext_modeis enabled (must match agent’supdate_observation_space).sample_observations (dict or tuple or any, optional) – Observations used when
no_ext_modeis enabled (must match agent’ssample_observation_space).**kwargs (dict) – Environment observations.
- Returns:
Action selected by the agent.
- Return type:
any
- save(path: str = None, *, agent_ids: int | list[int] = None) str
Saves the state of the experiment to a file in lz4 format. For each agent, both the state and the initialization parameters are saved. The extension and loggers settings are saved as well to fully reconstruct the experiment.
- Parameters:
path (str, optional) – Path to the checkpoint file. If none specified, saves to the default path. If the
.pkl.lz4suffix is not detected, it will be appended automatically.agent_ids (int or Array, optional) – The identifier of the agent instance(s) to save. If none specified, saves the state of all agents.
- Returns:
Path to the saved checkpoint file.
- Return type:
str
- static load(path: str, *, agent_params: dict[str, any] = None, ext_params: dict[str, any] = None, logger_types: type | list[type] = None, logger_sources: tuple[str, SourceType] | str | None | list[tuple[str, SourceType] | str | None] = None, logger_params: dict[str, any] = None) RLib
Loads the state of the experiment from a file in lz4 format.
- Parameters:
path (str) – Path to the checkpoint file.
agent_params (dict[str, any], optional) – Dictionary of altered agent parameters with their new values, by default None.
ext_params (dict[str, any], optional) – Dictionary of altered extension parameters with their new values, by default None.
logger_types (type or list[type], optional) – Types of the selected loggers. Must inherit from the
BaseLoggerclass.logger_sources (Source or list[Source], optional) – Sources to log.
logger_params (dict, optional) – Parameters of the selected loggers.
- log(name: str, value: any) None
Logs a custom value.
- Parameters:
name (str) – The name of the value to log.
value (any) – The value to log.
- to_tflite(path: str = None, *, agent_id: int = None, sample_only: bool = False) None
Converts the agent to a TensorFlow Lite model and saves it to a file.
- Parameters:
path (str, optional) – Path to the output file.
agent_id (int, optional) – The identifier of the agent instance to convert. If specified, state of the selected agent will be saved.
sample_only (bool) – Flag indicating if the method should save only the sample function.