96 lines
3.1 KiB
Python
Executable file
96 lines
3.1 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
from itertools import chain, combinations
|
|
from typing import Generator, Callable, Any
|
|
|
|
|
|
def zerosum_subgroups(
|
|
balances: dict[str, int],
|
|
tolerance: int = 0,
|
|
) -> Generator[tuple[str, ...], None, None]:
|
|
if len(balances) < 3:
|
|
return
|
|
for combination in combinations(balances, len(balances) - 2):
|
|
if abs(sum(balances[key] for key in combination)) <= tolerance:
|
|
yield combination
|
|
|
|
|
|
def split_dict(d: dict, condition: Callable[[Any, Any], bool]) -> tuple[dict, dict]:
|
|
first = {}
|
|
second = {}
|
|
for k, v in d.items():
|
|
if condition(k, v):
|
|
first[k] = v
|
|
else:
|
|
second[k] = v
|
|
return first, second
|
|
|
|
|
|
def solve_greedily(
|
|
balances: dict[str, int], tolerance: int = 0
|
|
) -> dict[tuple[str, str], int]:
|
|
creditors, debitors = split_dict(balances, lambda k, v: v > 0)
|
|
transactions = {}
|
|
for _ in range(len(balances)):
|
|
for debitor, debit_value in sorted(
|
|
debitors.items(),
|
|
key=lambda x: x[1],
|
|
):
|
|
for creditor, credit_value in sorted(
|
|
creditors.items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
):
|
|
sum_value = credit_value + debit_value
|
|
if abs(debit_value) <= credit_value:
|
|
del debitors[debitor]
|
|
creditors[creditor] = sum_value
|
|
transactions[debitor, creditor] = abs(debit_value)
|
|
else:
|
|
del creditors[creditor]
|
|
debitors[debitor] = sum_value
|
|
transactions[debitor, creditor] = credit_value
|
|
break
|
|
if all(
|
|
abs(value) <= tolerance
|
|
for value in chain(creditors.values(), debitors.values())
|
|
):
|
|
break
|
|
else:
|
|
raise ValueError("No solution within tolerance found")
|
|
return transactions
|
|
|
|
|
|
def solve(balances: dict[str, int], tolerance: int = 0) -> dict[tuple[str, str], int]:
|
|
possibilities = []
|
|
for subgroup in zerosum_subgroups(balances, tolerance):
|
|
balances_sub, balances_other = split_dict(balances, lambda k, v: k in subgroup)
|
|
if abs(sum(balances_other.values())) > tolerance:
|
|
continue
|
|
transactions_sub = solve(balances_sub, tolerance)
|
|
transactions_other = solve(balances_other, tolerance)
|
|
possibilities.append(transactions_sub | transactions_other)
|
|
if not possibilities:
|
|
possibilities.append(solve_greedily(balances, tolerance))
|
|
return min(possibilities, key=lambda x: len(x))
|
|
|
|
|
|
def perform_transfers(
|
|
balances: dict[str, int], transactions: dict[tuple[str, str], int]
|
|
) -> dict[str, int]:
|
|
balances = balances.copy()
|
|
for (sender, recipient), value in transactions.items():
|
|
balances[sender] += value
|
|
balances[recipient] -= value
|
|
return balances
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# should be possible with 3 transactions (A, B, C balance excactly)
|
|
balances = {
|
|
"A": 50,
|
|
"B": -30,
|
|
"C": -20,
|
|
"D": -40,
|
|
}
|
|
balances["E"] = -sum(balances.values())
|
|
print(solve(balances))
|