Coverage for o2/optimizer.py: 91%
141 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1import concurrent.futures
2import time
3import traceback
4from collections.abc import Generator
6from o2.actions.base_actions.base_action import BaseAction
7from o2.agents.agent import (
8 Agent,
9 NoActionsLeftError,
10 NoNewBaseSolutionFoundError,
11)
12from o2.agents.simulated_annealing_agent import SimulatedAnnealingAgent
13from o2.agents.tabu_agent import TabuAgent
14from o2.models.settings import AgentType, Settings
15from o2.models.solution import Solution
16from o2.pareto_front import FRONT_STATUS
17from o2.simulation_runner import SimulationRunner
18from o2.store import SolutionTry, Store
19from o2.util.indented_printer import print_l0, print_l1, print_l2, print_l3, print_l4
20from o2.util.logger import STATS_LOG_LEVEL
21from o2.util.solution_dumper import SolutionDumper
24class Optimizer:
25 """The Optimizer class is the main class that runs the optimization process."""
27 def __init__(self, store: Store) -> None:
28 """Initialize the optimizer."""
29 self.settings = store.settings
30 self.max_iter = store.settings.max_iterations
31 self.max_non_improving_iter = store.settings.max_non_improving_actions
32 self.max_solutions = store.settings.max_solutions or float("inf")
33 self.max_parallel = store.settings.MAX_THREADS_ACTION_EVALUATION
34 self.running_avg_time = 0
35 if not Settings.DISABLE_PARALLEL_EVALUATION and Settings.MAX_THREADS_ACTION_EVALUATION > 1:
36 self.executor = concurrent.futures.ProcessPoolExecutor(
37 max_workers=Settings.MAX_THREADS_ACTION_EVALUATION
38 )
39 self.agent: Agent = self._init_agent(store)
40 if self.settings.log_to_tensor_board:
41 from o2.util.tensorboard_helper import TensorBoardHelper
43 TensorBoardHelper(self.agent, store.name)
45 def _init_agent(self, store: Store) -> Agent:
46 """Initialize the agent for the optimization task."""
47 if self.settings.agent == AgentType.TABU_SEARCH:
48 return TabuAgent(store)
49 elif self.settings.agent == AgentType.PROXIMAL_POLICY_OPTIMIZATION:
50 from o2.agents.ppo_agent import PPOAgent
52 return PPOAgent(store)
53 elif self.settings.agent == AgentType.PROXIMAL_POLICY_OPTIMIZATION_RANDOM:
54 from o2.agents.ppo_agent_random import PPOAgentRandom
56 return PPOAgentRandom(store)
57 elif self.settings.agent == AgentType.SIMULATED_ANNEALING:
58 return SimulatedAnnealingAgent(store)
59 elif self.settings.agent == AgentType.TABU_SEARCH_RANDOM:
60 from o2.agents.tabu_agent_random import TabuAgentRandom
62 return TabuAgentRandom(store)
63 elif self.settings.agent == AgentType.SIMULATED_ANNEALING_RANDOM:
64 from o2.agents.simulated_annealing_agent_random import (
65 SimulatedAnnealingAgentRandom,
66 )
68 return SimulatedAnnealingAgentRandom(store)
69 raise ValueError(f"Unknown agent type: {self.settings.agent}")
71 def solve(self) -> None:
72 """Run the optimizer and print the result."""
73 store = self.agent.store
74 print_l1(
75 f"Initial evaluation: {store.base_solution.evaluation}",
76 log_level=STATS_LOG_LEVEL,
77 )
78 generator = self.get_iteration_generator(yield_on_non_acceptance=True)
79 for _ in generator:
80 if self.settings.log_to_tensor_board:
81 from o2.util.tensorboard_helper import TensorBoardHelper
83 # Just iterate through the generator to run it
84 TensorBoardHelper.instance.tensor_board_iteration_callback(store.solution)
86 if not Settings.DISABLE_PARALLEL_EVALUATION and Settings.MAX_THREADS_ACTION_EVALUATION > 1:
87 self.executor.shutdown()
89 SimulationRunner.close_executor()
91 # Final write to tensorboard
92 if self.settings.log_to_tensor_board:
93 from o2.util.tensorboard_helper import TensorBoardHelper
95 TensorBoardHelper.instance.tensor_board_iteration_callback(store.solution, write_everything=True)
97 self._print_result()
99 def get_iteration_generator(
100 self, yield_on_non_acceptance: bool = False
101 ) -> Generator[Solution, None, None]:
102 """Run the optimizer and yield optimal Solution.
104 NOTE: You usually want to use the `solve` method instead of this
105 method, but if you want to process the Solution as they come,
106 you can use this method.
107 """
108 for it in range(self.max_iter):
109 start_time = time.time()
110 if Settings.DUMP_DISCARDED_SOLUTIONS or Settings.ARCHIVE_SOLUTIONS:
111 SolutionDumper.instance.iteration = it
112 if self.settings.log_to_tensor_board:
113 from o2.util.tensorboard_helper import TensorBoardHelper
115 TensorBoardHelper.instance.iteration += 1
117 try:
118 if self.max_non_improving_iter <= 0:
119 print_l1("Maximum non improving iterations reached!", log_level=STATS_LOG_LEVEL)
120 break
122 if self.max_solutions <= 0:
123 print_l1("Maximum number of solutions reached!", log_level=STATS_LOG_LEVEL)
124 break
126 max_solutions_setting = self.settings.max_solutions or float("inf")
127 solution_no = max_solutions_setting - self.max_solutions
129 msg = f"{self.settings.agent.name} - Iteration {it + 1}/{self.max_iter}"
130 msg += (
131 f" (Solution {solution_no}/{max_solutions_setting})"
132 if max_solutions_setting != float("inf")
133 else ""
134 )
135 print_l0(msg)
137 actions_to_perform = self.agent.select_actions()
138 if actions_to_perform is None or len(actions_to_perform) == 0:
139 print_l1("Optimization finished, no actions to perform.")
140 break
141 print_l1(f"Running {len(actions_to_perform)} actions...")
142 start_time = time.time()
144 solutions = self._execute_actions_parallel(actions_to_perform)
145 print_l2(f"Simulation took {time.time() - start_time:_.2f}s")
147 chosen_tries, not_chosen_tries = self.agent.process_many_solutions(solutions)
149 self.max_solutions -= len(chosen_tries) + len(not_chosen_tries)
151 if len(chosen_tries) == 0:
152 print_l1("No action improved the evaluation")
153 self.max_non_improving_iter -= len(actions_to_perform)
154 for _, solution in not_chosen_tries:
155 print_str = f"{solution.id}: {repr(solution.last_action)}"
156 if not solution.is_valid:
157 print_str = f"[INVALID] {print_str}"
158 print_l2(print_str)
159 if yield_on_non_acceptance:
160 yield solution
161 else:
162 if len(not_chosen_tries) > 0:
163 print_l2("Actions NOT chosen:")
164 for _, solution in not_chosen_tries:
165 print_l3(f"{solution.id}: {repr(solution.last_action)}")
166 self.max_non_improving_iter -= 1
167 if yield_on_non_acceptance:
168 yield solution
169 print_l2("Actions chosen:")
170 for status, solution in chosen_tries:
171 self.max_non_improving_iter = self.settings.max_non_improving_actions
172 print_l3(f"{solution.id}: {repr(solution.last_action)}")
173 if status == FRONT_STATUS.IN_FRONT:
174 print_l4(
175 "Pareto front CONTAINS new evaluation.",
176 log_level=STATS_LOG_LEVEL,
177 )
178 print_l4(
179 f"{solution.id}: {Settings.get_pareto_x_label()}: "
180 f"{solution.pareto_x:_.2f}; {Settings.get_pareto_y_label()}: "
181 f"{solution.pareto_y:_.2f}",
182 log_level=STATS_LOG_LEVEL,
183 )
184 elif status == FRONT_STATUS.IS_DOMINATED:
185 print_l4(
186 "Pareto front IS DOMINATED by new evaluation.",
187 log_level=STATS_LOG_LEVEL,
188 )
189 print_l4(
190 f"{solution.id}: {Settings.get_pareto_x_label()}: "
191 f"{solution.pareto_x:_.2f}; {Settings.get_pareto_y_label()}: "
192 f"{solution.pareto_y:_.2f}",
193 log_level=STATS_LOG_LEVEL,
194 )
195 yield solution
196 print_l1(f"Non improving actions left: {self.max_non_improving_iter}")
197 except NoActionsLeftError:
198 print_l1("No actions left to perform.")
199 break
200 except NoNewBaseSolutionFoundError:
201 print_l1("No new base solution found.")
202 break
203 except Exception as e:
204 print_l1(f"Error in iteration: {e}")
205 print_l1(traceback.format_exc())
206 if self.settings.throw_on_iteration_errors:
207 # re-raising the exception to stop the optimization
208 raise
209 continue
210 self._print_time_estimate(it, start_time)
212 def _print_result(self):
213 store = self.agent.store
214 print_l0("Final result:")
215 print_l1(
216 f"Best evaluation: \t{store.current_evaluation}",
217 log_level=STATS_LOG_LEVEL,
218 )
219 print_l1(
220 f"Base evaluation: \t{store.base_evaluation}",
221 log_level=STATS_LOG_LEVEL,
222 )
223 print_l1("Modifications:")
224 for action in store.base_solution.actions:
225 print_l2(repr(action))
227 def _print_time_estimate(self, it: int, start_time: float):
228 time_taken = time.time() - start_time
229 self.running_avg_time = (self.running_avg_time * it + time_taken) / (it + 1)
230 estimated_time_left = self.running_avg_time * (self.max_iter - it)
231 est_hours = estimated_time_left / 3600
232 est_minutes = (estimated_time_left % 3600) / 60
233 est_seconds = estimated_time_left % 60
234 print_l1(
235 f"Iteration took {time_taken:_.2f}s (avg: {self.running_avg_time:_.2f}s, "
236 f"est. {est_hours:.0f}h {est_minutes:.0f}m {est_seconds:.0f}s left)"
237 )
239 def _execute_actions_parallel(self, actions_to_perform: list[BaseAction]) -> list[Solution]:
240 """Execute the given actions in parallel and return the results.
242 The results are sorted, so that the most impactful actions are first,
243 thereby allowing the store to process them in that order.
244 The results have not modified the state of the store.
246 """
247 store = self.agent.store
248 solution_tries: list[SolutionTry] = []
250 if not Settings.DISABLE_PARALLEL_EVALUATION and Settings.MAX_THREADS_ACTION_EVALUATION > 1:
251 futures: list[concurrent.futures.Future[Solution]] = []
252 for action in actions_to_perform:
253 futures.append(self.executor.submit(Solution.from_parent, store.solution, action))
255 for future in concurrent.futures.as_completed(futures):
256 try:
257 new_solution = future.result()
258 solution_tries.append(self.agent.try_solution(new_solution))
259 except Exception as e:
260 print_l1(f"Error evaluating actions : {e}")
261 else:
262 for action in actions_to_perform:
263 solution_try = self.agent.try_solution(Solution.from_parent(store.solution, action))
264 solution_tries.append(solution_try)
266 # Sort tries with dominating ones first
267 solution_tries.sort(
268 key=lambda x: -1
269 if x[0] == FRONT_STATUS.IS_DOMINATED
270 else 1
271 if x[0] == FRONT_STATUS.DOMINATES
272 else 0
273 )
274 return list(map(lambda x: x[1], solution_tries))