# This demonstrates classic dynamic programming and prioritizes explainability. # Of course all of this would have been cleaner and would perform the same if # we were to use recursion with memoization (e.g. @memoize.cache) import sys import dataclasses @dataclasses.dataclass(frozen=True) class Entry: working: int broken: int def total(self): return self.working + self.broken def no_dot_between(l, start, end): return not any(c for c in l[start+1 : end] if c == ".") def process(springs, counts): # entries[(i,j)] = how many ways if we have only first i springs and j nums, # given if that one was chosen to be working or no. entries = {} # boundary conditions: when there's no spring and no count left, 1 choice entries[(-1, -1)] = Entry(working=1, broken=0) # boundary conditions: when there's no count left but still springs left for i in range(len(springs)): entries[(i, -1)] = Entry( working=(entries[(i - 1, -1)].working if springs[i] in (".", "?") else 0), broken=0) # boundary conditions: when there's no springs left but there's still count for j in range(len(counts)): entries[(-1, j)] = Entry(working=0, broken=0) # building the rest of the table for i in range(len(springs)): for j in range(len(counts)): prev_entry_if_working = entries[(i - 1, j)] prev_working_entry_if_broken = ( entries[(i-counts[j], j-1)].working if i-counts[j] >= -1 and no_dot_between(springs, i-counts[j], i) else 0) if springs[i] == ".": entries[(i, j)] = Entry( working=(prev_entry_if_working.working + prev_entry_if_working.broken), broken=0) elif springs[i] == "#": entries[(i, j)] = Entry( working=0, broken=prev_working_entry_if_broken) elif springs[i] == '?': entries[(i, j)] = Entry( working=(prev_entry_if_working.working + prev_entry_if_working.broken), broken=prev_working_entry_if_broken) # import code; code.interact( # local=locals(), # banner=(f"springs[{i}]={springs[i]}, counts[{j}]={counts[j]} " # f".={prev_entry_if_working}, " # f"#={prev_working_entry_if_broken}\n" # f"--> ({i}, {j}) = {entries[(i, j)]}"), # exitmsg='') return entries[(len(springs) - 1, len(counts) - 1)].total() if __name__ == "__main__": with open(sys.argv[1], "rt") as f: lines = f.readlines() total = 0 for line in lines: springs, counts = line.strip().split(" ") counts = [int(c) for c in counts.split(",")] # delete these next 2 lines to have part 1 back springs = "?".join([springs]*5) counts = counts*5 n = process(springs, counts) # print(f"{n} <-- {springs}-{counts}") total += n print(f"total = {total}")