Coverage for o2/ppo_utils/ppo_input.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from dataclasses import dataclass 

2from typing import Optional 

3 

4import numpy as np 

5from gymnasium import spaces 

6from sklearn.preprocessing import MinMaxScaler 

7 

8from o2.actions.base_actions.add_datetime_rule_base_action import ( 

9 AddDateTimeRuleAction, 

10 AddDateTimeRuleBaseActionParamsType, 

11) 

12from o2.actions.base_actions.base_action import BaseAction 

13from o2.actions.base_actions.shift_datetime_rule_base_action import ( 

14 ShiftDateTimeRuleAction, 

15 ShiftDateTimeRuleBaseActionParamsType, 

16) 

17from o2.actions.batching_actions.modify_large_ready_wt_of_significant_rule_action import ( 

18 ModifyLargeReadyWtOfSignificantRuleAction, 

19 ModifyLargeReadyWtOfSignificantRuleActionParamsType, 

20) 

21from o2.actions.batching_actions.modify_size_of_significant_rule_action import ( 

22 ModifySizeOfSignificantRuleAction, 

23 ModifySizeOfSignificantRuleActionParamsType, 

24) 

25from o2.actions.batching_actions.remove_date_time_rule_action import ( 

26 RemoveDateTimeRuleAction, 

27 RemoveDateTimeRuleActionParamsType, 

28) 

29from o2.models.days import DAYS 

30from o2.models.settings import Settings 

31from o2.models.timetable import RULE_TYPE 

32from o2.models.timetable.time_period import TimePeriod 

33from o2.store import Store 

34 

35 

36@dataclass(frozen=True) 

37class PPOInput: 

38 """The PPOInput will be used as an input for the PPO model. 

39 

40 The following columns/features are defined: 

41 

42 Continuous features: 

43 Task: 

44 - waiting_times, per task 

45 - idle_time, per task 

46 - percentage of waiting is batching, per task 

47 - fixed_cost, per task 

48 - Percentage of task instances that either have a waiting or idle time, per task 

49 

50 Resource: 

51 - waiting_times, per resource 

52 - available_time, per resource 

53 - utilization, per resource 

54 

55 

56 Batching rules: 

57 - avg_batch_size, per task 

58 

59 Discrete features: 

60 - Number of task enablements, per task, per day 

61 - Number of resources, per task 

62 - Number of tasks, per resource 

63 

64 

65 Also we'll have the following actions: 

66 - Add Batching DateTime Rule, per day, per task, 

67 - Shift Batching DateTime Rule forward, per task 

68 - Shift Batching DateTime Rule backward, per task 

69 - Remove Batching DateTime Rule, per task 

70 - Add 1h to large waiting time rule, per task 

71 - Remove 1h from large waiting time rule, per task 

72 - Add 1h to ready waiting time rule, per task 

73 - Remove 1h from ready waiting time rule, per task 

74 - Increase batch size, per task 

75 - Decrease batch size, per task 

76 

77 We'll use action masking to disable the invalid actions per step. 

78 

79 """ 

80 

81 @staticmethod 

82 def get_observation_space(store: Store) -> spaces.Dict: 

83 """Get the observation space based on the current state of the store.""" 

84 task_ids = store.current_timetable.get_task_ids() 

85 num_tasks = len(task_ids) 

86 num_days = len(DAYS) 

87 resources = store.current_timetable.get_all_resources() 

88 num_resources = len(resources) 

89 num_cases = Settings.NUMBER_OF_CASES or store.current_timetable.total_cases 

90 

91 high = num_cases * 2 

92 

93 return spaces.Dict( 

94 { 

95 # Continuous resource-related observations 

96 "resource_waiting_times": spaces.Box(low=0, high=1, shape=(1, num_resources)), 

97 "resource_available_time": spaces.Box(low=0, high=1, shape=(1, num_resources)), 

98 "resource_utilization": spaces.Box(low=0, high=1, shape=(1, num_resources)), 

99 # Discrete resource-related observations 

100 "resource_num_tasks": spaces.Box(low=0, high=high, shape=(1, num_resources)), 

101 # Continuous task-related observations 

102 "task_waiting_times": spaces.Box(low=0, high=1, shape=(1, num_tasks)), 

103 "task_idle_time": spaces.Box(low=0, high=1, shape=(1, num_tasks)), 

104 "task_batching_waiting_times_percentage": spaces.Box(low=0, high=1, shape=(1, num_tasks)), 

105 # Discrete task-related observations 

106 "task_execution_percentage_with_wt_or_it": spaces.Box( 

107 low=0, high=num_cases, shape=(1, num_tasks) 

108 ), 

109 "task_enablements_per_day": spaces.Box( 

110 low=0, high=num_cases, shape=(1, num_tasks * num_days) 

111 ), 

112 "task_num_resources": spaces.Box(low=0, high=num_resources, shape=(1, num_tasks)), 

113 "task_average_batch_size": spaces.Box(low=0, high=num_resources, shape=(1, num_tasks)), 

114 } 

115 ) 

116 

117 @staticmethod 

118 def _get_task_features(store: Store) -> dict[str, np.ndarray]: 

119 task_ids = store.current_timetable.get_task_ids() 

120 kpis = store.current_evaluation.task_kpis 

121 evaluation = store.current_evaluation 

122 

123 # MinMaxScaler for continuous features 

124 scaler_waiting_time = MinMaxScaler() 

125 scaler_idle_time = MinMaxScaler() 

126 

127 # Continuous features 

128 waiting_times = np.array( 

129 [kpis[task_id].waiting_time.total if task_id in kpis else 0 for task_id in task_ids] 

130 ) 

131 

132 batching_waiting_times = np.array( 

133 [evaluation.total_batching_waiting_time_per_task.get(task_id, 0) for task_id in task_ids] 

134 ) 

135 

136 batching_waiting_times_percentage = np.array( 

137 [ 

138 (batching_waiting_times[i] / waiting_times[i]) if waiting_times[i] != 0 else 0 

139 for i in range(len(task_ids)) 

140 ] 

141 ) 

142 

143 idle_times = np.array( 

144 [kpis[task_id].idle_time.total if task_id in kpis else 0 for task_id in task_ids] 

145 ) 

146 

147 # Scale continuous features 

148 waiting_times = scaler_waiting_time.fit_transform(waiting_times.reshape(1, -1)) 

149 idle_times = scaler_idle_time.fit_transform(idle_times.reshape(1, -1)) 

150 

151 # Discrete features 

152 task_execution_count_with_wt_or_it_dict = evaluation.task_execution_count_with_wt_or_it 

153 task_execution_percentage_with_wt_or_it_ = np.array( 

154 [ 

155 task_execution_count_with_wt_or_it_dict[task_id] / evaluation.task_execution_counts[task_id] 

156 if task_id in task_execution_count_with_wt_or_it_dict 

157 else 0 

158 for task_id in task_ids 

159 ] 

160 ) 

161 

162 task_enablements_per_day = np.array( 

163 [ 

164 len(evaluation.task_enablement_weekdays[task_id][day]) 

165 if ( 

166 task_id in evaluation.task_enablement_weekdays 

167 and day in evaluation.task_enablement_weekdays[task_id] 

168 ) 

169 else 0 

170 for task_id in task_ids 

171 for day in DAYS 

172 ] 

173 ) 

174 

175 number_of_resources = np.array( 

176 [ 

177 len(store.current_timetable.get_resources_assigned_to_task(task_id)) 

178 if task_id in task_ids 

179 else 0 

180 for task_id in task_ids 

181 ] 

182 ) 

183 

184 average_batch_size = np.array( 

185 [evaluation.avg_batch_size_per_task.get(task_id, 1) for task_id in task_ids] 

186 ) 

187 

188 return { 

189 "task_waiting_times": PPOInput._clean_np_array(waiting_times), 

190 "task_idle_time": PPOInput._clean_np_array(idle_times), 

191 "task_batching_waiting_times_percentage": PPOInput._clean_np_array( 

192 batching_waiting_times_percentage 

193 ), 

194 "task_execution_percentage_with_wt_or_it": PPOInput._clean_np_array( 

195 task_execution_percentage_with_wt_or_it_ 

196 ), 

197 "task_enablements_per_day": PPOInput._clean_np_array(task_enablements_per_day), 

198 "task_num_resources": PPOInput._clean_np_array(number_of_resources), 

199 "task_average_batch_size": PPOInput._clean_np_array(average_batch_size), 

200 } 

201 

202 @staticmethod 

203 def _get_resource_features(store: Store) -> dict[str, np.ndarray]: 

204 # TODO: This only uses base resources 

205 resources = store.base_timetable.get_all_resources() 

206 evaluation = store.current_evaluation 

207 

208 # MinMaxScaler for continuous features 

209 scaler_waiting_time = MinMaxScaler() 

210 scaler_available_time = MinMaxScaler() 

211 scaler_utilization = MinMaxScaler() 

212 scaler_hourly_cost = MinMaxScaler() 

213 

214 # Continuous features 

215 waiting_times_dict = evaluation.total_batching_waiting_time_per_resource 

216 waiting_times = np.array( 

217 [ 

218 waiting_times_dict[resource.id] # noqa: SIM401 

219 if resource.id in waiting_times_dict 

220 # TODO: 0 might not be the best default value 

221 else 0 

222 for resource in resources 

223 ] 

224 ) 

225 

226 available_times = np.array( 

227 [ 

228 evaluation.resource_kpis[resource.id].available_time 

229 if resource.id in evaluation.resource_kpis 

230 else 0 

231 for resource in resources 

232 ] 

233 ) 

234 

235 utilizations = np.array( 

236 [ 

237 evaluation.resource_kpis[resource.id].utilization 

238 if resource.id in evaluation.resource_kpis 

239 else 0 

240 for resource in resources 

241 ] 

242 ) 

243 

244 hourly_costs = np.array([resource.cost_per_hour for resource in resources]) 

245 

246 # Scale continuous features 

247 waiting_times = scaler_waiting_time.fit_transform(waiting_times.reshape(1, -1)) 

248 available_times = scaler_available_time.fit_transform(available_times.reshape(1, -1)) 

249 utilizations = scaler_utilization.fit_transform(utilizations.reshape(1, -1)) 

250 hourly_costs = scaler_hourly_cost.fit_transform(hourly_costs.reshape(1, -1)) 

251 

252 # Discrete features 

253 number_of_tasks = np.array([len(resource.assigned_tasks) for resource in resources]) 

254 

255 return { 

256 "resource_waiting_times": PPOInput._clean_np_array(waiting_times), 

257 "resource_available_time": PPOInput._clean_np_array(available_times), 

258 "resource_utilization": PPOInput._clean_np_array(utilizations), 

259 "resource_num_tasks": PPOInput._clean_np_array(number_of_tasks), 

260 } 

261 

262 @staticmethod 

263 def get_state_from_store(store: Store) -> dict: 

264 """Get the input for the PPO model based on the current state of the store.""" 

265 resource_features = PPOInput._get_resource_features(store) 

266 task_features = PPOInput._get_task_features(store) 

267 

268 return { 

269 **resource_features, 

270 **task_features, 

271 } 

272 

273 @staticmethod 

274 def get_action_space_from_actions( 

275 actions: list[Optional[BaseAction]], 

276 ) -> spaces.Discrete: 

277 """Get the action space based on the actions.""" 

278 return spaces.Discrete(len(actions)) 

279 

280 @staticmethod 

281 def get_actions_from_store(store: Store) -> list[Optional[BaseAction]]: 

282 """Get the action based on the index.""" 

283 # TODO: This only uses base resources 

284 resources = store.base_timetable.get_all_resources() # type: ignore # noqa: F841 

285 current_timetable = store.current_timetable # type: ignore # noqa: F841 

286 

287 actions: list[Optional[BaseAction]] = [] 

288 

289 # Add 1h to large waiting time rule (or create), per task 

290 # Remove 1h from large waiting time rule (or remove), per task 

291 # Add 1h to ready waiting time rule (or create), per task 

292 # Remove 1h from ready waiting time rule (or remove), per task 

293 # Increase batch size (or create), per task 

294 # Decrease batch size (or remove), per task 

295 # Add Batching DateTime Rule, per day, per task, 

296 # Shift Batching DateTime Rule forward, per day, per task 

297 # Shift Batching DateTime Rule backward, per day, per task 

298 # Remove Batching DateTime Rule, per day, per task 

299 

300 for task_id in store.current_timetable.get_task_ids(): 

301 constraints = store.constraints.get_batching_size_rule_constraints(task_id) 

302 duration_fn = "size" if not constraints else constraints[0].duration_fn 

303 actions.append( 

304 ModifySizeOfSignificantRuleAction( 

305 ModifySizeOfSignificantRuleActionParamsType( 

306 task_id=task_id, 

307 change_size=1, 

308 duration_fn=duration_fn, 

309 ) 

310 ) 

311 ) 

312 actions.append( 

313 ModifySizeOfSignificantRuleAction( 

314 ModifySizeOfSignificantRuleActionParamsType( 

315 task_id=task_id, 

316 change_size=-1, 

317 duration_fn=duration_fn, 

318 ) 

319 ) 

320 ) 

321 actions.append( 

322 ModifyLargeReadyWtOfSignificantRuleAction( 

323 ModifyLargeReadyWtOfSignificantRuleActionParamsType( 

324 task_id=task_id, 

325 type=RULE_TYPE.LARGE_WT, 

326 change_wt=1, 

327 duration_fn=duration_fn, 

328 ) 

329 ) 

330 ) 

331 actions.append( 

332 ModifyLargeReadyWtOfSignificantRuleAction( 

333 ModifyLargeReadyWtOfSignificantRuleActionParamsType( 

334 task_id=task_id, 

335 type=RULE_TYPE.LARGE_WT, 

336 change_wt=-1, 

337 duration_fn=duration_fn, 

338 ) 

339 ) 

340 ) 

341 actions.append( 

342 ModifyLargeReadyWtOfSignificantRuleAction( 

343 ModifyLargeReadyWtOfSignificantRuleActionParamsType( 

344 task_id=task_id, 

345 type=RULE_TYPE.READY_WT, 

346 change_wt=1, 

347 duration_fn=duration_fn, 

348 ) 

349 ) 

350 ) 

351 actions.append( 

352 ModifyLargeReadyWtOfSignificantRuleAction( 

353 ModifyLargeReadyWtOfSignificantRuleActionParamsType( 

354 task_id=task_id, 

355 type=RULE_TYPE.READY_WT, 

356 change_wt=-1, 

357 duration_fn=duration_fn, 

358 ) 

359 ) 

360 ) 

361 for day in DAYS: 

362 # TODO: Is a default value for start and end okay? 

363 actions.append( 

364 AddDateTimeRuleAction( 

365 params=AddDateTimeRuleBaseActionParamsType( 

366 task_id=task_id, 

367 time_period=TimePeriod.from_start_end(day=day, start=12, end=13), 

368 duration_fn=duration_fn, 

369 ) 

370 ) 

371 ) 

372 actions.append( 

373 ShiftDateTimeRuleAction( 

374 params=ShiftDateTimeRuleBaseActionParamsType( 

375 task_id=task_id, day=day, add_to_start=-1, add_to_end=1 

376 ) 

377 ) 

378 ) 

379 actions.append( 

380 ShiftDateTimeRuleAction( 

381 params=ShiftDateTimeRuleBaseActionParamsType( 

382 task_id=task_id, day=day, add_to_start=1, add_to_end=-1 

383 ) 

384 ) 

385 ) 

386 actions.append( 

387 RemoveDateTimeRuleAction( 

388 params=RemoveDateTimeRuleActionParamsType(task_id=task_id, day=day) 

389 ) 

390 ) 

391 

392 return [ 

393 action 

394 if ( 

395 action is not None 

396 and store.is_tabu(action) is False 

397 and ( 

398 Settings.disable_action_validity_check 

399 or action.check_if_valid(store, mark_no_change_as_invalid=True) 

400 ) 

401 ) 

402 else None 

403 for action in actions 

404 ] 

405 

406 @staticmethod 

407 def get_action_mask_from_actions(actions: list[Optional[BaseAction]]) -> np.ndarray: 

408 """Get the action mask based on the actions.""" 

409 mask = np.array([action is not None for action in actions]) 

410 return mask 

411 

412 @staticmethod 

413 def _clean_np_array(array: np.ndarray) -> np.ndarray: 

414 """Clean the numpy array.""" 

415 return np.nan_to_num(array, nan=0, posinf=0, neginf=0).reshape(1, -1)