Coverage for o2/actions/batching_actions/modify_large_ready_wt_of_significant_rule_action.py: 81%
48 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from typing import Literal
3from typing_extensions import Required, override
5from o2.actions.base_actions.base_action import (
6 BaseAction,
7 BaseActionParamsType,
8 RateSelfReturnType,
9)
10from o2.models.rule_selector import RuleSelector
11from o2.models.solution import Solution
12from o2.models.state import State
13from o2.models.timetable import (
14 RULE_TYPE,
15 BatchingRule,
16 FiringRule,
17)
18from o2.store import Store
21class ModifyLargeReadyWtOfSignificantRuleActionParamsType(BaseActionParamsType):
22 """Parameter for ModifyLargeReadyWtOfSignificantRuleAction."""
24 task_id: Required[str]
25 type: Required[Literal[RULE_TYPE.LARGE_WT, RULE_TYPE.READY_WT]]
26 change_wt: Required[int]
27 """How much to change the wt of the rule by; positive for increase, negative for decrease."""
28 duration_fn: Required[str]
31class ModifyLargeReadyWtOfSignificantRuleAction(BaseAction):
32 """ModifyLargeReadyWtOfSignificantRuleAction will modify the size of the most significant BatchingRule."""
34 params: ModifyLargeReadyWtOfSignificantRuleActionParamsType
36 @override
37 def apply(self, state: State, enable_prints: bool = True) -> State:
38 timetable = state.timetable
39 task_id = self.params["task_id"]
40 change_wt = self.params["change_wt"]
41 duration_fn = self.params.get("duration_fn", None)
42 firing_rule_type = self.params["type"]
44 batching_rules = timetable.get_batching_rules_for_task(task_id)
46 # Smallest wt (only) and-rule (if change_size > 0)
47 # Largest wt (only) and-rule (if change_size < 0)
48 significant_rule = None
49 significant_wt = float("inf") if change_wt > 0 else -float("inf")
51 for batching_rule in batching_rules:
52 for i, and_rules in enumerate(batching_rule.firing_rules):
53 if len(and_rules) > 1 or len(and_rules) == 0:
54 continue
55 firing_rule = and_rules[0]
56 if (
57 firing_rule.attribute == firing_rule_type
58 # TODO: We should also support lte
59 and firing_rule.is_gt_or_gte
60 ):
61 wt = int(firing_rule.value) // 3600
62 new_wt = wt + change_wt
63 if new_wt < 1 or new_wt > 23:
64 continue
65 if change_wt > 0 and wt < significant_wt or change_wt < 0 and wt > significant_wt:
66 significant_rule = RuleSelector.from_batching_rule(batching_rule, (i, 0))
67 significant_wt = wt
68 # If no significant rule is found, add a new one
69 if significant_rule is None:
70 batching_rule = BatchingRule.from_task_id(
71 task_id,
72 firing_rules=[
73 FiringRule.gte(firing_rule_type, abs(change_wt) * 3600),
74 FiringRule.lte(firing_rule_type, 24 * 60 * 60),
75 ],
76 duration_fn=duration_fn,
77 )
78 return state.replace_timetable(batch_processing=timetable.batch_processing + [batching_rule])
79 else:
80 timetable = timetable.replace_firing_rule(
81 rule_selector=significant_rule,
82 new_firing_rule=FiringRule.gte(firing_rule_type, (significant_wt + change_wt) * 3600),
83 )
84 return state.replace_timetable(timetable=timetable)
86 @override
87 @staticmethod
88 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
89 raise NotImplementedError("Not implemented")