cleanup and simplify subset function

This commit is contained in:
Nikolai Hartmann 2023-09-19 09:24:08 +02:00
parent cdca6df821
commit c9164a080a

View file

@ -2,27 +2,15 @@
from itertools import chain, combinations from itertools import chain, combinations
from typing import Generator from typing import Generator
balances = {
# "A": -20,
# "B": 10,
# "C": 3,
# "D": -43,
# should be possible with 3 transactions (A, B, C balance excactly)
"A": 50,
"B": -30,
"C": -20,
"D": -40,
}
balances["E"] = -sum(balances.values())
def zerosum_subgroups(
def find_zerosum_subgroups(
balances: dict[str, int], balances: dict[str, int],
) -> Generator[tuple[str, ...], None, None]: ) -> Generator[tuple[str, ...], None, None]:
for n in range(2, len(balances)): if len(balances) < 3:
for combination in combinations(balances, n): return
if sum(balances[key] for key in combination) == 0: for combination in combinations(balances, len(balances) - 2):
yield combination if sum(balances[key] for key in combination) == 0:
yield combination
def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]: def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]:
@ -34,7 +22,7 @@ def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]:
else: else:
debitors[k] = v debitors[k] = v
txn = {} transactions = {}
while not all(value == 0 for value in chain(creditors.values(), debitors.values())): while not all(value == 0 for value in chain(creditors.values(), debitors.values())):
for debitor, debit_value in sorted(debitors.items(), key=lambda x: x[1]): for debitor, debit_value in sorted(debitors.items(), key=lambda x: x[1]):
for creditor, credit_value in sorted( for creditor, credit_value in sorted(
@ -44,25 +32,34 @@ def solve_greedily(balances: dict[str, int]) -> dict[tuple[str, str], int]:
if abs(debit_value) <= credit_value: if abs(debit_value) <= credit_value:
del debitors[debitor] del debitors[debitor]
creditors[creditor] = sum_value creditors[creditor] = sum_value
txn[debitor, creditor] = abs(debit_value) transactions[debitor, creditor] = abs(debit_value)
else: else:
debitors[debitor] = sum_value
del creditors[creditor] del creditors[creditor]
txn[debitor, creditor] = credit_value debitors[debitor] = sum_value
transactions[debitor, creditor] = credit_value
break break
return txn return transactions
def solve(balances: dict[str, int]) -> dict[tuple[str, str], int]: def solve(balances: dict[str, int]) -> dict[tuple[str, str], int]:
possibilities = [] possibilities = []
for subgroup in find_zerosum_subgroups(balances): for subgroup in zerosum_subgroups(balances):
txn_sub = solve({k: balances[k] for k in subgroup}) transactions_sub = solve({k: balances[k] for k in subgroup})
txn_other = solve({k: balances[k] for k in balances if not k in subgroup}) transactions_other = solve({k: balances[k] for k in balances if not k in subgroup})
possibilities.append(txn_sub | txn_other) possibilities.append(transactions_sub | transactions_other)
if not possibilities: if not possibilities:
possibilities.append(solve_greedily(balances)) possibilities.append(solve_greedily(balances))
return min(possibilities, key=lambda x: len(x)) return min(possibilities, key=lambda x: len(x))
print(solve(balances)) if __name__ == "__main__":
# should be possible with 3 transactions (A, B, C balance excactly)
balances = {
"A": 50,
"B": -30,
"C": -20,
"D": -40,
}
balances["E"] = -sum(balances.values())
print(solve(balances))