Coverage for o2/actions/base_actions/add_datetime_rule_base_action.py: 94%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass, replace 

3 

4from typing_extensions import Required, override 

5 

6from o2.actions.base_actions.base_action import ( 

7 BaseActionParamsType, 

8 RateSelfReturnType, 

9) 

10from o2.actions.base_actions.batching_rule_base_action import ( 

11 BatchingRuleBaseAction, 

12) 

13from o2.models.rule_selector import RuleSelector 

14from o2.models.self_rating import RATING 

15from o2.models.settings import Settings 

16from o2.models.solution import Solution 

17from o2.models.state import State 

18from o2.models.timetable import ( 

19 RULE_TYPE, 

20 BatchingRule, 

21 FiringRule, 

22) 

23from o2.models.timetable.time_period import TimePeriod 

24from o2.store import Store 

25from o2.util.logger import info 

26 

27 

28class AddDateTimeRuleBaseActionParamsType(BaseActionParamsType): 

29 """Parameter for AddDateTimeRuleBaseAction.""" 

30 

31 task_id: Required[str] 

32 time_period: Required[TimePeriod] 

33 duration_fn: Required[str] 

34 

35 

36@dataclass(frozen=True) 

37class AddDateTimeRuleBaseAction(BatchingRuleBaseAction, ABC, str=False): 

38 """AddDateTimeRuleBaseAction will add a new day of week and time of day rule.""" 

39 

40 params: AddDateTimeRuleBaseActionParamsType 

41 

42 @override 

43 def apply(self, state: State, enable_prints: bool = True) -> State: 

44 timetable = state.timetable 

45 task_id = self.params["task_id"] 

46 time_period = self.params["time_period"] 

47 duration_fn = self.params.get("duration_fn", None) 

48 

49 existing_task_rules = timetable.get_batching_rules_for_task(task_id) 

50 

51 new_or_rule = [ 

52 FiringRule.eq(RULE_TYPE.WEEK_DAY, time_period.from_), 

53 FiringRule.gte(RULE_TYPE.DAILY_HOUR, time_period.begin_time_hour), 

54 FiringRule.lt(RULE_TYPE.DAILY_HOUR, time_period.end_time_hour), 

55 ] 

56 if Settings.ADD_SIZE_RULE_TO_NEW_RULES: 

57 new_or_rule.append(FiringRule.gte(RULE_TYPE.SIZE, 2)) 

58 

59 if not existing_task_rules: 

60 # TODO: Allow combining rules, e.g. extending date range 

61 new_batching_rule = BatchingRule.from_task_id( 

62 task_id=task_id, 

63 firing_rules=new_or_rule, 

64 duration_fn=duration_fn, 

65 ) 

66 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule]) 

67 

68 # Find the rule to modify 

69 rule = existing_task_rules[0] 

70 updated_rule = rule.add_firing_rules(new_or_rule) 

71 

72 if enable_prints: 

73 info( 

74 f"\t\t>> Adding rule for {task_id} on {time_period.from_} " 

75 f"from {time_period.begin_time} to {time_period.end_time}" 

76 ) 

77 

78 return replace( 

79 state, 

80 timetable=timetable.replace_batching_rule( 

81 RuleSelector.from_batching_rule(rule), 

82 updated_rule, 

83 ), 

84 ) 

85 

86 @override 

87 @staticmethod 

88 @abstractmethod 

89 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

90 pass 

91 

92 @staticmethod 

93 def get_default_rating() -> RATING: 

94 """Return the default rating for this action.""" 

95 return RATING.MEDIUM 

96 

97 

98class AddDateTimeRuleAction(AddDateTimeRuleBaseAction): 

99 """AddDateTimeRuleAction will add a new day of week and time of day rule.""" 

100 

101 params: AddDateTimeRuleBaseActionParamsType 

102 

103 @override 

104 @override 

105 @staticmethod 

106 def rate_self(store: Store, input: Solution) -> RateSelfReturnType: 

107 raise NotImplementedError("rate_self is not implemented")