splitbill/splitbill.py
Nikolai Hartmann 9288c9dd56 fix tolerance
2023-09-19 12:28:02 +02:00

92 lines
3 KiB
Python
Executable file

#!/usr/bin/env python3
from itertools import chain, combinations
from typing import Generator, Callable, Any
def zerosum_subgroups(
balances: dict[str, int],
tolerance: int = 0,
) -> Generator[tuple[str, ...], None, None]:
if len(balances) < 3:
return
for combination in combinations(balances, len(balances) - 2):
if abs(sum(balances[key] for key in combination)) <= tolerance:
yield combination
def split_dict(d: dict, condition: Callable[[Any, Any], bool]) -> tuple[dict, dict]:
first = {}
second = {}
for k, v in d.items():
if condition(k, v):
first[k] = v
else:
second[k] = v
return first, second
def solve_greedily(
balances: dict[str, int], tolerance: int = 0
) -> dict[tuple[str, str], int]:
creditors, debitors = split_dict(balances, lambda k, v: v > 0)
transactions = {}
while not all(
abs(value) <= tolerance
for value in chain(creditors.values(), debitors.values())
):
for debitor, debit_value in sorted(
debitors.items(),
key=lambda x: x[1],
):
for creditor, credit_value in sorted(
creditors.items(),
key=lambda x: x[1],
reverse=True,
):
sum_value = credit_value + debit_value
if abs(debit_value) <= credit_value:
del debitors[debitor]
creditors[creditor] = sum_value
transactions[debitor, creditor] = abs(debit_value)
else:
del creditors[creditor]
debitors[debitor] = sum_value
transactions[debitor, creditor] = credit_value
break
return transactions
def solve(balances: dict[str, int], tolerance: int = 0) -> dict[tuple[str, str], int]:
possibilities = []
for subgroup in zerosum_subgroups(balances, tolerance):
balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup)
if abs(sum(balances_other.values())) > tolerance:
continue
transactions_sub = solve(balances_sub, tolerance)
transactions_other = solve(balances_other, tolerance)
possibilities.append(transactions_sub | transactions_other)
if not possibilities:
possibilities.append(solve_greedily(balances, tolerance))
return min(possibilities, key=lambda x: len(x))
def perform_transfers(
balances: dict[str, int], transactions: dict[tuple[str, str], int]
) -> dict[str, int]:
balances = balances.copy()
for (sender, recipient), value in transactions.items():
balances[sender] += value
balances[recipient] -= value
return balances
if __name__ == "__main__":
# should be possible with 3 transactions (A, B, C balance excactly)
balances = {
"A": 50,
"B": -30,
"C": -20,
"D": -40,
}
balances["E"] = -sum(balances.values())
print(solve(balances))