Coverage for o2/actions/batching_actions/modify_large_ready_wt_of_significant_rule_action.py: 81%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from typing import Literal 

2 

3from typing_extensions import Required, override 

4 

5from o2.actions.base_actions.base_action import ( 

6 BaseAction, 

7 BaseActionParamsType, 

8 RateSelfReturnType, 

9) 

10from o2.models.rule_selector import RuleSelector 

11from o2.models.solution import Solution 

12from o2.models.state import State 

13from o2.models.timetable import ( 

14 RULE_TYPE, 

15 BatchingRule, 

16 FiringRule, 

17) 

18from o2.store import Store 

19 

20 

21class ModifyLargeReadyWtOfSignificantRuleActionParamsType(BaseActionParamsType): 

22 """Parameter for ModifyLargeReadyWtOfSignificantRuleAction.""" 

23 

24 task_id: Required[str] 

25 type: Required[Literal[RULE_TYPE.LARGE_WT, RULE_TYPE.READY_WT]] 

26 change_wt: Required[int] 

27 """How much to change the wt of the rule by; positive for increase, negative for decrease.""" 

28 duration_fn: Required[str] 

29 

30 

31class ModifyLargeReadyWtOfSignificantRuleAction(BaseAction): 

32 """ModifyLargeReadyWtOfSignificantRuleAction will modify the size of the most significant BatchingRule.""" 

33 

34 params: ModifyLargeReadyWtOfSignificantRuleActionParamsType 

35 

36 @override 

37 def apply(self, state: State, enable_prints: bool = True) -> State: 

38 timetable = state.timetable 

39 task_id = self.params["task_id"] 

40 change_wt = self.params["change_wt"] 

41 duration_fn = self.params.get("duration_fn", None) 

42 firing_rule_type = self.params["type"] 

43 

44 batching_rules = timetable.get_batching_rules_for_task(task_id) 

45 

46 # Smallest wt (only) and-rule (if change_size > 0) 

47 # Largest wt (only) and-rule (if change_size < 0) 

48 significant_rule = None 

49 significant_wt = float("inf") if change_wt > 0 else -float("inf") 

50 

51 for batching_rule in batching_rules: 

52 for i, and_rules in enumerate(batching_rule.firing_rules): 

53 if len(and_rules) > 1 or len(and_rules) == 0: 

54 continue 

55 firing_rule = and_rules[0] 

56 if ( 

57 firing_rule.attribute == firing_rule_type 

58 # TODO: We should also support lte 

59 and firing_rule.is_gt_or_gte 

60 ): 

61 wt = int(firing_rule.value) // 3600 

62 new_wt = wt + change_wt 

63 if new_wt < 1 or new_wt > 23: 

64 continue 

65 if change_wt > 0 and wt < significant_wt or change_wt < 0 and wt > significant_wt: 

66 significant_rule = RuleSelector.from_batching_rule(batching_rule, (i, 0)) 

67 significant_wt = wt 

68 # If no significant rule is found, add a new one 

69 if significant_rule is None: 

70 batching_rule = BatchingRule.from_task_id( 

71 task_id, 

72 firing_rules=[ 

73 FiringRule.gte(firing_rule_type, abs(change_wt) * 3600), 

74 FiringRule.lte(firing_rule_type, 24 * 60 * 60), 

75 ], 

76 duration_fn=duration_fn, 

77 ) 

78 return state.replace_timetable(batch_processing=timetable.batch_processing + [batching_rule]) 

79 else: 

80 timetable = timetable.replace_firing_rule( 

81 rule_selector=significant_rule, 

82 new_firing_rule=FiringRule.gte(firing_rule_type, (significant_wt + change_wt) * 3600), 

83 ) 

84 return state.replace_timetable(timetable=timetable) 

85 

86 @override 

87 @staticmethod 

88 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

89 raise NotImplementedError("Not implemented")