From c9164a080acf932abf159ca6b749154a2f57f340 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann Date: Tue, 19 Sep 2023 09:24:08 +0200 Subject: [PATCH] cleanup and simplify subset function --- splitbill.py | 53 +++++++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/splitbill.py b/splitbill.py index 4284f52..bf50732 100755 --- a/splitbill.py +++ b/splitbill.py @@ -2,27 +2,15 @@ from itertools import chain, combinations from typing import Generator -balances = { - # "A": -20, - # "B": 10, - # "C": 3, - # "D": -43, - # should be possible with 3 transactions (A, B, C balance excactly) - "A": 50, - "B": -30, - "C": -20, - "D": -40, -} -balances["E"] = -sum(balances.values()) - -def find_zerosum_subgroups( +def zerosum_subgroups( balances: dict[str, int], ) -> Generator[tuple[str, ...], None, None]: - for n in range(2, len(balances)): - for combination in combinations(balances, n): - if sum(balances[key] for key in combination) == 0: - yield combination + if len(balances) < 3: + return + for combination in combinations(balances, len(balances) - 2): + if sum(balances[key] for key in combination) == 0: + yield combination def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]: @@ -34,7 +22,7 @@ def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]: else: debitors[k] = v - txn = {} + transactions = {} while not all(value == 0 for value in chain(creditors.values(), debitors.values())): for debitor, debit_value in sorted(debitors.items(), key=lambda x: x[1]): for creditor, credit_value in sorted( @@ -44,25 +32,34 @@ def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]: if abs(debit_value) <= credit_value: del debitors[debitor] creditors[creditor] = sum_value - txn[debitor, creditor] = abs(debit_value) + transactions[debitor, creditor] = abs(debit_value) else: - debitors[debitor] = sum_value del creditors[creditor] - txn[debitor, creditor] = credit_value + debitors[debitor] = sum_value + transactions[debitor, creditor] = credit_value break - return txn + return transactions def solve(balances: dict[str, int]) -> dict[tuple[str, str], int]: possibilities = [] - for subgroup in find_zerosum_subgroups(balances): - txn_sub = solve({k: balances[k] for k in subgroup}) - txn_other = solve({k: balances[k] for k in balances if not k in subgroup}) - possibilities.append(txn_sub | txn_other) + for subgroup in zerosum_subgroups(balances): + transactions_sub = solve({k: balances[k] for k in subgroup}) + transactions_other = solve({k: balances[k] for k in balances if not k in subgroup}) + possibilities.append(transactions_sub | transactions_other) if not possibilities: possibilities.append(solve_greedily(balances)) return min(possibilities, key=lambda x: len(x)) -print(solve(balances)) +if __name__ == "__main__": + # should be possible with 3 transactions (A, B, C balance excactly) + balances = { + "A": 50, + "B": -30, + "C": -20, + "D": -40, + } + balances["E"] = -sum(balances.values()) + print(solve(balances))