Coverage for o2/actions/base_actions/base_action.py: 87%
39 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1import functools
2from abc import ABC, abstractmethod
3from collections.abc import Generator
4from dataclasses import dataclass
5from typing import (
6 TYPE_CHECKING,
7 Optional,
8 TypeVar,
9)
11from dataclass_wizard import JSONSerializable
12from typing_extensions import TypedDict
14from o2.util.helper import hash_string
15from o2.util.logger import warn
17if TYPE_CHECKING:
18 from o2.models.self_rating import RATING
19 from o2.models.solution import Solution
20 from o2.store import State, Store
23ActionT = TypeVar("ActionT", bound="BaseAction")
25ActionRatingTuple = tuple["RATING", Optional[ActionT]]
27RateSelfReturnType = Generator[ActionRatingTuple[ActionT], bool, Optional[ActionRatingTuple[ActionT]]]
30class BaseActionParamsType(TypedDict):
31 """Base type for all action parameters."""
34@dataclass(frozen=True)
35class BaseAction(JSONSerializable, ABC, str=False):
36 """Abstract class for all actions."""
38 params: BaseActionParamsType
40 @abstractmethod
41 def apply(self, state: "State", enable_prints: bool = True) -> "State":
42 """Apply the action to the state, returning the new state."""
43 pass
45 @staticmethod
46 @abstractmethod
47 def rate_self(store: "Store", input: "Solution") -> RateSelfReturnType[ActionT]:
48 """Generate a best set of parameters & self-evaluates this action."""
49 pass
51 def check_if_valid(self, store: "Store", mark_no_change_as_invalid: bool = False) -> bool:
52 """Check if the action produces a valid state."""
53 try:
54 new_state = self.apply(store.current_state, enable_prints=False)
55 if mark_no_change_as_invalid and new_state == store.current_state:
56 return False
57 except Exception as e:
58 warn(f"Error applying action {self}: {e}")
59 return False
60 return (
61 new_state.is_valid()
62 and store.constraints.verify_legacy_constraints(new_state.timetable)
63 and store.constraints.verify_batching_constraints(new_state.timetable)
64 )
66 def __str__(self) -> str:
67 """Return a string representation of the action."""
68 return f"{self.__class__.__name__}({self.params})"
70 def __eq__(self, other: object) -> bool:
71 """Check if two actions are equal."""
72 if not isinstance(other, BaseAction):
73 return NotImplemented
74 return self.__class__ == other.__class__ and self.params == other.params
76 @functools.cached_property
77 def id(self) -> str:
78 """Return a hash of the action."""
79 # Iterate over all params, sort them by name and concat them.
80 return hash_string("|".join(f"{k}={v}" for k, v in sorted(self.params.items(), key=lambda x: x[0])))