Coverage for o2/actions/batching_actions/modify_size_rule_by_duration_fn_action.py: 24%
82 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from collections import defaultdict
3from o2.actions.base_actions.add_size_rule_base_action import (
4 AddSizeRuleAction,
5 AddSizeRuleBaseActionParamsType,
6)
7from o2.actions.base_actions.base_action import (
8 RateSelfReturnType,
9)
10from o2.actions.base_actions.modify_size_rule_base_action import (
11 ModifySizeRuleBaseAction,
12 ModifySizeRuleBaseActionParamsType,
13)
14from o2.models.rule_selector import RuleSelector
15from o2.models.self_rating import RATING
16from o2.models.solution import Solution
17from o2.models.timetable import RULE_TYPE
18from o2.store import Store
19from o2.util.helper import select_variants
22class ModifySizeRuleByDurationFnActionParamsType(ModifySizeRuleBaseActionParamsType):
23 """Parameter for ModifySizeRuleByCostFn."""
25 pass
28class ModifyBatchSizeIfNoDurationImprovementAction(ModifySizeRuleBaseAction):
29 """An Action to modify size batching rules based on the cost fn.
31 If batch size decrement does not increase Duration (looking at the duration fn) => Decrease Batch Size
32 """
34 params: ModifySizeRuleByDurationFnActionParamsType
36 @staticmethod
37 def rate_self(
38 store: "Store", input: "Solution"
39 ) -> RateSelfReturnType["ModifyBatchSizeIfNoDurationImprovementAction"]:
40 """Generate a best set of parameters & self-evaluates this action."""
41 timetable = store.current_timetable
42 constraints = store.constraints
43 state = store.current_state
45 task_ids = timetable.get_task_ids()
46 rule_selectors_by_duration: dict[float, list[RuleSelector]] = defaultdict(list)
48 for task_id in task_ids:
49 firing_rule_selectors = timetable.get_firing_rule_selectors_for_task(
50 task_id, rule_type=RULE_TYPE.SIZE
51 )
52 if not firing_rule_selectors:
53 continue
55 size_constraints = constraints.get_batching_size_rule_constraints(task_id)
56 if not size_constraints:
57 continue
58 size_constraint = size_constraints[0]
59 for firing_rule_selector in firing_rule_selectors:
60 firing_rule = firing_rule_selector.get_firing_rule_from_state(state)
61 if firing_rule is None:
62 continue
63 size = firing_rule.value
65 new_duration = size_constraint.duration_fn_lambda(size - 1)
66 old_duration = size_constraint.duration_fn_lambda(size)
67 if new_duration <= old_duration:
68 rule_selectors_by_duration[old_duration - new_duration].append(firing_rule_selector)
70 sorted_rule_selectors = sorted(
71 rule_selectors_by_duration.keys(),
72 reverse=True,
73 )
75 for duration in sorted_rule_selectors:
76 rule_selectors = rule_selectors_by_duration[duration]
77 for rule_selector in select_variants(store, rule_selectors):
78 duration_fn = constraints.get_duration_fn_for_task(rule_selector.batching_rule_task_id)
80 yield (
81 RATING.LOW,
82 ModifyBatchSizeIfNoDurationImprovementAction(
83 ModifySizeRuleByDurationFnActionParamsType(
84 rule=rule_selector,
85 size_increment=-1,
86 duration_fn=duration_fn,
87 )
88 ),
89 )
92class ModifySizeRuleByDurationFnCostImpactAction(ModifySizeRuleBaseAction):
93 """An Action to modify size batching rules based on the duration fn.
95 If batch size increment reduces Duration => Increase Batch Size
96 - Sorted by the cost impact of that size increment.
98 NOTE: We do NOT limit the number of results here, because this action is
99 also a fallback of sorts, so we make sure that every size rule is incremented
100 if sensible.
101 """
103 params: ModifySizeRuleByDurationFnActionParamsType
105 @staticmethod
106 def rate_self(
107 store: "Store", input: "Solution"
108 ) -> RateSelfReturnType["ModifySizeRuleByDurationFnCostImpactAction | AddSizeRuleAction"]:
109 """Generate a best set of parameters & self-evaluates this action."""
110 constraints = store.constraints
111 timetable = input.state.timetable
112 state = input.state
114 task_ids = timetable.get_task_ids()
115 rule_selectors_by_duration: dict[float, list[RuleSelector]] = defaultdict(list)
117 for task_id in task_ids:
118 size_constraints = constraints.get_batching_size_rule_constraints(task_id)
119 if not size_constraints:
120 continue
121 size_constraint = size_constraints[0]
123 firing_rule_selectors = timetable.get_firing_rule_selectors_for_task(
124 task_id, rule_type=RULE_TYPE.SIZE
125 )
126 # In case we do not have firing_rule selectors, we might want to add a new rule
127 if not firing_rule_selectors:
128 old_duration = size_constraint.duration_fn_lambda(1)
129 new_duration = size_constraint.duration_fn_lambda(2)
130 if new_duration <= old_duration:
131 duration_fn = store.constraints.get_duration_fn_for_task(task_id)
132 rule_selectors_by_duration[old_duration - new_duration].append(
133 RuleSelector(batching_rule_task_id=task_id, firing_rule_index=None)
134 )
136 for firing_rule_selector in firing_rule_selectors:
137 firing_rule = firing_rule_selector.get_firing_rule_from_state(state)
138 if firing_rule is None:
139 continue
140 size = firing_rule.value
142 old_duration = size_constraint.duration_fn_lambda(size)
143 new_duration = size_constraint.duration_fn_lambda(size + 1)
145 if new_duration <= old_duration:
146 rule_selectors_by_duration[old_duration - new_duration].append(firing_rule_selector)
148 sorted_rule_selectors = sorted(
149 rule_selectors_by_duration.keys(),
150 reverse=True,
151 )
153 for duration in sorted_rule_selectors:
154 rule_selectors = rule_selectors_by_duration[duration]
155 for rule_selector in select_variants(store, rule_selectors):
156 duration_fn = constraints.get_duration_fn_for_task(rule_selector.batching_rule_task_id)
157 if rule_selector.firing_rule_index is not None:
158 yield (
159 RATING.LOW,
160 ModifySizeRuleByDurationFnCostImpactAction(
161 ModifySizeRuleByDurationFnActionParamsType(
162 rule=rule_selector,
163 size_increment=1,
164 duration_fn=duration_fn,
165 )
166 ),
167 )
168 else:
169 yield (
170 RATING.LOW,
171 AddSizeRuleAction(
172 AddSizeRuleBaseActionParamsType(
173 task_id=rule_selector.batching_rule_task_id,
174 size=2,
175 duration_fn=duration_fn,
176 )
177 ),
178 )