Coverage for o2/actions/batching_actions/modify_size_of_significant_rule_action.py: 98%
46 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from typing_extensions import Required, override
3from o2.actions.base_actions.base_action import (
4 BaseAction,
5 BaseActionParamsType,
6 RateSelfReturnType,
7)
8from o2.actions.base_actions.modify_size_rule_base_action import (
9 ModifySizeRuleAction,
10 ModifySizeRuleBaseActionParamsType,
11)
12from o2.models.rule_selector import RuleSelector
13from o2.models.solution import Solution
14from o2.models.state import State
15from o2.models.timetable import (
16 RULE_TYPE,
17 BatchingRule,
18 FiringRule,
19 rule_is_size,
20)
21from o2.store import Store
24class ModifySizeOfSignificantRuleActionParamsType(BaseActionParamsType):
25 """Parameter for ModifySizeOfSignificantRuleAction."""
27 task_id: Required[str]
28 change_size: Required[int]
29 """How much to change the size of the rule by; positive for increase, negative for decrease."""
30 duration_fn: Required[str]
33class ModifySizeOfSignificantRuleAction(BaseAction):
34 """ModifySizeOfSignificantRuleAction will modify the size of the most significant BatchingRule."""
36 params: ModifySizeOfSignificantRuleActionParamsType
38 @override
39 def apply(self, state: State, enable_prints: bool = True) -> State:
40 timetable = state.timetable
41 task_id = self.params["task_id"]
42 change_size = self.params["change_size"]
43 duration_fn = self.params.get("duration_fn", None)
45 batching_rules = timetable.get_batching_rules_for_task(task_id)
47 # Smallest size (only) and-rule (if change_size > 0)
48 # Largest size (only) and-rule (if change_size < 0)
49 significant_rule = None
50 significant_size = float("inf") if change_size > 0 else -float("inf")
52 for batching_rule in batching_rules:
53 for i, and_rules in enumerate(batching_rule.firing_rules):
54 if len(and_rules) > 1 or len(and_rules) == 0:
55 continue
56 firing_rule = and_rules[0]
57 if rule_is_size(firing_rule):
58 size = int(firing_rule.value)
59 new_size = size + change_size
60 if new_size < 1:
61 continue
62 if (
63 change_size > 0
64 and size < significant_size
65 or change_size < 0
66 and size > significant_size
67 ):
68 significant_rule = RuleSelector.from_batching_rule(batching_rule, (i, 0))
69 significant_size = size
71 # If no significant rule is found, add a new one
72 if significant_rule is None:
73 # TODO: We need to find the min size from the constraints
74 new_size = min(max(1 + change_size, 1), 2)
75 new_batching_rule = BatchingRule.from_task_id(
76 task_id=task_id,
77 firing_rules=[FiringRule.gte(RULE_TYPE.SIZE, new_size)],
78 duration_fn=duration_fn,
79 )
80 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule])
82 return ModifySizeRuleAction(
83 ModifySizeRuleBaseActionParamsType(
84 rule=significant_rule,
85 size_increment=change_size,
86 duration_fn=duration_fn,
87 )
88 ).apply(state, enable_prints=enable_prints)
90 @override
91 @staticmethod
92 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
93 raise NotImplementedError("Not implemented")