Coverage for o2/actions/base_actions/modify_size_rule_base_action.py: 85%
52 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass, replace
4from typing_extensions import override
6from o2.actions.base_actions.base_action import (
7 RateSelfReturnType,
8)
9from o2.actions.base_actions.batching_rule_base_action import (
10 BatchingRuleBaseAction,
11 BatchingRuleBaseActionParamsType,
12)
13from o2.models.self_rating import RATING
14from o2.models.solution import Solution
15from o2.models.state import State
16from o2.models.timetable import (
17 BatchingRule,
18 Distribution,
19 rule_is_size,
20)
21from o2.store import Store
22from o2.util.logger import warn
25class ModifySizeRuleBaseActionParamsType(BatchingRuleBaseActionParamsType):
26 """Parameter for ModifySizeRuleBaseAction."""
28 size_increment: int
31@dataclass(frozen=True)
32class ModifySizeRuleBaseAction(BatchingRuleBaseAction, ABC, str=False):
33 """ModifySizeRuleBaseAction will modify the size of a BatchingRule.
35 This will effect the size distribution and the duration distribution of the rule,
36 as well as the firing rule.
37 """
39 params: ModifySizeRuleBaseActionParamsType
41 @override
42 def apply(self, state: State, enable_prints: bool = True) -> State:
43 timetable = state.timetable
44 rule_selector = self.params["rule"]
45 duration_fn = self.params.get("duration_fn", "1")
47 _, batching_rule = timetable.get_batching_rule(rule_selector)
48 if batching_rule is None:
49 warn(f"BatchingRule not found for {rule_selector}")
50 return state
52 firing_rule = batching_rule.get_firing_rule(rule_selector)
53 if firing_rule is None:
54 warn(f"FiringRule not found for {rule_selector}")
55 return state
57 if not rule_is_size(firing_rule):
58 return state
60 new_size = int(firing_rule.value) + self.params["size_increment"]
61 # We don't allow size 1, as that basically means no batching
62 if new_size < 1:
63 return state
65 new_firing_rule = replace(firing_rule, value=new_size)
67 new_timetable = state.timetable.replace_firing_rule(
68 rule_selector, new_firing_rule, duration_fn=duration_fn
69 )
71 return replace(state, timetable=new_timetable)
73 def get_dominant_distribution(self, old_rule: BatchingRule) -> Distribution:
74 """Find the size distribution with the highest probability."""
75 return max(
76 old_rule.size_distrib,
77 key=lambda distribution: distribution.value,
78 )
80 @override
81 @staticmethod
82 @abstractmethod
83 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
84 pass
86 @staticmethod
87 def get_default_rating() -> RATING:
88 """Return the default rating for this action."""
89 return RATING.MEDIUM
92class ModifySizeRuleAction(ModifySizeRuleBaseAction):
93 """ModifySizeRuleAction will modify the size of a BatchingRule."""
95 @override
96 @staticmethod
97 def rate_self(store: Store, input: Solution) -> RateSelfReturnType:
98 raise NotImplementedError("Not implemented")