#!/usr/bin/env python

# prithreshpng
# Threshold one image against another.
# Output a 1-bit per channel PNG, where l <= r

import collections

from array import array

import png

"""
prithreshpng file1.png file2.png

The `prithreshpng` tool compares channels from the input images.
"""

Image = collections.namedtuple("Image", "rows info")


class Error(Exception):
    pass

class ImageError(Error):
    pass


def thresh(out, args):
    """Compare input PNG files and threshold the right image
    against the left;
    the output image is 1 when l <= r, that is when the right
    image is at least as bright as the left.
    """

    paths = args.input

    if len(paths) != 2:
        raise Error("Required input is missing.")
    
    images = []

    for image_index, path in enumerate(paths):
        inp = png.cli_open(path)
        rows, info = png.Reader(file=inp).asDirect()[2:]
        rows = list(rows)
        image = Image(rows, info)
        images.append(image)

    planes = images[0].info["planes"]
    size = images[0].info["size"]
    for image in images:
        if image.info["planes"] != planes:
            raise ImageError("All images should have same number of channels")
        if image.info["size"] != size:
            raise ImageError("All images should have same size")

    size = images[0].info["size"]
    out_channels = planes

    # Values per row, of output image
    vpr = out_channels * size[0]

    def thresh_row_iter():
        """
        Yield each woven row in turn.
        """
        # The zip call creates an iterator that yields
        # a tuple with each element containing the next row
        # for each of the input images.
        for row_tuple in zip(*(image.rows for image in images)):
            # Compare values pairwise
            vs = zip(*row_tuple)
            # output row
            row = array("B", [v[0] <= v[1] for v in vs])
            yield row

    w = png.Writer(
        size[0],
        size[1],
        greyscale=True,
        alpha=False,
        bitdepth=1,
    )
    w.write(out, thresh_row_iter())


def main(argv=None):
    import argparse
    import itertools
    import sys

    if argv is None:
        argv = sys.argv
    argv = argv[1:]

    parser = argparse.ArgumentParser()
    parser.add_argument("input", nargs=2)
    args = parser.parse_args(argv)

    return thresh(png.binary_stdout(), args)


if __name__ == "__main__":
    main()
