diff --git a/splitbill.py b/splitbill.py index c5d2b3c..79da71f 100755 --- a/splitbill.py +++ b/splitbill.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 from itertools import chain, combinations -from typing import Generator +from typing import Generator, Callable, Any def zerosum_subgroups( @@ -14,25 +14,34 @@ def zerosum_subgroups( yield combination +def split_dict(d: dict, condition: Callable[[Any, Any], bool]) -> tuple[dict, dict]: + first = {} + second = {} + for k, v in d.items(): + if condition(k, v): + first[k] = v + else: + second[k] = v + return first, second + + def solve_greedily( balances: dict[str, int], tolerance: int = 0 ) -> dict[tuple[str, str], int]: - creditors = {} - debitors = {} - for k, v in balances.items(): - if v > 0: - creditors[k] = v - else: - debitors[k] = v - + creditors, debitors = split_dict(balances, lambda k, v: v > 0) transactions = {} while not all( abs(value) <= tolerance for value in chain(creditors.values(), debitors.values()) ): - for debitor, debit_value in sorted(debitors.items(), key=lambda x: x[1]): + for debitor, debit_value in sorted( + debitors.items(), + key=lambda x: x[1], + ): for creditor, credit_value in sorted( - creditors.items(), key=lambda x: x[1], reverse=True + creditors.items(), + key=lambda x: x[1], + reverse=True, ): sum_value = credit_value + debit_value if abs(debit_value) <= credit_value: @@ -44,17 +53,15 @@ def solve_greedily( debitors[debitor] = sum_value transactions[debitor, creditor] = credit_value break - return transactions def solve(balances: dict[str, int], tolerance: int = 0) -> dict[tuple[str, str], int]: possibilities = [] for subgroup in zerosum_subgroups(balances, tolerance): - transactions_sub = solve({k: balances[k] for k in subgroup}, tolerance) - transactions_other = solve( - {k: balances[k] for k in balances if not k in subgroup}, tolerance - ) + balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup) + transactions_sub = solve(balances_sub, tolerance) + transactions_other = solve(balances_other, tolerance) possibilities.append(transactions_sub | transactions_other) if not possibilities: possibilities.append(solve_greedily(balances, tolerance))