Coverage for o2/agents/ppo_agent.py: 91%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1import os 

2 

3# Disable GPU 

4os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 

5 

6from typing import TYPE_CHECKING, Optional 

7 

8import gymnasium as gym 

9import numpy as np 

10import torch as th 

11from stable_baselines3.common.utils import obs_as_tensor 

12from typing_extensions import override 

13 

14from o2.actions.base_actions.base_action import BaseAction 

15from o2.agents.agent import Agent, NoActionsLeftError 

16from o2.agents.tabu_agent import TabuAgent 

17from o2.models.solution import Solution 

18from o2.pareto_front import FRONT_STATUS 

19from o2.ppo_utils.ppo_env import PPOEnv 

20from o2.ppo_utils.ppo_input import PPOInput 

21from o2.store import SolutionTry, Store 

22from o2.util.indented_printer import print_l1 

23 

24if TYPE_CHECKING: 

25 from numpy import ndarray 

26 

27 

28class PPOAgent(Agent): 

29 """Selects the best action to take next, based on the current state of the store.""" 

30 

31 @override 

32 def __init__(self, store: Store) -> None: 

33 super().__init__(store) 

34 from sb3_contrib import MaskablePPO 

35 

36 if store.settings.ppo_use_existing_model: 

37 self.model = MaskablePPO.load(store.settings.ppo_model_path) 

38 else: 

39 env: gym.Env = self.get_env() 

40 self.model = MaskablePPO( 

41 "MultiInputPolicy", 

42 env, 

43 verbose=1, 

44 # tensorboard_log="./logs/progress_tensorboard/", 

45 clip_range=0.2, 

46 # TODO make learning rate smarter 

47 # learning_rate=linear_schedule(3e-4), 

48 n_steps=1 * store.settings.ppo_steps_per_iteration, # Multiple of 50 

49 batch_size=round(0.5 * store.settings.ppo_steps_per_iteration), # Divisor of 50 

50 gamma=1, 

51 ) 

52 

53 self.model._setup_learn( 

54 store.settings.max_iterations, 

55 callback=None, 

56 reset_num_timesteps=True, 

57 tb_log_name="PPO", 

58 progress_bar=False, 

59 ) 

60 self.last_actions: ndarray 

61 self.last_values: ndarray 

62 self.log_probs: ndarray 

63 self.last_action_mask: ndarray 

64 

65 @override 

66 def select_actions(self) -> Optional[list[BaseAction]]: 

67 """Select the best actions to take next. 

68 

69 It will pick at most cpu_count actions, so parallel evaluation is possible. 

70 

71 If the possible options for the current base evaluation are exhausted, 

72 it will choose a new base evaluation. 

73 """ 

74 action_from_store = PPOInput.get_actions_from_store(self.store) 

75 action_count = len([a for a in action_from_store if a is not None]) 

76 if action_count == 0: 

77 # TODO: We need to reset the env here 

78 raise NoActionsLeftError() 

79 else: 

80 print_l1(f"Choosing best action out of {action_count} possible actions.") 

81 action_mask = PPOInput.get_action_mask_from_actions(action_from_store) 

82 self.last_action_mask = action_mask 

83 

84 # Collect a single step 

85 with th.no_grad(): 

86 obs_tensor = obs_as_tensor(self.model._last_obs, self.model.device) # type: ignore 

87 actions, values, log_probs = self.model.policy(obs_tensor, action_masks=action_mask) 

88 self.last_actions = actions 

89 self.last_values = values 

90 self.log_probs = log_probs 

91 

92 [action_index] = actions.cpu().numpy() 

93 

94 if action_index is None: 

95 raise ValueError("Model did not return an action index.") 

96 action = action_from_store[action_index] 

97 if action is None: 

98 return None 

99 return [action] 

100 

101 @override 

102 def find_new_base_solution(self, proposed_solution_try: Optional[SolutionTry] = None) -> Solution: 

103 """Select a new base solution.""" 

104 return TabuAgent(self.store).find_new_base_solution(proposed_solution_try) 

105 

106 @override 

107 def process_many_solutions( 

108 self, solutions: list[Solution] 

109 ) -> tuple[list[SolutionTry], list[SolutionTry]]: 

110 chosen_tries, not_chosen_tries = super().process_many_solutions(solutions) 

111 self._result_callback(chosen_tries, not_chosen_tries) 

112 return chosen_tries, not_chosen_tries 

113 

114 def _result_callback(self, chosen_tries: list[SolutionTry], not_chosen_tries: list[SolutionTry]) -> None: 

115 """Handle the result of the evaluation.""" 

116 result = chosen_tries[0] if chosen_tries else not_chosen_tries[0] 

117 

118 new_obs, reward, done = self.step_info_from_try(result) 

119 rewards = [reward] 

120 done_values = np.array([1 if done else 0]) 

121 actions = self.last_actions.reshape(-1, 1) 

122 log_probs = self.log_probs 

123 action_masks = self.last_action_mask 

124 

125 # Add collected data to the rollout buffer 

126 self.model.rollout_buffer.add( 

127 self.model._last_obs, 

128 actions, 

129 rewards, 

130 self.model._last_episode_starts or done_values, 

131 self.last_values, 

132 log_probs, 

133 action_masks=action_masks, 

134 ) 

135 self.model._last_obs = new_obs 

136 self.model._last_episode_starts = done_values # type: ignore 

137 

138 # Train if the buffer is full 

139 if self.model.rollout_buffer.full: 

140 print_l1("Rollout buffer full, training...") 

141 with th.no_grad(): 

142 last_values = self.model.policy.predict_values(obs_as_tensor(new_obs, self.model.device)) 

143 self.model.rollout_buffer.compute_returns_and_advantage( 

144 last_values=last_values, 

145 dones=done_values, # type: ignore 

146 ) 

147 self.model.train() 

148 self.model.rollout_buffer.reset() 

149 

150 # If the episode is done, select a new base solution 

151 if done_values[0]: 

152 tmp_agent = TabuAgent(self.store) 

153 while True: 

154 self.store.solution = tmp_agent._select_new_base_evaluation(reinsert_current_solution=False) 

155 actions = PPOInput.get_actions_from_store(self.store) 

156 action_count = len([a for a in actions if a is not None]) 

157 if action_count > 0: 

158 break 

159 else: 

160 print_l1("Still no actions available for next step, selecting new base solution again.") 

161 

162 def get_env(self) -> gym.Env: 

163 """Get the environment for the PPO agent.""" 

164 return PPOEnv(self.store, max_steps=self.store.settings.ppo_steps_per_iteration) 

165 

166 def update_state(self) -> None: 

167 """Update the state of the agent.""" 

168 self.actions = PPOInput.get_actions_from_store(self.store) 

169 self.action_space = PPOInput.get_action_space_from_actions(self.actions) 

170 self.observation_space = PPOInput.get_observation_space(self.store) 

171 self.state = PPOInput.get_state_from_store(self.store) 

172 

173 def step_info_from_try(self, solution_try: SolutionTry) -> tuple[dict, float, bool]: 

174 """Get the step info from the given SolutionTry.""" 

175 # TODO Improve scores based on how good/bad the solution is 

176 status, _ = solution_try 

177 if status == FRONT_STATUS.INVALID: 

178 reward = -1 

179 elif status == FRONT_STATUS.IN_FRONT: 

180 reward = 1 

181 elif status == FRONT_STATUS.IS_DOMINATED: 

182 reward = 10 

183 else: 

184 reward = -1 

185 

186 done = False 

187 

188 self.update_state() 

189 

190 print_l1(f"Done. Reward: {reward}") 

191 action_count = len([a for a in self.actions if a is not None]) 

192 if action_count == 0: 

193 print_l1("No actions available for next step") 

194 done = True 

195 else: 

196 print_l1(f"{action_count} actions available for next step") 

197 

198 return self.state, reward, done