Source code for rlagent.agent

from stable_baselines3 import PPO
from rlagent.data import ppo_trained_model_example_path

from rlenv.envs.wall.core import AccentaEnv

[docs]def load_agent(): """ This function is mandatory for the ENS Challenge Data evaluation platform to evaluate your agent. This function serves as a single, common entry point for all groups to load your trained agent. Do not delete it or change its signature (i.e. do not add an argument to it) otherwise your agent will be rejected by the ENS evaluation platform. Here an example is provided to load a PPO agent whose weights are loaded from the file returned by the rlagent.data.ppo_trained_model_example_path() function. You have to adapt the content of this function to load your **pre-trained** agent. Here an example is provided to load a PPO agent whose weights are loaded from the file returned by the rlagent.data.ppo_trained_model_example_path() function. You have to adapt the content of this function to load your** pre-trained agent. WARNING: this function must return a pre-trained model (e.g. saved in a file included in the git repo, downloaded from the internet, etc.). You should not train a model here because there is a timeout on the execution time of this function on the ENS Challenge Data evaluation platform. Normally you would load a pre-trained model file, instantiate it and give it as the return value of this function. Example code for training an agent is available in the examples directory. You can use them as inspiration to define and train your own agent. Returns ------- Gym agent The pre-trained model to be evaluated on the ENS Challenge Data platform. """ env = AccentaEnv() # Load the agent # (in this example the policy_kwargs are automatically loaded from the file returned by `rlagent.data.ppo_trained_model_example_path()`) model = PPO.load(ppo_trained_model_example_path(), env=env, device="cpu") return model