Coverage for o2/actions/batching_actions/modify_size_of_significant_rule_action.py: 98%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from typing_extensions import Required, override 

2 

3from o2.actions.base_actions.base_action import ( 

4 BaseAction, 

5 BaseActionParamsType, 

6 RateSelfReturnType, 

7) 

8from o2.actions.base_actions.modify_size_rule_base_action import ( 

9 ModifySizeRuleAction, 

10 ModifySizeRuleBaseActionParamsType, 

11) 

12from o2.models.rule_selector import RuleSelector 

13from o2.models.solution import Solution 

14from o2.models.state import State 

15from o2.models.timetable import ( 

16 RULE_TYPE, 

17 BatchingRule, 

18 FiringRule, 

19 rule_is_size, 

20) 

21from o2.store import Store 

22 

23 

24class ModifySizeOfSignificantRuleActionParamsType(BaseActionParamsType): 

25 """Parameter for ModifySizeOfSignificantRuleAction.""" 

26 

27 task_id: Required[str] 

28 change_size: Required[int] 

29 """How much to change the size of the rule by; positive for increase, negative for decrease.""" 

30 duration_fn: Required[str] 

31 

32 

33class ModifySizeOfSignificantRuleAction(BaseAction): 

34 """ModifySizeOfSignificantRuleAction will modify the size of the most significant BatchingRule.""" 

35 

36 params: ModifySizeOfSignificantRuleActionParamsType 

37 

38 @override 

39 def apply(self, state: State, enable_prints: bool = True) -> State: 

40 timetable = state.timetable 

41 task_id = self.params["task_id"] 

42 change_size = self.params["change_size"] 

43 duration_fn = self.params.get("duration_fn", None) 

44 

45 batching_rules = timetable.get_batching_rules_for_task(task_id) 

46 

47 # Smallest size (only) and-rule (if change_size > 0) 

48 # Largest size (only) and-rule (if change_size < 0) 

49 significant_rule = None 

50 significant_size = float("inf") if change_size > 0 else -float("inf") 

51 

52 for batching_rule in batching_rules: 

53 for i, and_rules in enumerate(batching_rule.firing_rules): 

54 if len(and_rules) > 1 or len(and_rules) == 0: 

55 continue 

56 firing_rule = and_rules[0] 

57 if rule_is_size(firing_rule): 

58 size = int(firing_rule.value) 

59 new_size = size + change_size 

60 if new_size < 1: 

61 continue 

62 if ( 

63 change_size > 0 

64 and size < significant_size 

65 or change_size < 0 

66 and size > significant_size 

67 ): 

68 significant_rule = RuleSelector.from_batching_rule(batching_rule, (i, 0)) 

69 significant_size = size 

70 

71 # If no significant rule is found, add a new one 

72 if significant_rule is None: 

73 # TODO: We need to find the min size from the constraints 

74 new_size = min(max(1 + change_size, 1), 2) 

75 new_batching_rule = BatchingRule.from_task_id( 

76 task_id=task_id, 

77 firing_rules=[FiringRule.gte(RULE_TYPE.SIZE, new_size)], 

78 duration_fn=duration_fn, 

79 ) 

80 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule]) 

81 

82 return ModifySizeRuleAction( 

83 ModifySizeRuleBaseActionParamsType( 

84 rule=significant_rule, 

85 size_increment=change_size, 

86 duration_fn=duration_fn, 

87 ) 

88 ).apply(state, enable_prints=enable_prints) 

89 

90 @override 

91 @staticmethod 

92 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

93 raise NotImplementedError("Not implemented")