#!/usr/bin/env python3 import sys from typing import Sequence, Set, Tuple from itertools import product import math # Grid = Sequence[Sequence[int]] Point = Tuple[int, int] class Grid: def __init__(self, grid: Sequence[Sequence[int]]): self.grid = grid self.rows = len(grid) self.cols = len(grid[0]) if self.rows else 0 def neighbors(self, point: Point) -> Set[Point]: i, j = point return { (r, c) for r, c in [(i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j)] if 0 <= r < self.rows and 0 <= c < self.cols } def __getitem__(self, key: Point) -> int: r, c = key return self.grid[r][c] def find_low_points(grid: Grid) -> Set[Point]: low_points = set() for i in range(grid.rows): for j in range(grid.cols): neighbors = grid.neighbors((i, j)) point = grid[i, j] if all(point < grid[r, c] for r, c in neighbors): low_points |= {(i, j)} return low_points def part1(grid: Grid): low_points = find_low_points(grid) print(sum(grid[r, c] + 1 for r, c in low_points)) def part2(grid: Grid): low_points = find_low_points(grid) def find_basin( grid: Grid, point: Tuple[int, int], exclude: Set[Point], ) -> Set[Tuple[int, int]]: if point in exclude: return set() exclude |= {point} value = grid[point] if value == 9: return set() neighbors = {n for n in grid.neighbors(point) if value < grid[n]} - exclude basin = {point} for n in neighbors: basin |= find_basin(grid, n, exclude) return basin sizes = [] for point in low_points: sizes += [len(find_basin(grid, point, set()))] sizes = list(reversed(sorted(sizes))) print(sizes[0] * sizes[1] * sizes[2]) grid = Grid([[int(c) for c in line.strip()] for line in sys.stdin]) print("Part 1") part1(grid) print("Part 2") part2(grid)