Coverage for o2/actions/base_actions/add_size_rule_base_action.py: 59%
54 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass, replace
4from sympy import Symbol, lambdify
5from typing_extensions import Required, override
7from o2.actions.base_actions.base_action import (
8 BaseAction,
9 BaseActionParamsType,
10 RateSelfReturnType,
11)
12from o2.models.rule_selector import RuleSelector
13from o2.models.self_rating import RATING
14from o2.models.solution import Solution
15from o2.models.state import State
16from o2.models.timetable import (
17 RULE_TYPE,
18 BatchingRule,
19 Distribution,
20 FiringRule,
21)
22from o2.store import Store
23from o2.util.helper import select_variants
26class AddSizeRuleBaseActionParamsType(BaseActionParamsType):
27 """Parameter for ModifySizeRuleBaseAction."""
29 size: Required[int]
30 task_id: Required[str]
31 duration_fn: Required[str]
34@dataclass(frozen=True)
35class AddSizeRuleBaseAction(BaseAction, ABC, str=False):
36 """AddSizeRuleBaseAction will add a BatchingRule.
38 This will effect the size distribution and the duration distribution of the rule,
39 as well as the firing rule.
40 """
42 params: AddSizeRuleBaseActionParamsType
44 @override
45 def apply(self, state: State, enable_prints: bool = True) -> State:
46 new_size = self.params["size"]
47 task_id = self.params["task_id"]
48 duration_fn = self.params.get("duration_fn", "1")
50 duration_lambda = lambdify(Symbol("size"), duration_fn)
52 timetable = state.timetable
53 batching_rules = timetable.get_batching_rules_for_task(task_id)
55 if new_size < 1:
56 raise ValueError(f"Size must be at least 1, got {new_size}")
58 # Make sure the size is at least 2, because 1 just means no batching
59 batching_size = max(new_size, 2)
61 # Create fully fresh rule
62 if len(batching_rules) == 0:
63 new_batching_rule = BatchingRule.from_task_id(
64 task_id=task_id,
65 firing_rules=[FiringRule.gte(RULE_TYPE.SIZE, batching_size)],
66 duration_fn=duration_fn,
67 )
68 return state.replace_timetable(batch_processing=timetable.batch_processing + [new_batching_rule])
69 # Add OR Case to existing rule
70 else:
71 existing_rule = batching_rules[0]
72 # TODO: Check if a single size rule already exists, and if so, replace it
73 new_batching_rule = replace(
74 existing_rule,
75 # Integrate in existing size distribution
76 size_distrib=[
77 Distribution(key=str(batching_size), value=1.0),
78 ]
79 + [
80 size_distrib
81 for size_distrib in existing_rule.size_distrib
82 if size_distrib.key != str(batching_size)
83 ],
84 duration_distrib=[
85 Distribution(key=str(batching_size), value=duration_lambda(batching_size)),
86 ]
87 + [
88 duration_distrib
89 for duration_distrib in existing_rule.duration_distrib
90 if duration_distrib.key != str(batching_size)
91 ],
92 firing_rules=existing_rule.firing_rules
93 + [
94 [FiringRule.gte(RULE_TYPE.SIZE, batching_size)],
95 ],
96 )
97 new_timetable = timetable.replace_batching_rule(
98 RuleSelector.from_batching_rule(new_batching_rule),
99 new_batching_rule,
100 )
101 return replace(state, timetable=new_timetable)
103 @override
104 @staticmethod
105 @abstractmethod
106 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
107 pass
109 @staticmethod
110 def get_default_rating() -> RATING:
111 """Return the default rating for this action."""
112 return RATING.MEDIUM
115class AddSizeRuleAction(AddSizeRuleBaseAction):
116 """AddSizeRuleAction will add a BatchingRule."""
118 @override
119 @override
120 @staticmethod
121 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
122 task_ids = store.current_timetable.get_task_ids()
124 for task_id in select_variants(store, task_ids):
125 duration_fn = store.constraints.get_duration_fn_for_task(task_id)
126 yield (
127 RATING.VERY_LOW,
128 AddSizeRuleAction(
129 AddSizeRuleBaseActionParamsType(
130 task_id=task_id,
131 size=2,
132 duration_fn=duration_fn,
133 )
134 ),
135 )