Coverage for o2/models/solution_tree.py: 82%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1import random 

2from collections import OrderedDict 

3from typing import TYPE_CHECKING, Optional, cast 

4 

5import numpy as np 

6import rtree 

7 

8from o2.models.settings import Settings 

9from o2.models.solution import Solution 

10from o2.pareto_front import ParetoFront 

11from o2.util.helper import hex_id 

12from o2.util.indented_printer import print_l3 

13from o2.util.logger import warn 

14from o2.util.solution_dumper import SolutionDumper 

15 

16if TYPE_CHECKING: 

17 from o2.actions.base_actions.base_action import BaseAction 

18 

19 

20class SolutionTree: 

21 """The SolutionTree class is a tree of solutions. 

22 

23 It's used to get a list of possible base solutions, that can be used to try 

24 new actions on. Those base states are sorted by their distance to the 

25 current Pareto Front (Nearest Neighbor). 

26 

27 The points in the pareto front are of course therefore always chosen first. 

28 

29 Under the hood it uses a RTree to store the solutions, which allows for fast 

30 nearest neighbor queries, as well as insertions / deletions. 

31 """ 

32 

33 def __init__( 

34 self, 

35 ) -> None: 

36 self.rtree = rtree.index.Index() 

37 self.solution_lookup: OrderedDict[str, Optional[Solution]] = OrderedDict() 

38 

39 def add_solution(self, solution: "Solution", archive: bool = True) -> None: 

40 """Add a solution to the tree.""" 

41 self.solution_lookup[solution.id] = solution 

42 self.rtree.insert(int(solution.id, 16), solution.point) 

43 if archive and not solution.is_base_solution and Settings.ARCHIVE_SOLUTIONS: 

44 solution.archive() 

45 

46 def add_solution_as_discarded(self, solution: "Solution") -> None: 

47 """Add a solution to the tree as discarded.""" 

48 self.solution_lookup[solution.id] = None 

49 if Settings.DUMP_DISCARDED_SOLUTIONS: 

50 SolutionDumper.instance.dump_solution(solution) 

51 

52 def get_nearest_solution( 

53 self, pareto_front: ParetoFront, max_distance: float = float("inf") 

54 ) -> Optional["Solution"]: 

55 """Get the nearest solution to the given Pareto Front. 

56 

57 This means the solution with the smallest distance to any solution in the 

58 given Pareto Front. With the distance being the euclidean distance of the 

59 evaluations on the time & cost axis. 

60 

61 If there are still pareto solutions left, it will usually return the 

62 most recent one. If there are no pareto solutions left, it will return 

63 """ 

64 nearest_solution: Optional[Solution] = None 

65 nearest_distance = float("inf") 

66 pareto_solutions = list(reversed(pareto_front.solutions)) 

67 error_count = 0 

68 valid_count = 0 

69 

70 pareto_points = np.array([s.point for s in pareto_solutions]) 

71 max_distances = np.array([max_distance] * len(pareto_solutions)) 

72 

73 neighbors, _ = self.rtree.nearest_v( # type: ignore 

74 mins=pareto_points, 

75 maxs=pareto_points, 

76 num_results=1, 

77 max_dists=max_distances, 

78 ) 

79 print_l3(f"Found {len(neighbors)} neighbors in radius {max_distance:_.2f}.") 

80 for neighbor in neighbors: 

81 if error_count > 20: 

82 warn(f"Got too many None items from rtree! Returning None. {error_count} errors so far.") 

83 break 

84 if neighbor is None: 

85 warn(f"WARNING: Got None item from rtree. {error_count} errors so far.") 

86 error_count += 1 

87 continue 

88 item_id = hex_id(neighbor) 

89 if item_id not in self.solution_lookup: 

90 warn( 

91 f"WARNING: Got non-existent solution from rtree ({item_id}). {error_count} errors so far." 

92 ) 

93 error_count += 1 

94 continue 

95 solution = self.solution_lookup[item_id] 

96 if solution is None: 

97 warn(f"WARNING: Got discarded solution from rtree ({item_id}). {error_count} errors so far.") 

98 # TODO: How to remove the solution from the rtree? 

99 error_count += 1 

100 continue 

101 # Early exit if we find a pareto solution. 

102 # Because of reversed, it will be the most recent 

103 if solution in pareto_front.solutions: 

104 return solution 

105 

106 distance = min(s.distance_to(solution) for s in pareto_front.solutions) 

107 # Sanity check, that the distance is smaller than the max distance. 

108 if distance > max_distance: 

109 continue 

110 valid_count += 1 

111 if distance < nearest_distance: 

112 nearest_solution = solution 

113 nearest_distance = distance 

114 

115 if nearest_solution is None: 

116 print_l3( 

117 f"NO nearest solution was found in tree. ({error_count} errors, " 

118 f"{len(neighbors)} neighbors, {len(pareto_front.solutions)} pareto solutions, " 

119 f"{self.total_solutions} solutions in tree, {self.solutions_left} solutions unexplored)" 

120 ) 

121 else: 

122 print_l3(f"... of which {valid_count} were valid.") 

123 return nearest_solution 

124 

125 @property 

126 def discarded_solutions(self) -> int: 

127 """Return the number of discarded / exhausted solutions.""" 

128 return sum(1 for id in self.solution_lookup if self.solution_lookup[id] is None) 

129 

130 @property 

131 def solutions_left(self) -> int: 

132 """Return the number of untried solutions left in the tree.""" 

133 return len(self.solution_lookup) - self.discarded_solutions 

134 

135 @property 

136 def total_solutions(self) -> int: 

137 """Return the total number of solutions (tried + non-tried).""" 

138 return len(self.solution_lookup) 

139 

140 def pop_nearest_solution( 

141 self, pareto_front: ParetoFront, max_distance: float = float("inf") 

142 ) -> Optional["Solution"]: 

143 """Pop the nearest solution to the given Pareto Front.""" 

144 nearest_solution = self.get_nearest_solution(pareto_front, max_distance=max_distance) 

145 if nearest_solution is not None: 

146 self.remove_solution(nearest_solution) 

147 print_l3(f"Popped solution ({nearest_solution.id})") 

148 if nearest_solution not in pareto_front.solutions: 

149 print_l3("Nearest solution is NOT in pareto front.") 

150 else: 

151 print_l3("Nearest solution is IN pareto front.") 

152 return nearest_solution 

153 

154 def check_if_already_done(self, base_solution: "Solution", new_action: "BaseAction") -> bool: 

155 """Check if the given action has already been tried.""" 

156 return Solution.hash_action_list(base_solution.actions + [new_action]) in self.solution_lookup 

157 

158 def get_index_of_solution(self, solution: Solution) -> int: 

159 """Get the index of the solution in the tree.""" 

160 return list(self.solution_lookup).index(solution.id) 

161 

162 def get_solutions_near_to_pareto_front( 

163 self, pareto_front: ParetoFront, max_distance: float = float("inf") 

164 ) -> list["Solution"]: 

165 """Get a list of solutions near the pareto front.""" 

166 bounding_points_mins = (np.array([s.point for s in pareto_front.solutions]) - max_distance).clip( 

167 min=0 

168 ) 

169 bounding_points_maxs = np.array([s.point for s in pareto_front.solutions]) + max_distance 

170 solution_ids, _ = self.rtree.intersection_v(bounding_points_mins, bounding_points_maxs) 

171 solutions = [ 

172 cast(Solution, self.solution_lookup[hex_id(solution_id)]) 

173 for solution_id in set(solution_ids) 

174 if hex_id(solution_id) in self.solution_lookup 

175 and self.solution_lookup[hex_id(solution_id)] is not None 

176 ] 

177 # Filter out solutions, that are distanced more than max_distance 

178 solutions = [ 

179 s for s in solutions if min(s.distance_to(p) for p in pareto_front.solutions) <= max_distance 

180 ] 

181 

182 return solutions 

183 

184 def get_random_solution_near_to_pareto_front( 

185 self, pareto_front: ParetoFront, max_distance: float = float("inf") 

186 ) -> Optional["Solution"]: 

187 """Get a random solution near the pareto front.""" 

188 solutions = self.get_solutions_near_to_pareto_front(pareto_front, max_distance=max_distance) 

189 if not solutions: 

190 return None 

191 print_l3(f"Found {len(solutions)} solutions near pareto front.") 

192 random_solution = random.choice(solutions) 

193 return random_solution 

194 

195 def remove_solution(self, solution: Solution) -> None: 

196 """Remove a solution from the tree.""" 

197 self.rtree.delete(int(solution.id, 16), solution.point) 

198 self.solution_lookup[solution.id] = None 

199 

200 if Settings.DUMP_DISCARDED_SOLUTIONS: 

201 SolutionDumper.instance.dump_solution(solution)