Coverage for o2/actions/base_actions/add_ready_large_wt_rule_base_action.py: 89%
64 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass, replace
3from typing import Literal
5from typing_extensions import NotRequired, Required, override
7from o2.actions.base_actions.base_action import (
8 BaseActionParamsType,
9 RateSelfReturnType,
10)
11from o2.actions.base_actions.batching_rule_base_action import (
12 BatchingRuleBaseAction,
13)
14from o2.models.rule_selector import RuleSelector
15from o2.models.self_rating import RATING
16from o2.models.solution import Solution
17from o2.models.state import State
18from o2.models.timetable import (
19 RULE_TYPE,
20 BatchingRule,
21 FiringRule,
22)
23from o2.store import Store
24from o2.util.logger import info
27class AddReadyLargeWTRuleBaseActionParamsType(BaseActionParamsType):
28 """Parameter for AddReadyLargeWTRuleBaseAction."""
30 task_id: Required[str]
31 waiting_time: Required[int]
32 type: Required[Literal[RULE_TYPE.LARGE_WT, RULE_TYPE.READY_WT]]
33 duration_fn: NotRequired[str]
36@dataclass(frozen=True)
37class AddReadyLargeWTRuleBaseAction(BatchingRuleBaseAction, ABC, str=False):
38 """AddReadyLargeWTRuleBaseAction will add a new day of week and time of day rule."""
40 params: AddReadyLargeWTRuleBaseActionParamsType
42 @override
43 def apply(self, state: State, enable_prints: bool = True) -> State:
44 timetable = state.timetable
45 task_id = self.params["task_id"]
46 waiting_time = self.params["waiting_time"]
47 type = self.params["type"]
48 duration_fn = self.params.get("duration_fn ", "1")
50 assert waiting_time <= 24 * 60 * 60, "Waiting time must be less than 24 hours"
52 existing_task_rules = timetable.get_batching_rules_for_task(task_id)
54 if not existing_task_rules:
55 new_batching_rule = BatchingRule.from_task_id(
56 task_id=task_id,
57 duration_fn=duration_fn,
58 firing_rules=[
59 FiringRule.gte(type, waiting_time),
60 FiringRule.lte(type, 24 * 60 * 60),
61 ],
62 )
63 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule])
65 # Find the rule to modify
66 rule = existing_task_rules[0]
67 index = timetable.batch_processing.index(rule)
69 for or_index, and_rules in enumerate(rule.firing_rules):
70 if len(and_rules) == 2:
71 rule_1 = and_rules[0]
72 rule_2 = and_rules[1]
74 if (
75 rule_1.attribute == type
76 and rule_2.attribute == type
77 and rule_1.is_gte
78 and rule_2.is_lte
79 and rule_2.value == 24 * 60 * 60
80 ):
81 # The rule is already smaller than the waiting time
82 # -> it would fire anyway
83 if rule_1.value <= waiting_time:
84 return state
86 # We can modify this existing rule
87 updated_firing_rule = FiringRule.gte(type, waiting_time)
88 selector = RuleSelector(batching_rule_task_id=task_id, firing_rule_index=(or_index, 0))
89 return replace(
90 state,
91 timetable=timetable.replace_firing_rule(
92 selector, updated_firing_rule, duration_fn=duration_fn
93 ),
94 )
96 new_or_rule = [
97 FiringRule.gte(type, waiting_time),
98 # This is needed, or else we get an error
99 FiringRule.lte(type, 24 * 60 * 60),
100 ]
102 updated_rule = rule.add_firing_rules(new_or_rule)
104 if enable_prints:
105 info(f"\t\t>> Adding rule for {task_id} with large_wt >= {waiting_time}")
107 return state.replace_timetable(
108 batch_processing=timetable.batch_processing[:index]
109 + [updated_rule]
110 + timetable.batch_processing[index + 1 :],
111 )
113 @override
114 @staticmethod
115 @abstractmethod
116 def rate_self(store: Store, input: "Solution") -> RateSelfReturnType["AddReadyLargeWTRuleBaseAction"]:
117 pass
119 @staticmethod
120 def get_default_rating() -> RATING:
121 """Return the default rating for this action."""
122 return RATING.MEDIUM
125class AddReadyLargeWTRuleAction(AddReadyLargeWTRuleBaseAction):
126 """AddReadyLargeWTRuleAction will add a new day of week and time of day rule."""
128 params: AddReadyLargeWTRuleBaseActionParamsType
130 @override
131 @staticmethod
132 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
133 raise NotImplementedError("rate_self is not implemented")