简介
Gymnasium 为强化学习提供了一个标准化的API,它定义了 Agent 应该如何观察世界、如何做出动作以及如何获得奖励,不管是游戏,还是工业设备,只需要满足Gymnasium标准都能使用同一套代码进行训练。
认识Gymnasium
使用stable_baseline3只需要定义好Gymnasium环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。
Gymnasium提供了几个核心的api:
方法功能返回值reset()将环境重置为初始状态,开始新回合。obs, infostep(action)环境向前推进一步,执行动作。obs, reward, terminated, truncated, inforender()可视化环境(根据 render_mode 渲染图像或弹出窗口)。视配置而定(通常无或为 np.array)close()释放环境资源(关闭窗口、清理内存)。无其中的各个返回值的含义:
- observation (Object): 当前状态的描述。例如敌人,玩家的位置,玩家的状态等
- reward (Float): 上一步动作获得的奖励
- terminated (Bool): 是否由于任务逻辑结束。例如:到达终点、掉进岩浆等
- truncated (Bool): 是否由于外部限制结束。例如:达到最大步数 500 步
- info (Dict): 辅助诊断信息,模型训练通常不用,用于用户自定义调试或记录额外统计。
手动构建环境
案例
案例描述:利用pygame构建一个简单的游戏,躲避掉落方块,利用构建的奖励机制,进行强化学习。
[code]import gymnasium as gymfrom gymnasium import spacesimport numpy as npimport pygameimport randomimport cv2import osfrom stable_baselines3 import PPOfrom stable_baselines3.common.callbacks import CheckpointCallbackfrom stable_baselines3.common.env_checker import check_envclass MyEnv(gym.Env): def __init__(self, render_mode=None): super(MyEnv, self).__init__() #初始化参数 self.width = 400 self.height = 300 self.player_size = 30 self.enemy_size = 30 self.render_mode = render_mode self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box( low=0, high=255, shape=(84, 84, 3), dtype=np.uint8 ) pygame.init() if self.render_mode == "human": self.screen = pygame.display.set_mode((self.width, self.height)) self.canvas = pygame.Surface((self.width, self.height)) self.font = pygame.font.SysFont("monospace", 15) def reset(self, seed=None, options=None): super().reset(seed=seed) self.player_x = self.width // 2 - self.player_size // 2 self.player_y = self.height - self.player_size - 10 self.enemies = [] self.score = 0 self.frame_count = 0 self.current_speed = 5 self.spawn_rate = 30 return self._get_obs(), {} def step(self, action): reward = 0 terminated = False truncated = False move_speed = 8 if action == 1 and self.player_x > 0: # self.player_x -= move_speed reward -= 0.05 if action == 2 and self.player_x < self.width - self.player_size: self.player_x += move_speed reward -= 0.05 self.frame_count += 1 level = self.score // 5 self.current_speed = 5 + level self.spawn_rate = 30 - level * 2 spawn_rate = max(10, 30 - level) if self.frame_count >= spawn_rate: self.frame_count = 0 enemy_x = random.randint(0, self.width - self.enemy_size) self.enemies.append([enemy_x, 0]) # [x, y] for enemy in self.enemies: enemy[1] += self.current_speed player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size) enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size) if player_rect.colliderect(enemy_rect): reward = -10 terminated = True elif enemy[1] > self.height: self.enemies.remove(enemy) self.score += 1 reward = 1 if not terminated: if self.score > 100: reward += 0.01 reward += 0.01 obs = self._get_obs() if self.render_mode == "human": self._render_window() return obs, reward, terminated, truncated, {} def _get_obs(self): self.canvas.fill((0, 0, 0)) pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size)) for enemy in self.enemies: pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size)) img_array = pygame.surfarray.array3d(self.canvas) img_array = np.transpose(img_array, (1, 0, 2)) obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA) return obs.astype(np.uint8) def _render_window(self): self.screen.blit(self.canvas, (0, 0)) text = self.font.render(f"Score: {self.score}", True, (255, 255, 255)) self.screen.blit(text, (10, 10)) pygame.display.flip() for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit()def train(): log_dir = "logs/DodgeGame" os.makedirs(log_dir, exist_ok=True) env = MyEnv() check_env(env) print("环境检查通过...") model_path = "models/dodge_ai.zip" if not os.path.exists(model_path): print("
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |