项目作者: Ending2015a

项目描述 :
A TF2.0 implementation of RL baselines.
高级语言: Python
项目地址: git://github.com/Ending2015a/unstable_baselines.git
创建时间: 2021-02-13T10:55:54Z
项目社区:https://github.com/Ending2015a/unstable_baselines

开源协议:MIT License

下载


Unstable Baselines (Early Access)

A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.

Features Unstable Baselines Stable-Baselines3 OpenAI Baselines
State of the art RL methods :heavy_minus_sign: (1) :heavy_check_mark: :heavy_check_mark:
Documentation :x: :heavy_check_mark: :x:
Custom callback (2) :x: :vomiting_face: :heavy_minus_sign:
TensorFlow 2.0 support :heavy_check_mark: :x: :x:
Clean, elegant code :heavy_check_mark: :x: :x:
Easy to trace, customize :heavy_check_mark: :x: (3) :x: (3)
Standalone implementations :heavy_check_mark: :heavy_minus_sign: :x: (4)

(1) Currently only support DQN, C51, PPO, TD3, …etc. We are still working on other algorithms.

(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model :vomiting_face:, while in Unstable Baselines, they are automatically saved.

(3) If you have traced Stable-baselines or OpenAI/baselines once, you’ll never do that again.

(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py…. Great job! OpenAI!:cat:

Documentation

We don’t have any documentation yet.

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow (CPU/GPU) >= 2.3.0

You can install from PyPI

  1. $ pip install unstable_baselines

Or you can also install the latest version from this repository

  1. $ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master

Done! Now, you can

Algorithms

Model-free RL

Algorithm Box Discrete MultiDiscrete MultiBinary
DQN :x: :heavy_check_mark: :x: :x:
PPO :heavy_check_mark: :heavy_check_mark: :x: :x:
TD3 :heavy_check_mark: :x: :x: :x:
SD3 :heavy_check_mark: :x: :x: :x:

Distributional RL

Algorithm Box Discrete MultiDiscrete MultiBinary
C51 :x: :heavy_check_mark: :x: :x:
QRDQN :x: :heavy_check_mark: :x: :x:
IQN :x: :heavy_check_mark: :x: :x:

Quick Start

This example shows how to train a PPO agent to play CartPole-v0. You can find the full scripts in example/cartpole/train_ppo.py.

First, import dependencies

  1. import gym
  2. import unstable_baselines as ub
  3. from unstable_baselines.algo.ppo import PPO

Create environments for training and evaluation

  1. # create environments
  2. env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
  3. eval_env = gym.make('CartPole-v0')

Create a PPO model and train it

  1. model = PPO(
  2. env,
  3. learning_rate=1e-3,
  4. gamma=0.8,
  5. batch_size=128,
  6. n_steps=500
  7. ).learn( # train for 20000 steps
  8. 20000,
  9. verbose=1
  10. )

Save and load the trained model

  1. model.save('./my_ppo_model')
  2. model = PPO.load('./my_ppo_model')

Evaluate the training results

  1. model.eval(eval_env, 20, 200, render=True)
  2. # don't forget to close the environments!
  3. env.close()
  4. eval_env.close()

More examples:

Update Logs

  • 2021.05.22: Add benchmarks
  • 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.