Coverage for o2/models/timetable/batching_rule.py: 84%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-16 11:18 +0000

1from collections import defaultdict 

2from dataclasses import asdict, dataclass, replace 

3from json import dumps 

4from typing import Literal, Optional, Union 

5 

6from dataclass_wizard import JSONWizard 

7from sympy import Symbol, lambdify 

8 

9from o2.models.days import DAY 

10from o2.models.legacy_constraints import WorkMasks 

11from o2.models.rule_selector import RuleSelector 

12from o2.models.settings import Settings 

13from o2.models.timetable.batch_type import BATCH_TYPE 

14from o2.models.timetable.distribution import Distribution 

15from o2.models.timetable.firing_rule import ( 

16 FiringRule, 

17 OrRules, 

18 rule_is_daily_hour, 

19 rule_is_size, 

20 rule_is_week_day, 

21) 

22from o2.models.timetable.rule_type import RULE_TYPE 

23from o2.models.timetable.time_period import TimePeriod 

24from o2.util.helper import hash_string 

25 

26 

27@dataclass(frozen=True) 

28class BatchingRule(JSONWizard): 

29 """Rules for when and how to batch tasks.""" 

30 

31 task_id: str 

32 type: BATCH_TYPE 

33 size_distrib: list[Distribution] 

34 duration_distrib: list[Distribution] 

35 firing_rules: OrRules 

36 

37 def __post_init__(self) -> None: 

38 """Post-init hook to create a normalized representation of the firing rules.""" 

39 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY: 

40 return 

41 # Create a normalized representation: 

42 # - Each inner list is sorted (ignoring its original order) 

43 # - The collection of rows is also sorted so that their order doesn't matter. 

44 # Using tuple of tuples makes it hashable. 

45 normalized = tuple(sorted(tuple(sorted(row)) for row in self.firing_rules)) # type: ignore 

46 object.__setattr__(self, "_normalized", normalized) 

47 

48 def __eq__(self, other: object) -> bool: 

49 """Check if two batching rules are equal.""" 

50 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY: 

51 return isinstance(other, BatchingRule) and ( 

52 self.task_id, 

53 self.type, 

54 self.size_distrib, 

55 self.duration_distrib, 

56 self.firing_rules, 

57 ) == ( 

58 other.task_id, 

59 other.type, 

60 other.size_distrib, 

61 other.duration_distrib, 

62 other.firing_rules, 

63 ) 

64 if not isinstance(other, BatchingRule): 

65 return NotImplemented 

66 

67 # TODO: This is due to some timetable objects being pickled before the normalization implementation. 

68 if "_normalized" not in self.__dict__: 

69 normalized = tuple(sorted(tuple(sorted(row)) for row in self.firing_rules)) # type: ignore 

70 object.__setattr__(self, "_normalized", normalized) 

71 

72 if "_normalized" not in other.__dict__: 

73 normalized = tuple(sorted(tuple(sorted(row)) for row in other.firing_rules)) # type: ignore 

74 object.__setattr__(other, "_normalized", normalized) 

75 

76 return ( 

77 self._normalized == other._normalized # type: ignore 

78 and self.task_id == other.task_id 

79 and self.type == other.type 

80 and self.size_distrib == other.size_distrib 

81 and self.duration_distrib == other.duration_distrib 

82 ) 

83 

84 def __hash__(self) -> int: 

85 """Hash the batching rule.""" 

86 if not Settings.CHECK_FOR_TIMETABLE_EQUALITY: 

87 return hash( 

88 ( 

89 self.task_id, 

90 self.type, 

91 self.size_distrib, 

92 self.duration_distrib, 

93 self.firing_rules, 

94 ) 

95 ) 

96 return hash( 

97 ( 

98 self.task_id, 

99 self.type, 

100 self.size_distrib, 

101 self.duration_distrib, 

102 self._normalized, # type: ignore 

103 ) 

104 ) 

105 

106 def id(self) -> str: 

107 """Generate a unique hash identifier for this batching rule. 

108 

109 Creates a string hash based on the serialized representation of this rule. 

110 """ 

111 return hash_string(str(dumps(asdict(self))).encode()) 

112 

113 def get_firing_rule_selectors(self, type: Optional[RULE_TYPE] = None) -> list["RuleSelector"]: 

114 """Get all firing rule selectors for the rule.""" 

115 return [ 

116 RuleSelector.from_batching_rule(self, (i, j)) 

117 for i, or_rules in enumerate(self.firing_rules) 

118 for j, rule in enumerate(or_rules) 

119 if type is None or rule.attribute == type 

120 ] 

121 

122 def get_time_period_for_daily_hour_firing_rules( 

123 self, 

124 ) -> dict[ 

125 tuple[Optional["RuleSelector"], "RuleSelector", "RuleSelector"], 

126 tuple[Optional[DAY], int, int], 

127 ]: 

128 """Get the time period for daily hour firing rules. 

129 

130 Returns a dictionary with the optional Rule Selector of the day, 

131 lower bound, and upper bound as the key, 

132 and the day, lower bound, and upper bound as the value. 

133 """ 

134 time_periods_by_or_index = {} 

135 for or_index, or_rules in enumerate(self.firing_rules): 

136 day_selector = None 

137 lower_bound_selector = None 

138 upper_bound_selector = None 

139 day = None 

140 lower_bound = float("-inf") 

141 upper_bound = float("inf") 

142 for and_rule_index, and_rule in enumerate(or_rules): 

143 if rule_is_week_day(and_rule): 

144 day_selector = RuleSelector.from_batching_rule(self, (or_index, and_rule_index)) 

145 day = and_rule.value 

146 if rule_is_daily_hour(and_rule): 

147 if and_rule.is_lt_or_lte: 

148 if upper_bound is None or and_rule.value < upper_bound: 

149 upper_bound = and_rule.value 

150 upper_bound_selector = RuleSelector.from_batching_rule( 

151 self, (or_index, and_rule_index) 

152 ) 

153 elif and_rule.is_gt_or_gte and (lower_bound is None or and_rule.value > lower_bound): 

154 lower_bound = and_rule.value 

155 lower_bound_selector = RuleSelector.from_batching_rule( 

156 self, (or_index, and_rule_index) 

157 ) 

158 time_periods_by_or_index[(day_selector, lower_bound_selector, upper_bound_selector)] = ( 

159 day, 

160 lower_bound, 

161 upper_bound, 

162 ) 

163 return time_periods_by_or_index 

164 

165 def get_firing_rule(self, rule_selector: "RuleSelector") -> Optional[FiringRule]: 

166 """Get a firing rule by rule selector.""" 

167 if rule_selector.firing_rule_index is None: 

168 return None 

169 or_index = rule_selector.firing_rule_index[0] 

170 and_index = rule_selector.firing_rule_index[1] 

171 if or_index >= len(self.firing_rules): 

172 return None 

173 if and_index >= len(self.firing_rules[or_index]): 

174 return None 

175 return self.firing_rules[or_index][and_index] 

176 

177 def can_remove_firing_rule(self, or_index: int, and_index: int) -> bool: 

178 """Check if a firing rule can be removed. 

179 

180 Checks: 

181 - We cannot remove a size rule from a DAILY_HOUR rule set. 

182 """ 

183 if or_index >= len(self.firing_rules): 

184 return False 

185 if and_index >= len(self.firing_rules[or_index]): 

186 return False 

187 if self.firing_rules[or_index][and_index].attribute == RULE_TYPE.SIZE: 

188 return all(rule.attribute != RULE_TYPE.DAILY_HOUR for rule in self.firing_rules[or_index]) 

189 return True 

190 

191 def remove_firing_rule(self, rule_selector: "RuleSelector") -> "Optional[BatchingRule]": 

192 """Remove a firing rule. Returns a new BatchingRule.""" 

193 assert rule_selector.firing_rule_index is not None 

194 or_index = rule_selector.firing_rule_index[0] 

195 and_index = rule_selector.firing_rule_index[1] 

196 if or_index >= len(self.firing_rules): 

197 return None 

198 if and_index >= len(self.firing_rules[or_index]): 

199 return None 

200 and_rules = self.firing_rules[or_index][:and_index] + self.firing_rules[or_index][and_index + 1 :] 

201 

202 if len(and_rules) == 0: 

203 or_rules = self.firing_rules[:or_index] + self.firing_rules[or_index + 1 :] 

204 else: 

205 or_rules = self.firing_rules[:or_index] + [and_rules] + self.firing_rules[or_index + 1 :] 

206 

207 if len(or_rules) == 0: 

208 return None 

209 return replace(self, firing_rules=or_rules) 

210 

211 def generate_distrib(self, duration_fn: str) -> "BatchingRule": 

212 """Regenerate the duration and size distributions. 

213 

214 Looks at every size rule and then will create a new duration distribution 

215 based on every size specified. 

216 E.g. if there is a size rule with <= 10, then it will create a new distribution for 1-10. 

217 

218 It will not touch the existing duration distribution, it will only add new distributions 

219 """ 

220 sizes = set() 

221 for and_rules in self.firing_rules: 

222 for rule in and_rules: 

223 if rule.attribute != RULE_TYPE.SIZE: 

224 continue 

225 if rule.is_eq: 

226 sizes.add(rule.value) 

227 elif rule.is_gte: 

228 sizes.add(range(rule.value, 101)) 

229 elif rule.is_gt: 

230 sizes.add(range(rule.value + 1, 101)) 

231 elif rule.is_lte: 

232 sizes.add(range(1, rule.value + 1)) 

233 elif rule.is_lt: 

234 sizes.add(range(1, rule.value)) 

235 

236 new_duration_distrib = [] 

237 new_size_distrib = [] 

238 duration_lambda = lambdify(Symbol("size"), duration_fn) 

239 for size in sizes: 

240 new_duration_distrib.append(Distribution(key=str(size), value=duration_lambda(size))) 

241 for size in sizes: 

242 new_size_distrib.append(Distribution(key=str(size), value=1)) 

243 # Special case: if 1 is not in sizes, remove any distribution that has 1 as a key 

244 # and add a new one with value 0 

245 if 1 not in sizes: 

246 new_size_distrib = [distribution for distribution in new_size_distrib if distribution.key != "1"] 

247 new_size_distrib.append(Distribution(key="1", value=0)) 

248 

249 return replace(self, duration_distrib=new_duration_distrib, size_distrib=new_size_distrib) 

250 

251 def replace_firing_rule( 

252 self, 

253 rule_selector: "RuleSelector", 

254 new_rule: FiringRule, 

255 skip_merge: bool = False, 

256 duration_fn: Optional[str] = None, 

257 ) -> "BatchingRule": 

258 """Replace a firing rule. Returns a new BatchingRule.""" 

259 assert rule_selector.firing_rule_index is not None 

260 or_index = rule_selector.firing_rule_index[0] 

261 and_index = rule_selector.firing_rule_index[1] 

262 if or_index >= len(self.firing_rules) or and_index >= len(self.firing_rules[or_index]): 

263 return self 

264 and_rules = ( 

265 self.firing_rules[or_index][:and_index] 

266 + [new_rule] 

267 + self.firing_rules[or_index][and_index + 1 :] 

268 ) 

269 

270 or_rules = self.firing_rules[:or_index] + [and_rules] + self.firing_rules[or_index + 1 :] 

271 

272 updated_batching_rule = replace(self, firing_rules=or_rules) 

273 if duration_fn is not None: 

274 updated_batching_rule = updated_batching_rule.generate_distrib(duration_fn) 

275 

276 if ( 

277 not skip_merge 

278 and new_rule.attribute == RULE_TYPE.WEEK_DAY 

279 or new_rule.attribute == RULE_TYPE.DAILY_HOUR 

280 ): 

281 return updated_batching_rule._generate_merged_datetime_firing_rules() 

282 return updated_batching_rule 

283 

284 def add_firing_rule(self, firing_rule: FiringRule) -> "BatchingRule": 

285 """Add a firing rule. Returns a new BatchingRule.""" 

286 updated_batching_rule = replace(self, firing_rules=self.firing_rules + [[firing_rule]]) 

287 if firing_rule.attribute == RULE_TYPE.WEEK_DAY or firing_rule.attribute == RULE_TYPE.DAILY_HOUR: 

288 return updated_batching_rule._generate_merged_datetime_firing_rules() 

289 return updated_batching_rule 

290 

291 def add_firing_rules(self, firing_rules: list[FiringRule]) -> "BatchingRule": 

292 """Add a list of firing rules. Returns a new BatchingRule.""" 

293 updated_batching_rule = replace(self, firing_rules=self.firing_rules + [firing_rules]) 

294 if any( 

295 rule.attribute == RULE_TYPE.WEEK_DAY or rule.attribute == RULE_TYPE.DAILY_HOUR 

296 for rule in firing_rules 

297 ): 

298 return updated_batching_rule._generate_merged_datetime_firing_rules() 

299 return updated_batching_rule 

300 

301 def _generate_merged_datetime_firing_rules(self) -> "BatchingRule": 

302 """Generate merged firing rules for datetime rules. 

303 

304 E.g. if there are multiple OR-Rules, that only contain daily hour rules, 

305 we can merge them into a single OR-Rule. Or if there are multiple OR-Rules, 

306 that only contain the same week day + daily hour rule, 

307 we can merge them into a single OR-Rule. 

308 """ 

309 or_rules_to_remove = [] 

310 work_mask = WorkMasks() 

311 size_dict: dict[Union[DAY, Literal["ALL"]], dict[int, int]] = defaultdict(dict) 

312 

313 for index, or_rules in enumerate(self.firing_rules): 

314 length = len(or_rules) 

315 if length > 4: 

316 continue 

317 daily_hour_gte_rule: Optional[FiringRule[int]] = None 

318 daily_hour_lt_rule: Optional[FiringRule[int]] = None 

319 week_day_rule: Optional[FiringRule[DAY]] = None 

320 size_rule: Optional[FiringRule[int]] = None 

321 

322 for rule in or_rules: 

323 if rule_is_daily_hour(rule) and rule.is_gte: 

324 daily_hour_gte_rule = rule 

325 elif rule_is_daily_hour(rule) and rule.is_lt: 

326 daily_hour_lt_rule = rule 

327 elif rule_is_week_day(rule) and rule.is_eq: 

328 week_day_rule = rule 

329 elif rule_is_size(rule) and rule.is_gt_or_gte: 

330 size_rule = rule 

331 if daily_hour_gte_rule is None or daily_hour_lt_rule is None: 

332 continue 

333 if length == 4 and (size_rule is None or week_day_rule is None): 

334 continue 

335 if length == 3 and (week_day_rule is None and size_rule is None): 

336 continue 

337 if not week_day_rule: 

338 work_mask = work_mask.set_hour_range_for_every_day( 

339 daily_hour_gte_rule.value, 

340 daily_hour_lt_rule.value, 

341 ) 

342 if size_rule: 

343 size_dict["ALL"][daily_hour_gte_rule.value] = max( 

344 size_dict["ALL"].get(daily_hour_gte_rule.value, 0), 

345 size_rule.value, 

346 ) 

347 else: 

348 work_mask = work_mask.set_hour_range_for_day( 

349 week_day_rule.value, 

350 daily_hour_gte_rule.value, 

351 daily_hour_lt_rule.value, 

352 ) 

353 if size_rule: 

354 size_dict[week_day_rule.value][daily_hour_gte_rule.value] = max( 

355 size_dict[week_day_rule.value].get(daily_hour_gte_rule.value, 0), 

356 size_rule.value, 

357 ) 

358 or_rules_to_remove.append(index) 

359 new_or_rules = [] 

360 for day in DAY: 

361 periods = TimePeriod.from_bitmask(work_mask.get(day), day) 

362 for period in periods: 

363 max_size = self._find_max_size(size_dict, period) 

364 rules = [ 

365 FiringRule.eq(RULE_TYPE.WEEK_DAY, day), 

366 FiringRule.gte(RULE_TYPE.DAILY_HOUR, period.begin_time_hour), 

367 FiringRule.lt(RULE_TYPE.DAILY_HOUR, period.end_time_hour), 

368 ] 

369 if max_size > 0: 

370 rules.append(FiringRule.gte(RULE_TYPE.SIZE, max_size)) 

371 new_or_rules.append(rules) 

372 return replace( 

373 self, 

374 firing_rules=new_or_rules 

375 + [ 

376 or_rules 

377 for index, or_rules in enumerate(self.firing_rules) 

378 if index not in or_rules_to_remove 

379 ], 

380 ) 

381 

382 def _find_max_size( 

383 self, size_dict: dict[Union[DAY, Literal["ALL"]], dict[int, int]], period: TimePeriod 

384 ) -> int: 

385 all_entries = size_dict.get("ALL", {}) 

386 day_entries = size_dict.get(period.from_, {}) 

387 

388 # Get maximum of all entries, that are between begin_time_hour and end_time_hour 

389 return max( 

390 max(all_entries.get(entry, 0), day_entries.get(entry, 0)) 

391 for entry in range(period.begin_time_hour, period.end_time_hour) 

392 ) 

393 

394 def is_valid(self) -> bool: 

395 """Check if the timetable is valid. 

396 

397 Currently this will check: 

398 - if daily hour rules come after week day rules 

399 - if there are no duplicate daily hour rules 

400 - if there is more than 1 (single) size rule 

401 """ 

402 has_single_size_rule = False 

403 for and_rules in self.firing_rules: 

404 # OR rules should not be duplicated 

405 largest_smaller_than_time = None 

406 smallest_larger_than_time = None 

407 # Duplicate rules are not allowed 

408 if self.firing_rules.count(and_rules) > 1: 

409 return False 

410 if len(and_rules) == 0: 

411 # Empty AND rules are not allowed 

412 return False 

413 if len(and_rules) == 1 and rule_is_size(and_rules[0]) and and_rules[0].is_gte: 

414 if has_single_size_rule: 

415 return False 

416 has_single_size_rule = True 

417 has_daily_hour_rule = False 

418 for rule in and_rules: 

419 if and_rules.count(rule) > 1: 

420 return False 

421 if rule_is_daily_hour(rule): 

422 if rule.is_lt_or_lte and ( 

423 largest_smaller_than_time is None or rule.value > largest_smaller_than_time 

424 ): 

425 largest_smaller_than_time = rule.value 

426 elif rule.is_gt_or_gte and ( 

427 smallest_larger_than_time is None or rule.value < smallest_larger_than_time 

428 ): 

429 smallest_larger_than_time = rule.value 

430 has_daily_hour_rule = True 

431 if rule_is_week_day(rule) and has_daily_hour_rule: 

432 return False 

433 

434 if ( 

435 largest_smaller_than_time is not None 

436 and smallest_larger_than_time is not None 

437 and smallest_larger_than_time >= largest_smaller_than_time 

438 ): 

439 return False 

440 

441 return True 

442 

443 @staticmethod 

444 def from_task_id( 

445 task_id: str, 

446 type: BATCH_TYPE = BATCH_TYPE.PARALLEL, 

447 firing_rules: list[FiringRule] = [], # noqa: B006 

448 size: Optional[int] = None, 

449 duration_fn: Optional[str] = None, 

450 ) -> "BatchingRule": 

451 """Create a BatchingRule from a task id. 

452 

453 NOTE: Setting `size` to a value will limit the new rule to only 

454 this size. You can omit it, to support batches up to 50. 

455 TODO: Get limit from constraints 

456 """ 

457 duration_lambda = lambdify(Symbol("size"), duration_fn if duration_fn else "size") 

458 size_distrib = ([Distribution(key=str(1), value=0.0)] if size != 1 else []) + ( 

459 [Distribution(key=str(new_size), value=1.0) for new_size in range(2, 50)] 

460 if size is None 

461 else [Distribution(key=str(size), value=1.0)] 

462 ) 

463 duration_distrib = ( 

464 [Distribution(key=str(new_size), value=duration_lambda(new_size)) for new_size in range(1, 50)] 

465 if size is None 

466 else [Distribution(key=str(size), value=duration_lambda(size))] 

467 ) 

468 return BatchingRule( 

469 task_id=task_id, 

470 type=type, 

471 size_distrib=size_distrib, 

472 duration_distrib=duration_distrib, 

473 firing_rules=[firing_rules], 

474 )