Coverage for o2/models/solution_tree.py: 82%
103 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1import random
2from collections import OrderedDict
3from typing import TYPE_CHECKING, Optional, cast
5import numpy as np
6import rtree
8from o2.models.settings import Settings
9from o2.models.solution import Solution
10from o2.pareto_front import ParetoFront
11from o2.util.helper import hex_id
12from o2.util.indented_printer import print_l3
13from o2.util.logger import warn
14from o2.util.solution_dumper import SolutionDumper
16if TYPE_CHECKING:
17 from o2.actions.base_actions.base_action import BaseAction
20class SolutionTree:
21 """The SolutionTree class is a tree of solutions.
23 It's used to get a list of possible base solutions, that can be used to try
24 new actions on. Those base states are sorted by their distance to the
25 current Pareto Front (Nearest Neighbor).
27 The points in the pareto front are of course therefore always chosen first.
29 Under the hood it uses a RTree to store the solutions, which allows for fast
30 nearest neighbor queries, as well as insertions / deletions.
31 """
33 def __init__(
34 self,
35 ) -> None:
36 self.rtree = rtree.index.Index()
37 self.solution_lookup: OrderedDict[str, Optional[Solution]] = OrderedDict()
39 def add_solution(self, solution: "Solution", archive: bool = True) -> None:
40 """Add a solution to the tree."""
41 self.solution_lookup[solution.id] = solution
42 self.rtree.insert(int(solution.id, 16), solution.point)
43 if archive and not solution.is_base_solution and Settings.ARCHIVE_SOLUTIONS:
44 solution.archive()
46 def add_solution_as_discarded(self, solution: "Solution") -> None:
47 """Add a solution to the tree as discarded."""
48 self.solution_lookup[solution.id] = None
49 if Settings.DUMP_DISCARDED_SOLUTIONS:
50 SolutionDumper.instance.dump_solution(solution)
52 def get_nearest_solution(
53 self, pareto_front: ParetoFront, max_distance: float = float("inf")
54 ) -> Optional["Solution"]:
55 """Get the nearest solution to the given Pareto Front.
57 This means the solution with the smallest distance to any solution in the
58 given Pareto Front. With the distance being the euclidean distance of the
59 evaluations on the time & cost axis.
61 If there are still pareto solutions left, it will usually return the
62 most recent one. If there are no pareto solutions left, it will return
63 """
64 nearest_solution: Optional[Solution] = None
65 nearest_distance = float("inf")
66 pareto_solutions = list(reversed(pareto_front.solutions))
67 error_count = 0
68 valid_count = 0
70 pareto_points = np.array([s.point for s in pareto_solutions])
71 max_distances = np.array([max_distance] * len(pareto_solutions))
73 neighbors, _ = self.rtree.nearest_v( # type: ignore
74 mins=pareto_points,
75 maxs=pareto_points,
76 num_results=1,
77 max_dists=max_distances,
78 )
79 print_l3(f"Found {len(neighbors)} neighbors in radius {max_distance:_.2f}.")
80 for neighbor in neighbors:
81 if error_count > 20:
82 warn(f"Got too many None items from rtree! Returning None. {error_count} errors so far.")
83 break
84 if neighbor is None:
85 warn(f"WARNING: Got None item from rtree. {error_count} errors so far.")
86 error_count += 1
87 continue
88 item_id = hex_id(neighbor)
89 if item_id not in self.solution_lookup:
90 warn(
91 f"WARNING: Got non-existent solution from rtree ({item_id}). {error_count} errors so far."
92 )
93 error_count += 1
94 continue
95 solution = self.solution_lookup[item_id]
96 if solution is None:
97 warn(f"WARNING: Got discarded solution from rtree ({item_id}). {error_count} errors so far.")
98 # TODO: How to remove the solution from the rtree?
99 error_count += 1
100 continue
101 # Early exit if we find a pareto solution.
102 # Because of reversed, it will be the most recent
103 if solution in pareto_front.solutions:
104 return solution
106 distance = min(s.distance_to(solution) for s in pareto_front.solutions)
107 # Sanity check, that the distance is smaller than the max distance.
108 if distance > max_distance:
109 continue
110 valid_count += 1
111 if distance < nearest_distance:
112 nearest_solution = solution
113 nearest_distance = distance
115 if nearest_solution is None:
116 print_l3(
117 f"NO nearest solution was found in tree. ({error_count} errors, "
118 f"{len(neighbors)} neighbors, {len(pareto_front.solutions)} pareto solutions, "
119 f"{self.total_solutions} solutions in tree, {self.solutions_left} solutions unexplored)"
120 )
121 else:
122 print_l3(f"... of which {valid_count} were valid.")
123 return nearest_solution
125 @property
126 def discarded_solutions(self) -> int:
127 """Return the number of discarded / exhausted solutions."""
128 return sum(1 for id in self.solution_lookup if self.solution_lookup[id] is None)
130 @property
131 def solutions_left(self) -> int:
132 """Return the number of untried solutions left in the tree."""
133 return len(self.solution_lookup) - self.discarded_solutions
135 @property
136 def total_solutions(self) -> int:
137 """Return the total number of solutions (tried + non-tried)."""
138 return len(self.solution_lookup)
140 def pop_nearest_solution(
141 self, pareto_front: ParetoFront, max_distance: float = float("inf")
142 ) -> Optional["Solution"]:
143 """Pop the nearest solution to the given Pareto Front."""
144 nearest_solution = self.get_nearest_solution(pareto_front, max_distance=max_distance)
145 if nearest_solution is not None:
146 self.remove_solution(nearest_solution)
147 print_l3(f"Popped solution ({nearest_solution.id})")
148 if nearest_solution not in pareto_front.solutions:
149 print_l3("Nearest solution is NOT in pareto front.")
150 else:
151 print_l3("Nearest solution is IN pareto front.")
152 return nearest_solution
154 def check_if_already_done(self, base_solution: "Solution", new_action: "BaseAction") -> bool:
155 """Check if the given action has already been tried."""
156 return Solution.hash_action_list(base_solution.actions + [new_action]) in self.solution_lookup
158 def get_index_of_solution(self, solution: Solution) -> int:
159 """Get the index of the solution in the tree."""
160 return list(self.solution_lookup).index(solution.id)
162 def get_solutions_near_to_pareto_front(
163 self, pareto_front: ParetoFront, max_distance: float = float("inf")
164 ) -> list["Solution"]:
165 """Get a list of solutions near the pareto front."""
166 bounding_points_mins = (np.array([s.point for s in pareto_front.solutions]) - max_distance).clip(
167 min=0
168 )
169 bounding_points_maxs = np.array([s.point for s in pareto_front.solutions]) + max_distance
170 solution_ids, _ = self.rtree.intersection_v(bounding_points_mins, bounding_points_maxs)
171 solutions = [
172 cast(Solution, self.solution_lookup[hex_id(solution_id)])
173 for solution_id in set(solution_ids)
174 if hex_id(solution_id) in self.solution_lookup
175 and self.solution_lookup[hex_id(solution_id)] is not None
176 ]
177 # Filter out solutions, that are distanced more than max_distance
178 solutions = [
179 s for s in solutions if min(s.distance_to(p) for p in pareto_front.solutions) <= max_distance
180 ]
182 return solutions
184 def get_random_solution_near_to_pareto_front(
185 self, pareto_front: ParetoFront, max_distance: float = float("inf")
186 ) -> Optional["Solution"]:
187 """Get a random solution near the pareto front."""
188 solutions = self.get_solutions_near_to_pareto_front(pareto_front, max_distance=max_distance)
189 if not solutions:
190 return None
191 print_l3(f"Found {len(solutions)} solutions near pareto front.")
192 random_solution = random.choice(solutions)
193 return random_solution
195 def remove_solution(self, solution: Solution) -> None:
196 """Remove a solution from the tree."""
197 self.rtree.delete(int(solution.id, 16), solution.point)
198 self.solution_lookup[solution.id] = None
200 if Settings.DUMP_DISCARDED_SOLUTIONS:
201 SolutionDumper.instance.dump_solution(solution)