Train a PPO agent on Accenta’s environment

This example shows how to train a PPO agent.

Import required packages

from rlenv.envs.wall.core import AccentaEnv

import torch as th
from stable_baselines3 import PPO
from rlagent.data import ppo_trained_model_example_path

Make the environment

env = AccentaEnv()

Make the agent’s policy

# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
# https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-network-architecture
policy_kwargs = dict(activation_fn=th.nn.ReLU,
                     net_arch=[dict(pi=[32, 32], vf=[32, 32])])

Make the agent

model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

Train the agent

model.learn(total_timesteps=10000)

Save the trained agent

model.save(ppo_trained_model_example_path())

Reload the trained agent

del model
model = PPO.load(ppo_trained_model_example_path(), env=env)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery