.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/plot_train_save_and_eval_agent.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_plot_train_save_and_eval_agent.py: ======================= PPO on Wall Environment ======================= This example shows how to train PPO the Wall Environment. .. GENERATED FROM PYTHON SOURCE LINES 13-14 Import required packages .. GENERATED FROM PYTHON SOURCE LINES 14-22 .. code-block:: default import matplotlib.pyplot as plt from rlenv.envs.wall.core import AccentaEnv import torch as th from stable_baselines3 import PPO .. GENERATED FROM PYTHON SOURCE LINES 23-24 Make the environment .. GENERATED FROM PYTHON SOURCE LINES 24-28 .. code-block:: default env = AccentaEnv() .. rst-class:: sphx-glr-script-out .. code-block:: none /usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32 logger.warn( .. GENERATED FROM PYTHON SOURCE LINES 29-30 Make the agent .. GENERATED FROM PYTHON SOURCE LINES 30-41 .. code-block:: default # Custom actor (pi) and value function (vf) networks # of two layers of size 32 each with Relu activation function # https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-network-architecture policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[dict(pi=[32, 32], vf=[32, 32])]) # Create the agent model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) .. rst-class:: sphx-glr-script-out .. code-block:: none Using cpu device Wrapping the env with a `Monitor` wrapper Wrapping the env in a DummyVecEnv. /usr/local/lib/python3.8/site-packages/stable_baselines3/common/policies.py:458: UserWarning: As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, you should now pass directly a dictionary and not a list (net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)]) warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 42-43 Train the agent .. GENERATED FROM PYTHON SOURCE LINES 43-47 .. code-block:: default model.learn(total_timesteps=10000) .. rst-class:: sphx-glr-script-out .. code-block:: none ----------------------------- | time/ | | | fps | 718 | | iterations | 1 | | time_elapsed | 2 | | total_timesteps | 2048 | ----------------------------- ----------------------------------------- | time/ | | | fps | 611 | | iterations | 2 | | time_elapsed | 6 | | total_timesteps | 4096 | | train/ | | | approx_kl | 0.097660676 | | clip_fraction | 0.475 | | clip_range | 0.2 | | entropy_loss | -4.23 | | explained_variance | 0.717 | | learning_rate | 0.0003 | | loss | 3.49 | | n_updates | 10 | | policy_gradient_loss | 0.028 | | std | 0.981 | | value_loss | 8.25 | ----------------------------------------- ------------------------------------------ | time/ | | | fps | 583 | | iterations | 3 | | time_elapsed | 10 | | total_timesteps | 6144 | | train/ | | | approx_kl | 0.0126673505 | | clip_fraction | 0.342 | | clip_range | 0.2 | | entropy_loss | -4.2 | | explained_variance | 0.0028 | | learning_rate | 0.0003 | | loss | 126 | | n_updates | 20 | | policy_gradient_loss | 0.0464 | | std | 0.98 | | value_loss | 412 | ------------------------------------------ ----------------------------------------- | time/ | | | fps | 570 | | iterations | 4 | | time_elapsed | 14 | | total_timesteps | 8192 | | train/ | | | approx_kl | 0.010383712 | | clip_fraction | 0.314 | | clip_range | 0.2 | | entropy_loss | -4.19 | | explained_variance | 0.271 | | learning_rate | 0.0003 | | loss | 105 | | n_updates | 30 | | policy_gradient_loss | 0.0202 | | std | 0.98 | | value_loss | 855 | ----------------------------------------- ---------------------------------------- | rollout/ | | | ep_len_mean | 8.75e+03 | | ep_rew_mean | -3.16e+04 | | time/ | | | fps | 560 | | iterations | 5 | | time_elapsed | 18 | | total_timesteps | 10240 | | train/ | | | approx_kl | 0.01187706 | | clip_fraction | 0.305 | | clip_range | 0.2 | | entropy_loss | -4.19 | | explained_variance | 0.42 | | learning_rate | 0.0003 | | loss | 149 | | n_updates | 40 | | policy_gradient_loss | 0.0354 | | std | 0.979 | | value_loss | 683 | ---------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 48-49 Save the agent (optional) .. GENERATED FROM PYTHON SOURCE LINES 49-53 .. code-block:: default model.save("../rlagent/data/trained_model") .. GENERATED FROM PYTHON SOURCE LINES 54-56 Load the agent (optional) the policy_kwargs are automatically loaded .. GENERATED FROM PYTHON SOURCE LINES 56-61 .. code-block:: default del model model = PPO.load("../rlagent/data/trained_model", env=env) .. rst-class:: sphx-glr-script-out .. code-block:: none Wrapping the env with a `Monitor` wrapper Wrapping the env in a DummyVecEnv. /usr/local/lib/python3.8/site-packages/stable_baselines3/common/policies.py:458: UserWarning: As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, you should now pass directly a dictionary and not a list (net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)]) warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 62-63 Assess the agent .. GENERATED FROM PYTHON SOURCE LINES 63-69 .. code-block:: default score = AccentaEnv.eval(model) print(score) df = AccentaEnv.gen_one_episode(model) df.plot() plt.show() .. image-sg:: /gallery/images/sphx_glr_plot_train_save_and_eval_agent_001.png :alt: plot train save and eval agent :srcset: /gallery/images/sphx_glr_plot_train_save_and_eval_agent_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32 logger.warn( -32977.81519482677 /usr/local/lib/python3.8/site-packages/gym/spaces/box.py:73: UserWarning: WARN: Box bound precision lowered by casting to float32 logger.warn( .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 12.719 seconds) .. _sphx_glr_download_gallery_plot_train_save_and_eval_agent.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_train_save_and_eval_agent.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_train_save_and_eval_agent.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_