aoc23/12/part2.py

85 lines
3.1 KiB
Python
Raw Normal View History

# This demonstrates classic dynamic programming and prioritizes explainability.
# Of course all of this would have been cleaner and would perform the same if
# we were to use recursion with memoization (e.g. @memoize.cache)
import sys
import collections
import dataclasses
@dataclasses.dataclass(frozen=True)
class Entry:
working: int
broken: int
def total(self):
return self.working + self.broken
def no_dot_between(l, start, end):
return not any(c for c in l[start+1 : end] if c == ".")
def process(springs, counts):
# entries[(i,j)] = how many ways if we have only first i springs and j nums,
# given if that one was chosen to be working or no.
entries = collections.defaultdict(lambda: Entry(working=0, broken=0))
# boundary conditions, when there's no spring and no count left
entries[(-1, -1)] = Entry(working=1, broken=0)
# boundary conditions, when there's no count left but still springs
for i in range(len(springs)):
entries[(i, -1)] = Entry(
working=(entries[(i - 1, -1)].working
if springs[i] in (".", "?")
else 0),
broken=0)
# building the rest of the table
for i in range(len(springs)):
for j in range(len(counts)):
prev_entry_if_working = entries[(i - 1, j)]
prev_working_entry_if_broken = (
entries[(i - counts[j], j - 1)].working
if no_dot_between(springs, i-counts[j], i)
else 0)
if springs[i] == ".":
entries[(i, j)] = Entry(
working=(prev_entry_if_working.working +
prev_entry_if_working.broken),
broken=0)
elif springs[i] == "#":
entries[(i, j)] = Entry(
working=0,
broken=prev_working_entry_if_broken)
elif springs[i] == '?':
entries[(i, j)] = Entry(
working=(prev_entry_if_working.working +
prev_entry_if_working.broken),
broken=prev_working_entry_if_broken)
# import code; code.interact(
# local=locals(),
# banner=(f"springs[{i}]={springs[i]}, counts[{j}]={counts[j]} "
# f".={prev_entry_if_working}, "
# f"#={prev_working_entry_if_broken}\n"
# f"--> ({i}, {j}) = {entries[(i, j)]}"),
# exitmsg='')
return entries[(len(springs) - 1, len(counts) - 1)].total()
if __name__ == "__main__":
with open(sys.argv[1], "rt") as f:
lines = f.readlines()
total = 0
for line in lines:
springs, counts = line.strip().split(" ")
counts = [int(c) for c in counts.split(",")]
# delete these next 2 lines to have part 1 back
springs = "?".join([springs]*5)
counts = counts*5
n = process(springs, counts)
# print(f"{n} <-- {springs}-{counts}")
total += n
print(f"total = {total}")