Coverage for o2/ppo_utils/ppo_input.py: 100%
101 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from dataclasses import dataclass
2from typing import Optional
4import numpy as np
5from gymnasium import spaces
6from sklearn.preprocessing import MinMaxScaler
8from o2.actions.base_actions.add_datetime_rule_base_action import (
9 AddDateTimeRuleAction,
10 AddDateTimeRuleBaseActionParamsType,
11)
12from o2.actions.base_actions.base_action import BaseAction
13from o2.actions.base_actions.shift_datetime_rule_base_action import (
14 ShiftDateTimeRuleAction,
15 ShiftDateTimeRuleBaseActionParamsType,
16)
17from o2.actions.batching_actions.modify_large_ready_wt_of_significant_rule_action import (
18 ModifyLargeReadyWtOfSignificantRuleAction,
19 ModifyLargeReadyWtOfSignificantRuleActionParamsType,
20)
21from o2.actions.batching_actions.modify_size_of_significant_rule_action import (
22 ModifySizeOfSignificantRuleAction,
23 ModifySizeOfSignificantRuleActionParamsType,
24)
25from o2.actions.batching_actions.remove_date_time_rule_action import (
26 RemoveDateTimeRuleAction,
27 RemoveDateTimeRuleActionParamsType,
28)
29from o2.models.days import DAYS
30from o2.models.settings import Settings
31from o2.models.timetable import RULE_TYPE
32from o2.models.timetable.time_period import TimePeriod
33from o2.store import Store
36@dataclass(frozen=True)
37class PPOInput:
38 """The PPOInput will be used as an input for the PPO model.
40 The following columns/features are defined:
42 Continuous features:
43 Task:
44 - waiting_times, per task
45 - idle_time, per task
46 - percentage of waiting is batching, per task
47 - fixed_cost, per task
48 - Percentage of task instances that either have a waiting or idle time, per task
50 Resource:
51 - waiting_times, per resource
52 - available_time, per resource
53 - utilization, per resource
56 Batching rules:
57 - avg_batch_size, per task
59 Discrete features:
60 - Number of task enablements, per task, per day
61 - Number of resources, per task
62 - Number of tasks, per resource
65 Also we'll have the following actions:
66 - Add Batching DateTime Rule, per day, per task,
67 - Shift Batching DateTime Rule forward, per task
68 - Shift Batching DateTime Rule backward, per task
69 - Remove Batching DateTime Rule, per task
70 - Add 1h to large waiting time rule, per task
71 - Remove 1h from large waiting time rule, per task
72 - Add 1h to ready waiting time rule, per task
73 - Remove 1h from ready waiting time rule, per task
74 - Increase batch size, per task
75 - Decrease batch size, per task
77 We'll use action masking to disable the invalid actions per step.
79 """
81 @staticmethod
82 def get_observation_space(store: Store) -> spaces.Dict:
83 """Get the observation space based on the current state of the store."""
84 task_ids = store.current_timetable.get_task_ids()
85 num_tasks = len(task_ids)
86 num_days = len(DAYS)
87 resources = store.current_timetable.get_all_resources()
88 num_resources = len(resources)
89 num_cases = Settings.NUMBER_OF_CASES or store.current_timetable.total_cases
91 high = num_cases * 2
93 return spaces.Dict(
94 {
95 # Continuous resource-related observations
96 "resource_waiting_times": spaces.Box(low=0, high=1, shape=(1, num_resources)),
97 "resource_available_time": spaces.Box(low=0, high=1, shape=(1, num_resources)),
98 "resource_utilization": spaces.Box(low=0, high=1, shape=(1, num_resources)),
99 # Discrete resource-related observations
100 "resource_num_tasks": spaces.Box(low=0, high=high, shape=(1, num_resources)),
101 # Continuous task-related observations
102 "task_waiting_times": spaces.Box(low=0, high=1, shape=(1, num_tasks)),
103 "task_idle_time": spaces.Box(low=0, high=1, shape=(1, num_tasks)),
104 "task_batching_waiting_times_percentage": spaces.Box(low=0, high=1, shape=(1, num_tasks)),
105 # Discrete task-related observations
106 "task_execution_percentage_with_wt_or_it": spaces.Box(
107 low=0, high=num_cases, shape=(1, num_tasks)
108 ),
109 "task_enablements_per_day": spaces.Box(
110 low=0, high=num_cases, shape=(1, num_tasks * num_days)
111 ),
112 "task_num_resources": spaces.Box(low=0, high=num_resources, shape=(1, num_tasks)),
113 "task_average_batch_size": spaces.Box(low=0, high=num_resources, shape=(1, num_tasks)),
114 }
115 )
117 @staticmethod
118 def _get_task_features(store: Store) -> dict[str, np.ndarray]:
119 task_ids = store.current_timetable.get_task_ids()
120 kpis = store.current_evaluation.task_kpis
121 evaluation = store.current_evaluation
123 # MinMaxScaler for continuous features
124 scaler_waiting_time = MinMaxScaler()
125 scaler_idle_time = MinMaxScaler()
127 # Continuous features
128 waiting_times = np.array(
129 [kpis[task_id].waiting_time.total if task_id in kpis else 0 for task_id in task_ids]
130 )
132 batching_waiting_times = np.array(
133 [evaluation.total_batching_waiting_time_per_task.get(task_id, 0) for task_id in task_ids]
134 )
136 batching_waiting_times_percentage = np.array(
137 [
138 (batching_waiting_times[i] / waiting_times[i]) if waiting_times[i] != 0 else 0
139 for i in range(len(task_ids))
140 ]
141 )
143 idle_times = np.array(
144 [kpis[task_id].idle_time.total if task_id in kpis else 0 for task_id in task_ids]
145 )
147 # Scale continuous features
148 waiting_times = scaler_waiting_time.fit_transform(waiting_times.reshape(1, -1))
149 idle_times = scaler_idle_time.fit_transform(idle_times.reshape(1, -1))
151 # Discrete features
152 task_execution_count_with_wt_or_it_dict = evaluation.task_execution_count_with_wt_or_it
153 task_execution_percentage_with_wt_or_it_ = np.array(
154 [
155 task_execution_count_with_wt_or_it_dict[task_id] / evaluation.task_execution_counts[task_id]
156 if task_id in task_execution_count_with_wt_or_it_dict
157 else 0
158 for task_id in task_ids
159 ]
160 )
162 task_enablements_per_day = np.array(
163 [
164 len(evaluation.task_enablement_weekdays[task_id][day])
165 if (
166 task_id in evaluation.task_enablement_weekdays
167 and day in evaluation.task_enablement_weekdays[task_id]
168 )
169 else 0
170 for task_id in task_ids
171 for day in DAYS
172 ]
173 )
175 number_of_resources = np.array(
176 [
177 len(store.current_timetable.get_resources_assigned_to_task(task_id))
178 if task_id in task_ids
179 else 0
180 for task_id in task_ids
181 ]
182 )
184 average_batch_size = np.array(
185 [evaluation.avg_batch_size_per_task.get(task_id, 1) for task_id in task_ids]
186 )
188 return {
189 "task_waiting_times": PPOInput._clean_np_array(waiting_times),
190 "task_idle_time": PPOInput._clean_np_array(idle_times),
191 "task_batching_waiting_times_percentage": PPOInput._clean_np_array(
192 batching_waiting_times_percentage
193 ),
194 "task_execution_percentage_with_wt_or_it": PPOInput._clean_np_array(
195 task_execution_percentage_with_wt_or_it_
196 ),
197 "task_enablements_per_day": PPOInput._clean_np_array(task_enablements_per_day),
198 "task_num_resources": PPOInput._clean_np_array(number_of_resources),
199 "task_average_batch_size": PPOInput._clean_np_array(average_batch_size),
200 }
202 @staticmethod
203 def _get_resource_features(store: Store) -> dict[str, np.ndarray]:
204 # TODO: This only uses base resources
205 resources = store.base_timetable.get_all_resources()
206 evaluation = store.current_evaluation
208 # MinMaxScaler for continuous features
209 scaler_waiting_time = MinMaxScaler()
210 scaler_available_time = MinMaxScaler()
211 scaler_utilization = MinMaxScaler()
212 scaler_hourly_cost = MinMaxScaler()
214 # Continuous features
215 waiting_times_dict = evaluation.total_batching_waiting_time_per_resource
216 waiting_times = np.array(
217 [
218 waiting_times_dict[resource.id] # noqa: SIM401
219 if resource.id in waiting_times_dict
220 # TODO: 0 might not be the best default value
221 else 0
222 for resource in resources
223 ]
224 )
226 available_times = np.array(
227 [
228 evaluation.resource_kpis[resource.id].available_time
229 if resource.id in evaluation.resource_kpis
230 else 0
231 for resource in resources
232 ]
233 )
235 utilizations = np.array(
236 [
237 evaluation.resource_kpis[resource.id].utilization
238 if resource.id in evaluation.resource_kpis
239 else 0
240 for resource in resources
241 ]
242 )
244 hourly_costs = np.array([resource.cost_per_hour for resource in resources])
246 # Scale continuous features
247 waiting_times = scaler_waiting_time.fit_transform(waiting_times.reshape(1, -1))
248 available_times = scaler_available_time.fit_transform(available_times.reshape(1, -1))
249 utilizations = scaler_utilization.fit_transform(utilizations.reshape(1, -1))
250 hourly_costs = scaler_hourly_cost.fit_transform(hourly_costs.reshape(1, -1))
252 # Discrete features
253 number_of_tasks = np.array([len(resource.assigned_tasks) for resource in resources])
255 return {
256 "resource_waiting_times": PPOInput._clean_np_array(waiting_times),
257 "resource_available_time": PPOInput._clean_np_array(available_times),
258 "resource_utilization": PPOInput._clean_np_array(utilizations),
259 "resource_num_tasks": PPOInput._clean_np_array(number_of_tasks),
260 }
262 @staticmethod
263 def get_state_from_store(store: Store) -> dict:
264 """Get the input for the PPO model based on the current state of the store."""
265 resource_features = PPOInput._get_resource_features(store)
266 task_features = PPOInput._get_task_features(store)
268 return {
269 **resource_features,
270 **task_features,
271 }
273 @staticmethod
274 def get_action_space_from_actions(
275 actions: list[Optional[BaseAction]],
276 ) -> spaces.Discrete:
277 """Get the action space based on the actions."""
278 return spaces.Discrete(len(actions))
280 @staticmethod
281 def get_actions_from_store(store: Store) -> list[Optional[BaseAction]]:
282 """Get the action based on the index."""
283 # TODO: This only uses base resources
284 resources = store.base_timetable.get_all_resources() # type: ignore # noqa: F841
285 current_timetable = store.current_timetable # type: ignore # noqa: F841
287 actions: list[Optional[BaseAction]] = []
289 # Add 1h to large waiting time rule (or create), per task
290 # Remove 1h from large waiting time rule (or remove), per task
291 # Add 1h to ready waiting time rule (or create), per task
292 # Remove 1h from ready waiting time rule (or remove), per task
293 # Increase batch size (or create), per task
294 # Decrease batch size (or remove), per task
295 # Add Batching DateTime Rule, per day, per task,
296 # Shift Batching DateTime Rule forward, per day, per task
297 # Shift Batching DateTime Rule backward, per day, per task
298 # Remove Batching DateTime Rule, per day, per task
300 for task_id in store.current_timetable.get_task_ids():
301 constraints = store.constraints.get_batching_size_rule_constraints(task_id)
302 duration_fn = "size" if not constraints else constraints[0].duration_fn
303 actions.append(
304 ModifySizeOfSignificantRuleAction(
305 ModifySizeOfSignificantRuleActionParamsType(
306 task_id=task_id,
307 change_size=1,
308 duration_fn=duration_fn,
309 )
310 )
311 )
312 actions.append(
313 ModifySizeOfSignificantRuleAction(
314 ModifySizeOfSignificantRuleActionParamsType(
315 task_id=task_id,
316 change_size=-1,
317 duration_fn=duration_fn,
318 )
319 )
320 )
321 actions.append(
322 ModifyLargeReadyWtOfSignificantRuleAction(
323 ModifyLargeReadyWtOfSignificantRuleActionParamsType(
324 task_id=task_id,
325 type=RULE_TYPE.LARGE_WT,
326 change_wt=1,
327 duration_fn=duration_fn,
328 )
329 )
330 )
331 actions.append(
332 ModifyLargeReadyWtOfSignificantRuleAction(
333 ModifyLargeReadyWtOfSignificantRuleActionParamsType(
334 task_id=task_id,
335 type=RULE_TYPE.LARGE_WT,
336 change_wt=-1,
337 duration_fn=duration_fn,
338 )
339 )
340 )
341 actions.append(
342 ModifyLargeReadyWtOfSignificantRuleAction(
343 ModifyLargeReadyWtOfSignificantRuleActionParamsType(
344 task_id=task_id,
345 type=RULE_TYPE.READY_WT,
346 change_wt=1,
347 duration_fn=duration_fn,
348 )
349 )
350 )
351 actions.append(
352 ModifyLargeReadyWtOfSignificantRuleAction(
353 ModifyLargeReadyWtOfSignificantRuleActionParamsType(
354 task_id=task_id,
355 type=RULE_TYPE.READY_WT,
356 change_wt=-1,
357 duration_fn=duration_fn,
358 )
359 )
360 )
361 for day in DAYS:
362 # TODO: Is a default value for start and end okay?
363 actions.append(
364 AddDateTimeRuleAction(
365 params=AddDateTimeRuleBaseActionParamsType(
366 task_id=task_id,
367 time_period=TimePeriod.from_start_end(day=day, start=12, end=13),
368 duration_fn=duration_fn,
369 )
370 )
371 )
372 actions.append(
373 ShiftDateTimeRuleAction(
374 params=ShiftDateTimeRuleBaseActionParamsType(
375 task_id=task_id, day=day, add_to_start=-1, add_to_end=1
376 )
377 )
378 )
379 actions.append(
380 ShiftDateTimeRuleAction(
381 params=ShiftDateTimeRuleBaseActionParamsType(
382 task_id=task_id, day=day, add_to_start=1, add_to_end=-1
383 )
384 )
385 )
386 actions.append(
387 RemoveDateTimeRuleAction(
388 params=RemoveDateTimeRuleActionParamsType(task_id=task_id, day=day)
389 )
390 )
392 return [
393 action
394 if (
395 action is not None
396 and store.is_tabu(action) is False
397 and (
398 Settings.disable_action_validity_check
399 or action.check_if_valid(store, mark_no_change_as_invalid=True)
400 )
401 )
402 else None
403 for action in actions
404 ]
406 @staticmethod
407 def get_action_mask_from_actions(actions: list[Optional[BaseAction]]) -> np.ndarray:
408 """Get the action mask based on the actions."""
409 mask = np.array([action is not None for action in actions])
410 return mask
412 @staticmethod
413 def _clean_np_array(array: np.ndarray) -> np.ndarray:
414 """Clean the numpy array."""
415 return np.nan_to_num(array, nan=0, posinf=0, neginf=0).reshape(1, -1)