Coverage for o2/agents/ppo_agent.py: 91%
113 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1import os
3# Disable GPU
4os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
6from typing import TYPE_CHECKING, Optional
8import gymnasium as gym
9import numpy as np
10import torch as th
11from stable_baselines3.common.utils import obs_as_tensor
12from typing_extensions import override
14from o2.actions.base_actions.base_action import BaseAction
15from o2.agents.agent import Agent, NoActionsLeftError
16from o2.agents.tabu_agent import TabuAgent
17from o2.models.solution import Solution
18from o2.pareto_front import FRONT_STATUS
19from o2.ppo_utils.ppo_env import PPOEnv
20from o2.ppo_utils.ppo_input import PPOInput
21from o2.store import SolutionTry, Store
22from o2.util.indented_printer import print_l1
24if TYPE_CHECKING:
25 from numpy import ndarray
28class PPOAgent(Agent):
29 """Selects the best action to take next, based on the current state of the store."""
31 @override
32 def __init__(self, store: Store) -> None:
33 super().__init__(store)
34 from sb3_contrib import MaskablePPO
36 if store.settings.ppo_use_existing_model:
37 self.model = MaskablePPO.load(store.settings.ppo_model_path)
38 else:
39 env: gym.Env = self.get_env()
40 self.model = MaskablePPO(
41 "MultiInputPolicy",
42 env,
43 verbose=1,
44 # tensorboard_log="./logs/progress_tensorboard/",
45 clip_range=0.2,
46 # TODO make learning rate smarter
47 # learning_rate=linear_schedule(3e-4),
48 n_steps=1 * store.settings.ppo_steps_per_iteration, # Multiple of 50
49 batch_size=round(0.5 * store.settings.ppo_steps_per_iteration), # Divisor of 50
50 gamma=1,
51 )
53 self.model._setup_learn(
54 store.settings.max_iterations,
55 callback=None,
56 reset_num_timesteps=True,
57 tb_log_name="PPO",
58 progress_bar=False,
59 )
60 self.last_actions: ndarray
61 self.last_values: ndarray
62 self.log_probs: ndarray
63 self.last_action_mask: ndarray
65 @override
66 def select_actions(self) -> Optional[list[BaseAction]]:
67 """Select the best actions to take next.
69 It will pick at most cpu_count actions, so parallel evaluation is possible.
71 If the possible options for the current base evaluation are exhausted,
72 it will choose a new base evaluation.
73 """
74 action_from_store = PPOInput.get_actions_from_store(self.store)
75 action_count = len([a for a in action_from_store if a is not None])
76 if action_count == 0:
77 # TODO: We need to reset the env here
78 raise NoActionsLeftError()
79 else:
80 print_l1(f"Choosing best action out of {action_count} possible actions.")
81 action_mask = PPOInput.get_action_mask_from_actions(action_from_store)
82 self.last_action_mask = action_mask
84 # Collect a single step
85 with th.no_grad():
86 obs_tensor = obs_as_tensor(self.model._last_obs, self.model.device) # type: ignore
87 actions, values, log_probs = self.model.policy(obs_tensor, action_masks=action_mask)
88 self.last_actions = actions
89 self.last_values = values
90 self.log_probs = log_probs
92 [action_index] = actions.cpu().numpy()
94 if action_index is None:
95 raise ValueError("Model did not return an action index.")
96 action = action_from_store[action_index]
97 if action is None:
98 return None
99 return [action]
101 @override
102 def find_new_base_solution(self, proposed_solution_try: Optional[SolutionTry] = None) -> Solution:
103 """Select a new base solution."""
104 return TabuAgent(self.store).find_new_base_solution(proposed_solution_try)
106 @override
107 def process_many_solutions(
108 self, solutions: list[Solution]
109 ) -> tuple[list[SolutionTry], list[SolutionTry]]:
110 chosen_tries, not_chosen_tries = super().process_many_solutions(solutions)
111 self._result_callback(chosen_tries, not_chosen_tries)
112 return chosen_tries, not_chosen_tries
114 def _result_callback(self, chosen_tries: list[SolutionTry], not_chosen_tries: list[SolutionTry]) -> None:
115 """Handle the result of the evaluation."""
116 result = chosen_tries[0] if chosen_tries else not_chosen_tries[0]
118 new_obs, reward, done = self.step_info_from_try(result)
119 rewards = [reward]
120 done_values = np.array([1 if done else 0])
121 actions = self.last_actions.reshape(-1, 1)
122 log_probs = self.log_probs
123 action_masks = self.last_action_mask
125 # Add collected data to the rollout buffer
126 self.model.rollout_buffer.add(
127 self.model._last_obs,
128 actions,
129 rewards,
130 self.model._last_episode_starts or done_values,
131 self.last_values,
132 log_probs,
133 action_masks=action_masks,
134 )
135 self.model._last_obs = new_obs
136 self.model._last_episode_starts = done_values # type: ignore
138 # Train if the buffer is full
139 if self.model.rollout_buffer.full:
140 print_l1("Rollout buffer full, training...")
141 with th.no_grad():
142 last_values = self.model.policy.predict_values(obs_as_tensor(new_obs, self.model.device))
143 self.model.rollout_buffer.compute_returns_and_advantage(
144 last_values=last_values,
145 dones=done_values, # type: ignore
146 )
147 self.model.train()
148 self.model.rollout_buffer.reset()
150 # If the episode is done, select a new base solution
151 if done_values[0]:
152 tmp_agent = TabuAgent(self.store)
153 while True:
154 self.store.solution = tmp_agent._select_new_base_evaluation(reinsert_current_solution=False)
155 actions = PPOInput.get_actions_from_store(self.store)
156 action_count = len([a for a in actions if a is not None])
157 if action_count > 0:
158 break
159 else:
160 print_l1("Still no actions available for next step, selecting new base solution again.")
162 def get_env(self) -> gym.Env:
163 """Get the environment for the PPO agent."""
164 return PPOEnv(self.store, max_steps=self.store.settings.ppo_steps_per_iteration)
166 def update_state(self) -> None:
167 """Update the state of the agent."""
168 self.actions = PPOInput.get_actions_from_store(self.store)
169 self.action_space = PPOInput.get_action_space_from_actions(self.actions)
170 self.observation_space = PPOInput.get_observation_space(self.store)
171 self.state = PPOInput.get_state_from_store(self.store)
173 def step_info_from_try(self, solution_try: SolutionTry) -> tuple[dict, float, bool]:
174 """Get the step info from the given SolutionTry."""
175 # TODO Improve scores based on how good/bad the solution is
176 status, _ = solution_try
177 if status == FRONT_STATUS.INVALID:
178 reward = -1
179 elif status == FRONT_STATUS.IN_FRONT:
180 reward = 1
181 elif status == FRONT_STATUS.IS_DOMINATED:
182 reward = 10
183 else:
184 reward = -1
186 done = False
188 self.update_state()
190 print_l1(f"Done. Reward: {reward}")
191 action_count = len([a for a in self.actions if a is not None])
192 if action_count == 0:
193 print_l1("No actions available for next step")
194 done = True
195 else:
196 print_l1(f"{action_count} actions available for next step")
198 return self.state, reward, done