PPO on Wall Environment¶

This example shows how to train PPO the Wall Environment.

Import required packages

import matplotlib.pyplot as plt
from rlenv.envs.wall.core import AccentaEnv

import torch as th
from stable_baselines3 import PPO

Make the environment

env = AccentaEnv()
/usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32
  logger.warn(

Make the agent

# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
# https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-network-architecture
policy_kwargs = dict(activation_fn=th.nn.ReLU,
                     net_arch=[dict(pi=[32, 32], vf=[32, 32])])

# Create the agent
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
/usr/local/lib/python3.8/site-packages/stable_baselines3/common/policies.py:458: UserWarning: As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, you should now pass directly a dictionary and not a list (net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])
  warnings.warn(

Train the agent

model.learn(total_timesteps=10000)
-----------------------------
| time/              |      |
|    fps             | 718  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 611         |
|    iterations           | 2           |
|    time_elapsed         | 6           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.097660676 |
|    clip_fraction        | 0.475       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.23       |
|    explained_variance   | 0.717       |
|    learning_rate        | 0.0003      |
|    loss                 | 3.49        |
|    n_updates            | 10          |
|    policy_gradient_loss | 0.028       |
|    std                  | 0.981       |
|    value_loss           | 8.25        |
-----------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 583          |
|    iterations           | 3            |
|    time_elapsed         | 10           |
|    total_timesteps      | 6144         |
| train/                  |              |
|    approx_kl            | 0.0126673505 |
|    clip_fraction        | 0.342        |
|    clip_range           | 0.2          |
|    entropy_loss         | -4.2         |
|    explained_variance   | 0.0028       |
|    learning_rate        | 0.0003       |
|    loss                 | 126          |
|    n_updates            | 20           |
|    policy_gradient_loss | 0.0464       |
|    std                  | 0.98         |
|    value_loss           | 412          |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 570         |
|    iterations           | 4           |
|    time_elapsed         | 14          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.010383712 |
|    clip_fraction        | 0.314       |
|    clip_range           | 0.2         |
|    entropy_loss         | -4.19       |
|    explained_variance   | 0.271       |
|    learning_rate        | 0.0003      |
|    loss                 | 105         |
|    n_updates            | 30          |
|    policy_gradient_loss | 0.0202      |
|    std                  | 0.98        |
|    value_loss           | 855         |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 8.75e+03   |
|    ep_rew_mean          | -3.16e+04  |
| time/                   |            |
|    fps                  | 560        |
|    iterations           | 5          |
|    time_elapsed         | 18         |
|    total_timesteps      | 10240      |
| train/                  |            |
|    approx_kl            | 0.01187706 |
|    clip_fraction        | 0.305      |
|    clip_range           | 0.2        |
|    entropy_loss         | -4.19      |
|    explained_variance   | 0.42       |
|    learning_rate        | 0.0003     |
|    loss                 | 149        |
|    n_updates            | 40         |
|    policy_gradient_loss | 0.0354     |
|    std                  | 0.979      |
|    value_loss           | 683        |
----------------------------------------

<stable_baselines3.ppo.ppo.PPO object at 0x7f619a4614f0>

Save the agent (optional)

model.save("../rlagent/data/trained_model")

Load the agent (optional) the policy_kwargs are automatically loaded

del model
model = PPO.load("../rlagent/data/trained_model", env=env)
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
/usr/local/lib/python3.8/site-packages/stable_baselines3/common/policies.py:458: UserWarning: As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, you should now pass directly a dictionary and not a list (net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])
  warnings.warn(

Assess the agent

score = AccentaEnv.eval(model)
print(score)

df = AccentaEnv.gen_one_episode(model)
df.plot()
plt.show()
plot train save and eval agent
/usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32
  logger.warn(
-32977.81519482677
/usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32
  logger.warn(

Total running time of the script: ( 2 minutes 12.719 seconds)

Gallery generated by Sphinx-Gallery