advent2023/neunzehn.py

91 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
from useful import *
xmasind = {'x':0,'m':1,'a':2,'s':3}
def makerule(string):
if ':' in string:
l, r = string.split(':')
return (l[0], l[1], int(l[2:])), r
return string
def flow(line):
name, right = line.split('{')
rules = right[:-1].split(',')
return name, [makerule(r) for r in rules]
def values(line):
nums = numbers(line)
return {
'x':nums[0],
'm':nums[1],
'a':nums[2],
's':nums[3],
}
def matches(rule, xmas):
var, comp, num = rule
if comp == '<':
return xmas[var] < num
if comp == '>':
return xmas[var] > num
assert False
def check(xmas, flow):
if flow == 'A':
return sum(xmas.values())
if flow == 'R':
return 0
#ic(flow)
#ic(flows[flow][:-1])
for test, target in flows[flow][:-1]:
if matches(test, xmas):
return check(xmas, target)
return check(xmas, flows[flow][-1])
top, bot = hfl(open(0))
flows = {k: v for (k, v) in [flow(l) for l in top]}
vals = [values(l) for l in bot]
print( sum(check(v, 'in') for v in vals))
def split_single(low, hi, comp, pivot):
if comp == '>': # n > pivot?
if low > pivot:
return [(low, hi), ()]
if hi > pivot:
return [(pivot + 1, hi), (low, pivot + 1)]
return [(), (low, hi)]
elif comp == '<': # should always be true now. n < pivot?
if hi < pivot:
return [(low, hi), ()]
if low < pivot:
return [(low, pivot), (pivot, hi)]
return [(), (low, hi)]
assert False
def split(xmas, test):
var, comp, pivot = test
index = xmasind[var]
low, hi = xmas[index]
return (xmas[:index] + [ss, ] + xmas[index + 1:] for ss in split_single(low, hi, comp, pivot))
def checkrange(xmas, rulestogo, indent=0):
if len(sys.argv) > 1 and sys.argv[1] == '-v':
print('%s%s%s' % (' ' * indent * 2, xmas, rulestogo), file=sys.stderr)
rule = rulestogo[0]
if rule == 'A':
return [[r - l for (l, r) in xmas], ]
if rule == 'R':
return [[0, 0, 0, 0], ]
if len(rulestogo) == 1:
return checkrange(xmas, flows[rule], indent + 1)
test, target = rule
yes, no = split(xmas, test)
return checkrange(yes, [target, ], indent + 1) + checkrange(no, rulestogo[1:], indent + 1)
# lower inclusive, upper exclusive makes everything easier
r = checkrange([(1, 4001), (1, 4001), (1, 4001), (1, 4001)], flows['in'])
print(np.asarray(r).prod(axis=1).sum())