找回密码
 立即注册
首页 业界区 业界 stable_baseline3 快速入门(二): 训练自定义游戏,构建G ...

stable_baseline3 快速入门(二): 训练自定义游戏,构建Gymnasium训练环境

吕颐然 3 天前
简介

Gymnasium 为强化学习提供了一个标准化的API,它定义了 Agent 应该如何观察世界、如何做出动作以及如何获得奖励,不管是游戏,还是工业设备,只需要满足Gymnasium标准都能使用同一套代码进行训练。
认识Gymnasium

使用stable_baseline3只需要定义好Gymnasium环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。
Gymnasium提供了几个核心的api:
方法功能返回值reset()将环境重置为初始状态,开始新回合。obs, infostep(action)环境向前推进一步,执行动作。obs, reward, terminated, truncated, inforender()可视化环境(根据 render_mode 渲染图像或弹出窗口)。视配置而定(通常无或为 np.array)close()释放环境资源(关闭窗口、清理内存)。无其中的各个返回值的含义:

  • observation (Object): 当前状态的描述。例如敌人,玩家的位置,玩家的状态等
  • reward (Float): 上一步动作获得的奖励
  • terminated (Bool): 是否由于任务逻辑结束。例如:到达终点、掉进岩浆等
  • truncated (Bool): 是否由于外部限制结束。例如:达到最大步数 500 步
  • info (Dict): 辅助诊断信息,模型训练通常不用,用于用户自定义调试或记录额外统计。
手动构建环境

案例

案例描述:利用pygame构建一个简单的游戏,躲避掉落方块,利用构建的奖励机制,进行强化学习。
[code]import gymnasium as gymfrom gymnasium import spacesimport numpy as npimport pygameimport randomimport cv2import osfrom stable_baselines3 import PPOfrom stable_baselines3.common.callbacks import CheckpointCallbackfrom stable_baselines3.common.env_checker import check_envclass MyEnv(gym.Env):    def __init__(self, render_mode=None):        super(MyEnv, self).__init__()        #初始化参数        self.width = 400        self.height = 300        self.player_size = 30        self.enemy_size = 30        self.render_mode = render_mode        self.action_space = spaces.Discrete(3)        self.observation_space = spaces.Box(            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8        )        pygame.init()        if self.render_mode == "human":            self.screen = pygame.display.set_mode((self.width, self.height))                self.canvas = pygame.Surface((self.width, self.height))        self.font = pygame.font.SysFont("monospace", 15)    def reset(self, seed=None, options=None):        super().reset(seed=seed)        self.player_x = self.width // 2 - self.player_size // 2        self.player_y = self.height - self.player_size - 10        self.enemies = []        self.score = 0        self.frame_count = 0        self.current_speed = 5        self.spawn_rate = 30        return self._get_obs(), {}    def step(self, action):        reward = 0        terminated = False        truncated = False        move_speed = 8        if action == 1 and self.player_x > 0: #             self.player_x -= move_speed            reward -= 0.05        if action == 2 and self.player_x < self.width - self.player_size:            self.player_x += move_speed            reward -= 0.05        self.frame_count += 1        level = self.score // 5        self.current_speed = 5 + level        self.spawn_rate = 30 - level * 2        spawn_rate = max(10, 30 - level)        if self.frame_count >= spawn_rate:            self.frame_count = 0            enemy_x = random.randint(0, self.width - self.enemy_size)            self.enemies.append([enemy_x, 0]) # [x, y]        for enemy in self.enemies:            enemy[1] += self.current_speed                        player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)            enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)                        if player_rect.colliderect(enemy_rect):                reward = -10                 terminated = True            elif enemy[1] > self.height:                self.enemies.remove(enemy)                self.score += 1                reward = 1                 if not terminated:            if self.score > 100:                reward += 0.01            reward += 0.01        obs = self._get_obs()        if self.render_mode == "human":            self._render_window()        return obs, reward, terminated, truncated, {}    def _get_obs(self):        self.canvas.fill((0, 0, 0))        pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))                for enemy in self.enemies:            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))        img_array = pygame.surfarray.array3d(self.canvas)        img_array = np.transpose(img_array, (1, 0, 2))        obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)        return obs.astype(np.uint8)    def _render_window(self):        self.screen.blit(self.canvas, (0, 0))        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))        self.screen.blit(text, (10, 10))        pygame.display.flip()        for event in pygame.event.get():            if event.type == pygame.QUIT:                pygame.quit()def train():    log_dir = "logs/DodgeGame"    os.makedirs(log_dir, exist_ok=True)    env = MyEnv()    check_env(env)    print("环境检查通过...")    model_path = "models/dodge_ai.zip"    if not os.path.exists(model_path):        print("
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册