diff --git a/splitbill.py b/splitbill.py index 79da71f..d336e76 100755 --- a/splitbill.py +++ b/splitbill.py @@ -60,6 +60,8 @@ def solve(balances: dict[str, int], tolerance: int = 0) -> dict[tuple[str, str], possibilities = [] for subgroup in zerosum_subgroups(balances, tolerance): balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup) + if abs(sum(balances_other.values())) > tolerance: + continue transactions_sub = solve(balances_sub, tolerance) transactions_other = solve(balances_other, tolerance) possibilities.append(transactions_sub | transactions_other) diff --git a/test.py b/test.py new file mode 100644 index 0000000..ef2d151 --- /dev/null +++ b/test.py @@ -0,0 +1,19 @@ +import random +from splitbill import solve, perform_transfers +import pytest + + +@pytest.mark.parametrize("tolerance", [0, 1, 2, 10, 100]) +def test_random(tolerance): + balances = {str(k): random.randint(-200, 200) for k in range(7)} + balances["7"] = -sum(balances.values()) + balances = {k: v for k, v in balances.items() if v != 0} + transactions = solve(balances, tolerance) + assert abs(sum(perform_transfers(balances, transactions).values())) <= tolerance + + +def test_bug1(): + tolerance = 100 + balances = {'0': -39, '1': 139, '2': -54, '3': 99, '4': 13, '5': 175, '6': -173, '7': -160} + transactions = solve(balances, tolerance=tolerance) + assert abs(sum(perform_transfers(balances, transactions).values())) <= tolerance