Coverage for o2/actions/base_actions/base_action.py: 87%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1import functools 

2from abc import ABC, abstractmethod 

3from collections.abc import Generator 

4from dataclasses import dataclass 

5from typing import ( 

6 TYPE_CHECKING, 

7 Optional, 

8 TypeVar, 

9) 

10 

11from dataclass_wizard import JSONSerializable 

12from typing_extensions import TypedDict 

13 

14from o2.util.helper import hash_string 

15from o2.util.logger import warn 

16 

17if TYPE_CHECKING: 

18 from o2.models.self_rating import RATING 

19 from o2.models.solution import Solution 

20 from o2.store import State, Store 

21 

22 

23ActionT = TypeVar("ActionT", bound="BaseAction") 

24 

25ActionRatingTuple = tuple["RATING", Optional[ActionT]] 

26 

27RateSelfReturnType = Generator[ActionRatingTuple[ActionT], bool, Optional[ActionRatingTuple[ActionT]]] 

28 

29 

30class BaseActionParamsType(TypedDict): 

31 """Base type for all action parameters.""" 

32 

33 

34@dataclass(frozen=True) 

35class BaseAction(JSONSerializable, ABC, str=False): 

36 """Abstract class for all actions.""" 

37 

38 params: BaseActionParamsType 

39 

40 @abstractmethod 

41 def apply(self, state: "State", enable_prints: bool = True) -> "State": 

42 """Apply the action to the state, returning the new state.""" 

43 pass 

44 

45 @staticmethod 

46 @abstractmethod 

47 def rate_self(store: "Store", input: "Solution") -> RateSelfReturnType[ActionT]: 

48 """Generate a best set of parameters & self-evaluates this action.""" 

49 pass 

50 

51 def check_if_valid(self, store: "Store", mark_no_change_as_invalid: bool = False) -> bool: 

52 """Check if the action produces a valid state.""" 

53 try: 

54 new_state = self.apply(store.current_state, enable_prints=False) 

55 if mark_no_change_as_invalid and new_state == store.current_state: 

56 return False 

57 except Exception as e: 

58 warn(f"Error applying action {self}: {e}") 

59 return False 

60 return ( 

61 new_state.is_valid() 

62 and store.constraints.verify_legacy_constraints(new_state.timetable) 

63 and store.constraints.verify_batching_constraints(new_state.timetable) 

64 ) 

65 

66 def __str__(self) -> str: 

67 """Return a string representation of the action.""" 

68 return f"{self.__class__.__name__}({self.params})" 

69 

70 def __eq__(self, other: object) -> bool: 

71 """Check if two actions are equal.""" 

72 if not isinstance(other, BaseAction): 

73 return NotImplemented 

74 return self.__class__ == other.__class__ and self.params == other.params 

75 

76 @functools.cached_property 

77 def id(self) -> str: 

78 """Return a hash of the action.""" 

79 # Iterate over all params, sort them by name and concat them. 

80 return hash_string("|".join(f"{k}={v}" for k, v in sorted(self.params.items(), key=lambda x: x[0])))