fix tolerance

This commit is contained in:
Nikolai Hartmann 2023-09-19 12:28:02 +02:00
parent 798d068ec0
commit 9288c9dd56
2 changed files with 21 additions and 0 deletions

View file

@ -60,6 +60,8 @@ def solve(balances: dict[str, int], tolerance: int = 0) -> dict[tuple[str, str],
possibilities = [] possibilities = []
for subgroup in zerosum_subgroups(balances, tolerance): for subgroup in zerosum_subgroups(balances, tolerance):
balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup) balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup)
if abs(sum(balances_other.values())) > tolerance:
continue
transactions_sub = solve(balances_sub, tolerance) transactions_sub = solve(balances_sub, tolerance)
transactions_other = solve(balances_other, tolerance) transactions_other = solve(balances_other, tolerance)
possibilities.append(transactions_sub | transactions_other) possibilities.append(transactions_sub | transactions_other)

19
test.py Normal file
View file

@ -0,0 +1,19 @@
import random
from splitbill import solve, perform_transfers
import pytest
@pytest.mark.parametrize("tolerance", [0, 1, 2, 10, 100])
def test_random(tolerance):
balances = {str(k): random.randint(-200, 200) for k in range(7)}
balances["7"] = -sum(balances.values())
balances = {k: v for k, v in balances.items() if v != 0}
transactions = solve(balances, tolerance)
assert abs(sum(perform_transfers(balances, transactions).values())) <= tolerance
def test_bug1():
tolerance = 100
balances = {'0': -39, '1': 139, '2': -54, '3': 99, '4': 13, '5': 175, '6': -173, '7': -160}
transactions = solve(balances, tolerance=tolerance)
assert abs(sum(perform_transfers(balances, transactions).values())) <= tolerance