Coverage for o2/actions/base_actions/modify_size_rule_base_action.py: 85%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass, replace 

3 

4from typing_extensions import override 

5 

6from o2.actions.base_actions.base_action import ( 

7 RateSelfReturnType, 

8) 

9from o2.actions.base_actions.batching_rule_base_action import ( 

10 BatchingRuleBaseAction, 

11 BatchingRuleBaseActionParamsType, 

12) 

13from o2.models.self_rating import RATING 

14from o2.models.solution import Solution 

15from o2.models.state import State 

16from o2.models.timetable import ( 

17 BatchingRule, 

18 Distribution, 

19 rule_is_size, 

20) 

21from o2.store import Store 

22from o2.util.logger import warn 

23 

24 

25class ModifySizeRuleBaseActionParamsType(BatchingRuleBaseActionParamsType): 

26 """Parameter for ModifySizeRuleBaseAction.""" 

27 

28 size_increment: int 

29 

30 

31@dataclass(frozen=True) 

32class ModifySizeRuleBaseAction(BatchingRuleBaseAction, ABC, str=False): 

33 """ModifySizeRuleBaseAction will modify the size of a BatchingRule. 

34 

35 This will effect the size distribution and the duration distribution of the rule, 

36 as well as the firing rule. 

37 """ 

38 

39 params: ModifySizeRuleBaseActionParamsType 

40 

41 @override 

42 def apply(self, state: State, enable_prints: bool = True) -> State: 

43 timetable = state.timetable 

44 rule_selector = self.params["rule"] 

45 duration_fn = self.params.get("duration_fn", "1") 

46 

47 _, batching_rule = timetable.get_batching_rule(rule_selector) 

48 if batching_rule is None: 

49 warn(f"BatchingRule not found for {rule_selector}") 

50 return state 

51 

52 firing_rule = batching_rule.get_firing_rule(rule_selector) 

53 if firing_rule is None: 

54 warn(f"FiringRule not found for {rule_selector}") 

55 return state 

56 

57 if not rule_is_size(firing_rule): 

58 return state 

59 

60 new_size = int(firing_rule.value) + self.params["size_increment"] 

61 # We don't allow size 1, as that basically means no batching 

62 if new_size < 1: 

63 return state 

64 

65 new_firing_rule = replace(firing_rule, value=new_size) 

66 

67 new_timetable = state.timetable.replace_firing_rule( 

68 rule_selector, new_firing_rule, duration_fn=duration_fn 

69 ) 

70 

71 return replace(state, timetable=new_timetable) 

72 

73 def get_dominant_distribution(self, old_rule: BatchingRule) -> Distribution: 

74 """Find the size distribution with the highest probability.""" 

75 return max( 

76 old_rule.size_distrib, 

77 key=lambda distribution: distribution.value, 

78 ) 

79 

80 @override 

81 @staticmethod 

82 @abstractmethod 

83 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

84 pass 

85 

86 @staticmethod 

87 def get_default_rating() -> RATING: 

88 """Return the default rating for this action.""" 

89 return RATING.MEDIUM 

90 

91 

92class ModifySizeRuleAction(ModifySizeRuleBaseAction): 

93 """ModifySizeRuleAction will modify the size of a BatchingRule.""" 

94 

95 @override 

96 @staticmethod 

97 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

98 raise NotImplementedError("Not implemented")