Stable Baselines3 是一个 Python 库,旨在简化强化学习 (RL) 算法的实现。它基于 PyTorch 构建,提供了流行的 RL 算法(如 PPO(近端策略优化)、DQN(深度 Q 网络)和 SAC(软 Actor-Critic))的预构建、优化实现。该库抽象掉了底层细节,使开发人员能够专注于训练和评估 RL 代理。它与 OpenAI Gym 环境无缝集成,使用户能够在 CartPole 或 Atari 游戏等标准化任务上训练代理。主要功能包括支持并行训练、超参数自定义以及用于在训练期间监控和保存模型的工具。
典型的工作流程包括三个步骤:定义环境、选择算法和训练代理。例如,使用 PPO 算法,开发人员首先使用 gym.make('CartPole-v1')
创建一个 Gym 环境。接下来,他们使用 PPO('MlpPolicy', env, verbose=1)
初始化模型,以指定策略网络(例如多层感知器)和环境。调用 model.learn(total_timesteps=10_000)
启动训练过程,代理与环境交互、收集经验并更新其策略以最大化奖励。该库会自动处理数据收集、神经网络更新和日志记录。可以添加回调来保存检查点或定期评估代理,并且可以保存和重新加载经过训练的模型以进行部署。
Stable Baselines3 还为高级用例提供自定义。开发人员可以通过覆盖策略类或使用 policy_kwargs
参数来修改神经网络架构。例如,更改策略网络中的层数或激活函数非常简单。该库支持向量化环境(通过 VecEnv
)进行并行训练,从而加快数据收集速度。可以添加预处理包装器(例如,用于规范化观察)来处理特定于环境的怪癖。此外,HER(后见之明经验回放)等工具通过重新标记失败的经验来帮助解决稀疏奖励问题。虽然该库简化了常见任务,但它仍然可以访问底层控件,使其在原型设计和生产中都具有灵活性。