Coverage for o2/actions/base_actions/shift_datetime_rule_base_action.py: 93%
56 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from abc import ABC
2from dataclasses import dataclass, replace
4from typing_extensions import NotRequired, override
6from o2.actions.base_actions.base_action import (
7 BaseActionParamsType,
8 RateSelfReturnType,
9)
10from o2.actions.base_actions.batching_rule_base_action import (
11 BatchingRuleBaseAction,
12)
13from o2.models.days import DAY
14from o2.models.solution import Solution
15from o2.models.state import State
16from o2.models.timetable import (
17 BatchingRule,
18 Distribution,
19 rule_is_daily_hour,
20)
21from o2.store import Store
22from o2.util.logger import info
25class ShiftDateTimeRuleBaseActionParamsType(BaseActionParamsType):
26 """Parameter for ShiftDateTimeRuleBaseAction."""
28 task_id: str
29 day: DAY
30 add_to_start: NotRequired[int]
31 """How many hours to add to the start of the rule.
32 (e.g. 1 = add 1 hour before, -1 = remove 1 hour after)"""
33 add_to_end: NotRequired[int]
34 """How many hours to add to the end of the rule.
35 (e.g. 1 = add 1 hour after, -1 = remove 1 hour before)"""
38@dataclass(frozen=True)
39class ShiftDateTimeRuleBaseAction(BatchingRuleBaseAction, ABC, str=False):
40 """ShiftDateTimeRuleBaseAction will shift a day of week and time of day rule."""
42 params: ShiftDateTimeRuleBaseActionParamsType
44 @override
45 def apply(self, state: State, enable_prints: bool = True) -> State:
46 timetable = state.timetable
47 task_id = self.params["task_id"]
48 add_to_start = self.params.get("add_to_start", 0)
49 add_to_end = self.params.get("add_to_end", 0)
51 best_selector = timetable.get_longest_time_period_for_daily_hour_firing_rules(
52 task_id, self.params["day"]
53 )
55 if best_selector is None:
56 # TODO: Here we should add a new rule
57 return state
59 # Modify Start / End
60 _, lower_bound_selector, upper_bound_selector = best_selector
61 batching_rule = lower_bound_selector.get_batching_rule_from_state(state)
62 if batching_rule is None:
63 return state
64 lower_bound_rule = lower_bound_selector.get_firing_rule_from_state(state)
65 upper_bound_rule = upper_bound_selector.get_firing_rule_from_state(state)
66 if not rule_is_daily_hour(lower_bound_rule):
67 return state
68 if not rule_is_daily_hour(upper_bound_rule):
69 return state
70 new_lower_bound = lower_bound_rule.value - add_to_start
71 new_upper_bound = upper_bound_rule.value + add_to_end
72 # TODO: Think about what happens < 0 or > 24
73 new_lower_bound_rule = replace(lower_bound_rule, value=new_lower_bound)
74 new_upper_bound_rule = replace(upper_bound_rule, value=new_upper_bound)
75 new_batching_rule = batching_rule.replace_firing_rule(
76 lower_bound_selector, new_lower_bound_rule, skip_merge=True
77 ).replace_firing_rule(upper_bound_selector, new_upper_bound_rule)
78 timetable = timetable.replace_batching_rule(lower_bound_selector, new_batching_rule)
80 if enable_prints:
81 info(
82 f"\t\t>> Modifying rule {lower_bound_selector.id()} "
83 f"to new time bounds: {new_lower_bound} -> {new_upper_bound}"
84 )
86 return replace(state, timetable=timetable)
88 def get_dominant_distribution(self, old_rule: BatchingRule) -> Distribution:
89 """Find the size distribution with the highest probability."""
90 return max(
91 old_rule.size_distrib,
92 key=lambda distribution: distribution.value,
93 )
96class ShiftDateTimeRuleAction(ShiftDateTimeRuleBaseAction):
97 """ShiftDateTimeRuleAction will shift a day of week and time of day rule."""
99 @override
100 @staticmethod
101 def rate_self(store: Store, input: Solution) -> RateSelfReturnType[ShiftDateTimeRuleBaseAction]:
102 raise NotImplementedError("Not implemented")