Coverage for o2/models/timetable/batching_rule.py: 84%
238 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-05-16 11:18 +0000
1from collections import defaultdict
2from dataclasses import asdict, dataclass, replace
3from json import dumps
4from typing import Literal, Optional, Union
6from dataclass_wizard import JSONWizard
7from sympy import Symbol, lambdify
9from o2.models.days import DAY
10from o2.models.legacy_constraints import WorkMasks
11from o2.models.rule_selector import RuleSelector
12from o2.models.settings import Settings
13from o2.models.timetable.batch_type import BATCH_TYPE
14from o2.models.timetable.distribution import Distribution
15from o2.models.timetable.firing_rule import (
16 FiringRule,
17 OrRules,
18 rule_is_daily_hour,
19 rule_is_size,
20 rule_is_week_day,
21)
22from o2.models.timetable.rule_type import RULE_TYPE
23from o2.models.timetable.time_period import TimePeriod
24from o2.util.helper import hash_string
27@dataclass(frozen=True)
28class BatchingRule(JSONWizard):
29 """Rules for when and how to batch tasks."""
31 task_id: str
32 type: BATCH_TYPE
33 size_distrib: list[Distribution]
34 duration_distrib: list[Distribution]
35 firing_rules: OrRules
37 def __post_init__(self) -> None:
38 """Post-init hook to create a normalized representation of the firing rules."""
39 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY:
40 return
41 # Create a normalized representation:
42 # - Each inner list is sorted (ignoring its original order)
43 # - The collection of rows is also sorted so that their order doesn't matter.
44 # Using tuple of tuples makes it hashable.
45 normalized = tuple(sorted(tuple(sorted(row)) for row in self.firing_rules)) # type: ignore
46 object.__setattr__(self, "_normalized", normalized)
48 def __eq__(self, other: object) -> bool:
49 """Check if two batching rules are equal."""
50 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY:
51 return isinstance(other, BatchingRule) and (
52 self.task_id,
53 self.type,
54 self.size_distrib,
55 self.duration_distrib,
56 self.firing_rules,
57 ) == (
58 other.task_id,
59 other.type,
60 other.size_distrib,
61 other.duration_distrib,
62 other.firing_rules,
63 )
64 if not isinstance(other, BatchingRule):
65 return NotImplemented
67 # TODO: This is due to some timetable objects being pickled before the normalization implementation.
68 if "_normalized" not in self.__dict__:
69 normalized = tuple(sorted(tuple(sorted(row)) for row in self.firing_rules)) # type: ignore
70 object.__setattr__(self, "_normalized", normalized)
72 if "_normalized" not in other.__dict__:
73 normalized = tuple(sorted(tuple(sorted(row)) for row in other.firing_rules)) # type: ignore
74 object.__setattr__(other, "_normalized", normalized)
76 return (
77 self._normalized == other._normalized # type: ignore
78 and self.task_id == other.task_id
79 and self.type == other.type
80 and self.size_distrib == other.size_distrib
81 and self.duration_distrib == other.duration_distrib
82 )
84 def __hash__(self) -> int:
85 """Hash the batching rule."""
86 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY:
87 return hash(
88 (
89 self.task_id,
90 self.type,
91 self.size_distrib,
92 self.duration_distrib,
93 self.firing_rules,
94 )
95 )
96 return hash(
97 (
98 self.task_id,
99 self.type,
100 self.size_distrib,
101 self.duration_distrib,
102 self._normalized, # type: ignore
103 )
104 )
106 def id(self) -> str:
107 """Generate a unique hash identifier for this batching rule.
109 Creates a string hash based on the serialized representation of this rule.
110 """
111 return hash_string(str(dumps(asdict(self))).encode())
113 def get_firing_rule_selectors(self, type: Optional[RULE_TYPE] = None) -> list["RuleSelector"]:
114 """Get all firing rule selectors for the rule."""
115 return [
116 RuleSelector.from_batching_rule(self, (i, j))
117 for i, or_rules in enumerate(self.firing_rules)
118 for j, rule in enumerate(or_rules)
119 if type is None or rule.attribute == type
120 ]
122 def get_time_period_for_daily_hour_firing_rules(
123 self,
124 ) -> dict[
125 tuple[Optional["RuleSelector"], "RuleSelector", "RuleSelector"],
126 tuple[Optional[DAY], int, int],
127 ]:
128 """Get the time period for daily hour firing rules.
130 Returns a dictionary with the optional Rule Selector of the day,
131 lower bound, and upper bound as the key,
132 and the day, lower bound, and upper bound as the value.
133 """
134 time_periods_by_or_index = {}
135 for or_index, or_rules in enumerate(self.firing_rules):
136 day_selector = None
137 lower_bound_selector = None
138 upper_bound_selector = None
139 day = None
140 lower_bound = float("-inf")
141 upper_bound = float("inf")
142 for and_rule_index, and_rule in enumerate(or_rules):
143 if rule_is_week_day(and_rule):
144 day_selector = RuleSelector.from_batching_rule(self, (or_index, and_rule_index))
145 day = and_rule.value
146 if rule_is_daily_hour(and_rule):
147 if and_rule.is_lt_or_lte:
148 if upper_bound is None or and_rule.value < upper_bound:
149 upper_bound = and_rule.value
150 upper_bound_selector = RuleSelector.from_batching_rule(
151 self, (or_index, and_rule_index)
152 )
153 elif and_rule.is_gt_or_gte and (lower_bound is None or and_rule.value > lower_bound):
154 lower_bound = and_rule.value
155 lower_bound_selector = RuleSelector.from_batching_rule(
156 self, (or_index, and_rule_index)
157 )
158 time_periods_by_or_index[(day_selector, lower_bound_selector, upper_bound_selector)] = (
159 day,
160 lower_bound,
161 upper_bound,
162 )
163 return time_periods_by_or_index
165 def get_firing_rule(self, rule_selector: "RuleSelector") -> Optional[FiringRule]:
166 """Get a firing rule by rule selector."""
167 if rule_selector.firing_rule_index is None:
168 return None
169 or_index = rule_selector.firing_rule_index[0]
170 and_index = rule_selector.firing_rule_index[1]
171 if or_index >= len(self.firing_rules):
172 return None
173 if and_index >= len(self.firing_rules[or_index]):
174 return None
175 return self.firing_rules[or_index][and_index]
177 def can_remove_firing_rule(self, or_index: int, and_index: int) -> bool:
178 """Check if a firing rule can be removed.
180 Checks:
181 - We cannot remove a size rule from a DAILY_HOUR rule set.
182 """
183 if or_index >= len(self.firing_rules):
184 return False
185 if and_index >= len(self.firing_rules[or_index]):
186 return False
187 if self.firing_rules[or_index][and_index].attribute == RULE_TYPE.SIZE:
188 return all(rule.attribute != RULE_TYPE.DAILY_HOUR for rule in self.firing_rules[or_index])
189 return True
191 def remove_firing_rule(self, rule_selector: "RuleSelector") -> "Optional[BatchingRule]":
192 """Remove a firing rule. Returns a new BatchingRule."""
193 assert rule_selector.firing_rule_index is not None
194 or_index = rule_selector.firing_rule_index[0]
195 and_index = rule_selector.firing_rule_index[1]
196 if or_index >= len(self.firing_rules):
197 return None
198 if and_index >= len(self.firing_rules[or_index]):
199 return None
200 and_rules = self.firing_rules[or_index][:and_index] + self.firing_rules[or_index][and_index + 1 :]
202 if len(and_rules) == 0:
203 or_rules = self.firing_rules[:or_index] + self.firing_rules[or_index + 1 :]
204 else:
205 or_rules = self.firing_rules[:or_index] + [and_rules] + self.firing_rules[or_index + 1 :]
207 if len(or_rules) == 0:
208 return None
209 return replace(self, firing_rules=or_rules)
211 def generate_distrib(self, duration_fn: str) -> "BatchingRule":
212 """Regenerate the duration and size distributions.
214 Looks at every size rule and then will create a new duration distribution
215 based on every size specified.
216 E.g. if there is a size rule with <= 10, then it will create a new distribution for 1-10.
218 It will not touch the existing duration distribution, it will only add new distributions
219 """
220 sizes = set()
221 for and_rules in self.firing_rules:
222 for rule in and_rules:
223 if rule.attribute != RULE_TYPE.SIZE:
224 continue
225 if rule.is_eq:
226 sizes.add(rule.value)
227 elif rule.is_gte:
228 sizes.add(range(rule.value, 101))
229 elif rule.is_gt:
230 sizes.add(range(rule.value + 1, 101))
231 elif rule.is_lte:
232 sizes.add(range(1, rule.value + 1))
233 elif rule.is_lt:
234 sizes.add(range(1, rule.value))
236 new_duration_distrib = []
237 new_size_distrib = []
238 duration_lambda = lambdify(Symbol("size"), duration_fn)
239 for size in sizes:
240 new_duration_distrib.append(Distribution(key=str(size), value=duration_lambda(size)))
241 for size in sizes:
242 new_size_distrib.append(Distribution(key=str(size), value=1))
243 # Special case: if 1 is not in sizes, remove any distribution that has 1 as a key
244 # and add a new one with value 0
245 if 1 not in sizes:
246 new_size_distrib = [distribution for distribution in new_size_distrib if distribution.key != "1"]
247 new_size_distrib.append(Distribution(key="1", value=0))
249 return replace(self, duration_distrib=new_duration_distrib, size_distrib=new_size_distrib)
251 def replace_firing_rule(
252 self,
253 rule_selector: "RuleSelector",
254 new_rule: FiringRule,
255 skip_merge: bool = False,
256 duration_fn: Optional[str] = None,
257 ) -> "BatchingRule":
258 """Replace a firing rule. Returns a new BatchingRule."""
259 assert rule_selector.firing_rule_index is not None
260 or_index = rule_selector.firing_rule_index[0]
261 and_index = rule_selector.firing_rule_index[1]
262 if or_index >= len(self.firing_rules) or and_index >= len(self.firing_rules[or_index]):
263 return self
264 and_rules = (
265 self.firing_rules[or_index][:and_index]
266 + [new_rule]
267 + self.firing_rules[or_index][and_index + 1 :]
268 )
270 or_rules = self.firing_rules[:or_index] + [and_rules] + self.firing_rules[or_index + 1 :]
272 updated_batching_rule = replace(self, firing_rules=or_rules)
273 if duration_fn is not None:
274 updated_batching_rule = updated_batching_rule.generate_distrib(duration_fn)
276 if (
277 not skip_merge
278 and new_rule.attribute == RULE_TYPE.WEEK_DAY
279 or new_rule.attribute == RULE_TYPE.DAILY_HOUR
280 ):
281 return updated_batching_rule._generate_merged_datetime_firing_rules()
282 return updated_batching_rule
284 def add_firing_rule(self, firing_rule: FiringRule) -> "BatchingRule":
285 """Add a firing rule. Returns a new BatchingRule."""
286 updated_batching_rule = replace(self, firing_rules=self.firing_rules + [[firing_rule]])
287 if firing_rule.attribute == RULE_TYPE.WEEK_DAY or firing_rule.attribute == RULE_TYPE.DAILY_HOUR:
288 return updated_batching_rule._generate_merged_datetime_firing_rules()
289 return updated_batching_rule
291 def add_firing_rules(self, firing_rules: list[FiringRule]) -> "BatchingRule":
292 """Add a list of firing rules. Returns a new BatchingRule."""
293 updated_batching_rule = replace(self, firing_rules=self.firing_rules + [firing_rules])
294 if any(
295 rule.attribute == RULE_TYPE.WEEK_DAY or rule.attribute == RULE_TYPE.DAILY_HOUR
296 for rule in firing_rules
297 ):
298 return updated_batching_rule._generate_merged_datetime_firing_rules()
299 return updated_batching_rule
301 def _generate_merged_datetime_firing_rules(self) -> "BatchingRule":
302 """Generate merged firing rules for datetime rules.
304 E.g. if there are multiple OR-Rules, that only contain daily hour rules,
305 we can merge them into a single OR-Rule. Or if there are multiple OR-Rules,
306 that only contain the same week day + daily hour rule,
307 we can merge them into a single OR-Rule.
308 """
309 or_rules_to_remove = []
310 work_mask = WorkMasks()
311 size_dict: dict[Union[DAY, Literal["ALL"]], dict[int, int]] = defaultdict(dict)
313 for index, or_rules in enumerate(self.firing_rules):
314 length = len(or_rules)
315 if length > 4:
316 continue
317 daily_hour_gte_rule: Optional[FiringRule[int]] = None
318 daily_hour_lt_rule: Optional[FiringRule[int]] = None
319 week_day_rule: Optional[FiringRule[DAY]] = None
320 size_rule: Optional[FiringRule[int]] = None
322 for rule in or_rules:
323 if rule_is_daily_hour(rule) and rule.is_gte:
324 daily_hour_gte_rule = rule
325 elif rule_is_daily_hour(rule) and rule.is_lt:
326 daily_hour_lt_rule = rule
327 elif rule_is_week_day(rule) and rule.is_eq:
328 week_day_rule = rule
329 elif rule_is_size(rule) and rule.is_gt_or_gte:
330 size_rule = rule
331 if daily_hour_gte_rule is None or daily_hour_lt_rule is None:
332 continue
333 if length == 4 and (size_rule is None or week_day_rule is None):
334 continue
335 if length == 3 and (week_day_rule is None and size_rule is None):
336 continue
337 if not week_day_rule:
338 work_mask = work_mask.set_hour_range_for_every_day(
339 daily_hour_gte_rule.value,
340 daily_hour_lt_rule.value,
341 )
342 if size_rule:
343 size_dict["ALL"][daily_hour_gte_rule.value] = max(
344 size_dict["ALL"].get(daily_hour_gte_rule.value, 0),
345 size_rule.value,
346 )
347 else:
348 work_mask = work_mask.set_hour_range_for_day(
349 week_day_rule.value,
350 daily_hour_gte_rule.value,
351 daily_hour_lt_rule.value,
352 )
353 if size_rule:
354 size_dict[week_day_rule.value][daily_hour_gte_rule.value] = max(
355 size_dict[week_day_rule.value].get(daily_hour_gte_rule.value, 0),
356 size_rule.value,
357 )
358 or_rules_to_remove.append(index)
359 new_or_rules = []
360 for day in DAY:
361 periods = TimePeriod.from_bitmask(work_mask.get(day), day)
362 for period in periods:
363 max_size = self._find_max_size(size_dict, period)
364 rules = [
365 FiringRule.eq(RULE_TYPE.WEEK_DAY, day),
366 FiringRule.gte(RULE_TYPE.DAILY_HOUR, period.begin_time_hour),
367 FiringRule.lt(RULE_TYPE.DAILY_HOUR, period.end_time_hour),
368 ]
369 if max_size > 0:
370 rules.append(FiringRule.gte(RULE_TYPE.SIZE, max_size))
371 new_or_rules.append(rules)
372 return replace(
373 self,
374 firing_rules=new_or_rules
375 + [
376 or_rules
377 for index, or_rules in enumerate(self.firing_rules)
378 if index not in or_rules_to_remove
379 ],
380 )
382 def _find_max_size(
383 self, size_dict: dict[Union[DAY, Literal["ALL"]], dict[int, int]], period: TimePeriod
384 ) -> int:
385 all_entries = size_dict.get("ALL", {})
386 day_entries = size_dict.get(period.from_, {})
388 # Get maximum of all entries, that are between begin_time_hour and end_time_hour
389 return max(
390 max(all_entries.get(entry, 0), day_entries.get(entry, 0))
391 for entry in range(period.begin_time_hour, period.end_time_hour)
392 )
394 def is_valid(self) -> bool:
395 """Check if the timetable is valid.
397 Currently this will check:
398 - if daily hour rules come after week day rules
399 - if there are no duplicate daily hour rules
400 - if there is more than 1 (single) size rule
401 """
402 has_single_size_rule = False
403 for and_rules in self.firing_rules:
404 # OR rules should not be duplicated
405 largest_smaller_than_time = None
406 smallest_larger_than_time = None
407 # Duplicate rules are not allowed
408 if self.firing_rules.count(and_rules) > 1:
409 return False
410 if len(and_rules) == 0:
411 # Empty AND rules are not allowed
412 return False
413 if len(and_rules) == 1 and rule_is_size(and_rules[0]) and and_rules[0].is_gte:
414 if has_single_size_rule:
415 return False
416 has_single_size_rule = True
417 has_daily_hour_rule = False
418 for rule in and_rules:
419 if and_rules.count(rule) > 1:
420 return False
421 if rule_is_daily_hour(rule):
422 if rule.is_lt_or_lte and (
423 largest_smaller_than_time is None or rule.value > largest_smaller_than_time
424 ):
425 largest_smaller_than_time = rule.value
426 elif rule.is_gt_or_gte and (
427 smallest_larger_than_time is None or rule.value < smallest_larger_than_time
428 ):
429 smallest_larger_than_time = rule.value
430 has_daily_hour_rule = True
431 if rule_is_week_day(rule) and has_daily_hour_rule:
432 return False
434 if (
435 largest_smaller_than_time is not None
436 and smallest_larger_than_time is not None
437 and smallest_larger_than_time >= largest_smaller_than_time
438 ):
439 return False
441 return True
443 @staticmethod
444 def from_task_id(
445 task_id: str,
446 type: BATCH_TYPE = BATCH_TYPE.PARALLEL,
447 firing_rules: list[FiringRule] = [], # noqa: B006
448 size: Optional[int] = None,
449 duration_fn: Optional[str] = None,
450 ) -> "BatchingRule":
451 """Create a BatchingRule from a task id.
453 NOTE: Setting `size` to a value will limit the new rule to only
454 this size. You can omit it, to support batches up to 50.
455 TODO: Get limit from constraints
456 """
457 duration_lambda = lambdify(Symbol("size"), duration_fn if duration_fn else "size")
458 size_distrib = ([Distribution(key=str(1), value=0.0)] if size != 1 else []) + (
459 [Distribution(key=str(new_size), value=1.0) for new_size in range(2, 50)]
460 if size is None
461 else [Distribution(key=str(size), value=1.0)]
462 )
463 duration_distrib = (
464 [Distribution(key=str(new_size), value=duration_lambda(new_size)) for new_size in range(1, 50)]
465 if size is None
466 else [Distribution(key=str(size), value=duration_lambda(size))]
467 )
468 return BatchingRule(
469 task_id=task_id,
470 type=type,
471 size_distrib=size_distrib,
472 duration_distrib=duration_distrib,
473 firing_rules=[firing_rules],
474 )