Coverage for o2/actions/base_actions/shift_datetime_rule_base_action.py: 93%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from abc import ABC 

2from dataclasses import dataclass, replace 

3 

4from typing_extensions import NotRequired, override 

5 

6from o2.actions.base_actions.base_action import ( 

7 BaseActionParamsType, 

8 RateSelfReturnType, 

9) 

10from o2.actions.base_actions.batching_rule_base_action import ( 

11 BatchingRuleBaseAction, 

12) 

13from o2.models.days import DAY 

14from o2.models.solution import Solution 

15from o2.models.state import State 

16from o2.models.timetable import ( 

17 BatchingRule, 

18 Distribution, 

19 rule_is_daily_hour, 

20) 

21from o2.store import Store 

22from o2.util.logger import info 

23 

24 

25class ShiftDateTimeRuleBaseActionParamsType(BaseActionParamsType): 

26 """Parameter for ShiftDateTimeRuleBaseAction.""" 

27 

28 task_id: str 

29 day: DAY 

30 add_to_start: NotRequired[int] 

31 """How many hours to add to the start of the rule. 

32 (e.g. 1 = add 1 hour before, -1 = remove 1 hour after)""" 

33 add_to_end: NotRequired[int] 

34 """How many hours to add to the end of the rule. 

35 (e.g. 1 = add 1 hour after, -1 = remove 1 hour before)""" 

36 

37 

38@dataclass(frozen=True) 

39class ShiftDateTimeRuleBaseAction(BatchingRuleBaseAction, ABC, str=False): 

40 """ShiftDateTimeRuleBaseAction will shift a day of week and time of day rule.""" 

41 

42 params: ShiftDateTimeRuleBaseActionParamsType 

43 

44 @override 

45 def apply(self, state: State, enable_prints: bool = True) -> State: 

46 timetable = state.timetable 

47 task_id = self.params["task_id"] 

48 add_to_start = self.params.get("add_to_start", 0) 

49 add_to_end = self.params.get("add_to_end", 0) 

50 

51 best_selector = timetable.get_longest_time_period_for_daily_hour_firing_rules( 

52 task_id, self.params["day"] 

53 ) 

54 

55 if best_selector is None: 

56 # TODO: Here we should add a new rule 

57 return state 

58 

59 # Modify Start / End 

60 _, lower_bound_selector, upper_bound_selector = best_selector 

61 batching_rule = lower_bound_selector.get_batching_rule_from_state(state) 

62 if batching_rule is None: 

63 return state 

64 lower_bound_rule = lower_bound_selector.get_firing_rule_from_state(state) 

65 upper_bound_rule = upper_bound_selector.get_firing_rule_from_state(state) 

66 if not rule_is_daily_hour(lower_bound_rule): 

67 return state 

68 if not rule_is_daily_hour(upper_bound_rule): 

69 return state 

70 new_lower_bound = lower_bound_rule.value - add_to_start 

71 new_upper_bound = upper_bound_rule.value + add_to_end 

72 # TODO: Think about what happens < 0 or > 24 

73 new_lower_bound_rule = replace(lower_bound_rule, value=new_lower_bound) 

74 new_upper_bound_rule = replace(upper_bound_rule, value=new_upper_bound) 

75 new_batching_rule = batching_rule.replace_firing_rule( 

76 lower_bound_selector, new_lower_bound_rule, skip_merge=True 

77 ).replace_firing_rule(upper_bound_selector, new_upper_bound_rule) 

78 timetable = timetable.replace_batching_rule(lower_bound_selector, new_batching_rule) 

79 

80 if enable_prints: 

81 info( 

82 f"\t\t>> Modifying rule {lower_bound_selector.id()} " 

83 f"to new time bounds: {new_lower_bound} -> {new_upper_bound}" 

84 ) 

85 

86 return replace(state, timetable=timetable) 

87 

88 def get_dominant_distribution(self, old_rule: BatchingRule) -> Distribution: 

89 """Find the size distribution with the highest probability.""" 

90 return max( 

91 old_rule.size_distrib, 

92 key=lambda distribution: distribution.value, 

93 ) 

94 

95 

96class ShiftDateTimeRuleAction(ShiftDateTimeRuleBaseAction): 

97 """ShiftDateTimeRuleAction will shift a day of week and time of day rule.""" 

98 

99 @override 

100 @staticmethod 

101 def rate_self(store: Store, input: Solution) -> RateSelfReturnType[ShiftDateTimeRuleBaseAction]: 

102 raise NotImplementedError("Not implemented")