Coverage for o2/actions/base_actions/add_size_rule_base_action.py: 59%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass, replace 

3 

4from sympy import Symbol, lambdify 

5from typing_extensions import Required, override 

6 

7from o2.actions.base_actions.base_action import ( 

8 BaseAction, 

9 BaseActionParamsType, 

10 RateSelfReturnType, 

11) 

12from o2.models.rule_selector import RuleSelector 

13from o2.models.self_rating import RATING 

14from o2.models.solution import Solution 

15from o2.models.state import State 

16from o2.models.timetable import ( 

17 RULE_TYPE, 

18 BatchingRule, 

19 Distribution, 

20 FiringRule, 

21) 

22from o2.store import Store 

23from o2.util.helper import select_variants 

24 

25 

26class AddSizeRuleBaseActionParamsType(BaseActionParamsType): 

27 """Parameter for ModifySizeRuleBaseAction.""" 

28 

29 size: Required[int] 

30 task_id: Required[str] 

31 duration_fn: Required[str] 

32 

33 

34@dataclass(frozen=True) 

35class AddSizeRuleBaseAction(BaseAction, ABC, str=False): 

36 """AddSizeRuleBaseAction will add a BatchingRule. 

37 

38 This will effect the size distribution and the duration distribution of the rule, 

39 as well as the firing rule. 

40 """ 

41 

42 params: AddSizeRuleBaseActionParamsType 

43 

44 @override 

45 def apply(self, state: State, enable_prints: bool = True) -> State: 

46 new_size = self.params["size"] 

47 task_id = self.params["task_id"] 

48 duration_fn = self.params.get("duration_fn", "1") 

49 

50 duration_lambda = lambdify(Symbol("size"), duration_fn) 

51 

52 timetable = state.timetable 

53 batching_rules = timetable.get_batching_rules_for_task(task_id) 

54 

55 if new_size < 1: 

56 raise ValueError(f"Size must be at least 1, got {new_size}") 

57 

58 # Make sure the size is at least 2, because 1 just means no batching 

59 batching_size = max(new_size, 2) 

60 

61 # Create fully fresh rule 

62 if len(batching_rules) == 0: 

63 new_batching_rule = BatchingRule.from_task_id( 

64 task_id=task_id, 

65 firing_rules=[FiringRule.gte(RULE_TYPE.SIZE, batching_size)], 

66 duration_fn=duration_fn, 

67 ) 

68 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule]) 

69 # Add OR Case to existing rule 

70 else: 

71 existing_rule = batching_rules[0] 

72 # TODO: Check if a single size rule already exists, and if so, replace it 

73 new_batching_rule = replace( 

74 existing_rule, 

75 # Integrate in existing size distribution 

76 size_distrib=[ 

77 Distribution(key=str(batching_size), value=1.0), 

78 ] 

79 + [ 

80 size_distrib 

81 for size_distrib in existing_rule.size_distrib 

82 if size_distrib.key != str(batching_size) 

83 ], 

84 duration_distrib=[ 

85 Distribution(key=str(batching_size), value=duration_lambda(batching_size)), 

86 ] 

87 + [ 

88 duration_distrib 

89 for duration_distrib in existing_rule.duration_distrib 

90 if duration_distrib.key != str(batching_size) 

91 ], 

92 firing_rules=existing_rule.firing_rules 

93 + [ 

94 [FiringRule.gte(RULE_TYPE.SIZE, batching_size)], 

95 ], 

96 ) 

97 new_timetable = timetable.replace_batching_rule( 

98 RuleSelector.from_batching_rule(new_batching_rule), 

99 new_batching_rule, 

100 ) 

101 return replace(state, timetable=new_timetable) 

102 

103 @override 

104 @staticmethod 

105 @abstractmethod 

106 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

107 pass 

108 

109 @staticmethod 

110 def get_default_rating() -> RATING: 

111 """Return the default rating for this action.""" 

112 return RATING.MEDIUM 

113 

114 

115class AddSizeRuleAction(AddSizeRuleBaseAction): 

116 """AddSizeRuleAction will add a BatchingRule.""" 

117 

118 @override 

119 @override 

120 @staticmethod 

121 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

122 task_ids = store.current_timetable.get_task_ids() 

123 

124 for task_id in select_variants(store, task_ids): 

125 duration_fn = store.constraints.get_duration_fn_for_task(task_id) 

126 yield ( 

127 RATING.VERY_LOW, 

128 AddSizeRuleAction( 

129 AddSizeRuleBaseActionParamsType( 

130 task_id=task_id, 

131 size=2, 

132 duration_fn=duration_fn, 

133 ) 

134 ), 

135 )