package com.srbenoit.microscopy;

import java.awt.Graphics;
import java.awt.Point;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import com.srbenoit.filter.AbstractFilter;
import com.srbenoit.filter.FilterException;
import com.srbenoit.filter.FilterInput;
import com.srbenoit.filter.FilterOutput;
import com.srbenoit.filter.FilterTreeExecutor;
import com.srbenoit.filter.Pipe;
import com.srbenoit.filter.items.ImageArrayPipeItem;
import com.srbenoit.filter.items.ImagePoint;
import com.srbenoit.filter.items.PointSetArrayPipeItem;

/**
 * A filter that identifies extrema (maxima or minima) in the input image sequence.
 */
public class MaximaFinderFilter extends AbstractFilter {

    /** version number for serialization */
    private final static long serialVersionUID = 2417170213367289037L;

    /** size of blocks for extrema detection */
    private static final int EXTREMA_SIZE = 14;

    /** maximum permitted motion between frames */
    private static final int MAX_FRAME_MOTION = 8;

    /**
     * Constructs a new <code>MaximaFinderFilter</code>.
     */
    public MaximaFinderFilter() {

        super("Intensity Maxima Identifier", MaximaFinderFilter.class.getName());

        this.inputs.add(new FilterInput(ImageArrayPipeItem.class, "Source images"));
        this.outputs.add(new FilterOutput(ImageArrayPipeItem.class, "Images with maxima marked",
                "marked_images"));
        this.outputs.add(new FilterOutput(PointSetArrayPipeItem.class,
                "Points of maximal intensity", "maxima"));
        makeRenderer();
    }

    /**
     * Duplicates the filter including all of its settings, but returns an independent object.
     *
     * @return  the duplicated object
     */
    @Override public AbstractFilter duplicate() {

        return new MaximaFinderFilter();
    }

    /**
     * Performs the filter operation.
     *
     * @param   executor  the <code>FilterTreeExecutor</code> that is executing the filter
     * @param   pipe      a pipe containing the input data items
     * @throws  FilterException  if the filter cannot complete
     */
    @Override public void filter(final FilterTreeExecutor executor, final Pipe pipe)
        throws FilterException {

        ImageArrayPipeItem input;
        ImageArrayPipeItem output;
        PointSetArrayPipeItem maxima;

        validateInputs(pipe);
        executor.indicateProgress(1);

        input = (ImageArrayPipeItem) pipe.get(this.inputs.get(0).getKey());

        // Install a dummy image array to test for persisted data
        output = new ImageArrayPipeItem(this.outputs.get(0).getKey(),
                "Images with maxima marked (PNG format)", pipe, "t", "z", input.getXSize(),
                input.getYSize(), "png");
        pipe.add(output);
        executor.indicateProgress(2);

        maxima = new PointSetArrayPipeItem(this.outputs.get(1).getKey(), "Intensity maxima", pipe,
                input.getXSize(), input.getYSize());
        pipe.add(maxima);
        executor.indicateProgress(3);

        runFilter(executor, pipe, input, output, maxima);

        executor.indicateProgress(80);

        if (!executor.isCancelled()) {
            pipe.save(executor);
        }

        executor.indicateProgress(100);
    }

    /**
     * Runs the filter, reading the source Metamorph TIF files and extracting an array of images,
     * the first dimension of which is time, and the second dimension of which is z plane.
     *
     * @param  executor  the <code>FilterTreeExecutor</code> that is executing the filter
     * @param  pipe      a pipe containing the input data items
     * @param  input     the images to process
     * @param  output    the image array in which to store output images
     * @param  maxima    the point set array in which to store maxima that were found
     */
    private void runFilter(final FilterTreeExecutor executor, final Pipe pipe,
        final ImageArrayPipeItem input, final ImageArrayPipeItem output,
        final PointSetArrayPipeItem maxima) {

        int width;
        int height;
        BufferedImage orig;
        BufferedImage newImage;
        WritableRaster raster;
        List<Point> frameMaxima;
        int imgWidth;
        int imgHeight;
        Graphics grx;
        int xCoord;
        int yCoord;
        int[] pixel;
        boolean isMax;
        int maxDecrease;
        int deltaX;
        int deltaY;
        int[] test;
        File outfile;
        FileWriter writer;
        ImagePoint point;

        width = input.getXSize();
        height = input.getYSize();

        frameMaxima = new ArrayList<Point>(30);

        pixel = new int[3];
        test = new int[3];

        for (int x = 0; x < width; x++) {

            for (int y = 0; y < height; y++) {

                if (executor.isCancelled()) {
                    break;
                }

                executor.indicateProgress(5 + (35 * ((x * height) + y) / (width * height)));

                orig = input.getImage(x, y);
                newImage = new BufferedImage(orig.getWidth(), orig.getHeight(), // NOPMD SRB
                        BufferedImage.TYPE_INT_RGB);
                grx = newImage.getGraphics();
                grx.drawImage(orig, 0, 0, null);
                output.setImage(x, y, newImage);
                frameMaxima.clear();

                raster = newImage.getRaster();
                imgWidth = raster.getWidth();
                imgHeight = raster.getHeight();

                for (yCoord = EXTREMA_SIZE; yCoord < (imgHeight - EXTREMA_SIZE); yCoord++) {

                    for (xCoord = EXTREMA_SIZE; xCoord < (imgWidth - EXTREMA_SIZE); xCoord++) {

                        raster.getPixel(xCoord, yCoord, pixel);

                        // Only test points with intensity at least 80
                        if (pixel[0] < 80) {
                            continue;
                        }

                        isMax = true;
                        maxDecrease = 0;

outer1:
                        for (deltaY = -EXTREMA_SIZE; deltaY <= EXTREMA_SIZE; deltaY++) {

                            for (deltaX = -EXTREMA_SIZE; deltaX <= EXTREMA_SIZE; deltaX++) {
                                raster.getPixel(xCoord + deltaX, yCoord + deltaY, test);

                                if ((pixel[0] - test[0]) > maxDecrease) {
                                    maxDecrease = pixel[0] - test[0];
                                }

                                if (test[0] > pixel[0]) {
                                    isMax = false;

                                    break outer1;
                                }
                            }
                        }

                        // Do not consider a point if it has a neighbor that is black,
                        // or that is within 20 in intensity of all surrounding pixels
                        if (isMax && (maxDecrease > 20)) {

outer2:
                            for (deltaY = -2; deltaY <= 2; deltaY++) {

                                for (deltaX = -2; deltaX <= 2; deltaX++) {
                                    raster.getPixel(xCoord + deltaX, yCoord + deltaY, test);

                                    if (test[0] == 0) {
                                        isMax = false;

                                        break outer2;
                                    }
                                }
                            }

                            if (isMax) {
                                frameMaxima.add(new Point(xCoord, yCoord)); // NOPMD SRB
                            }
                        }
                    }
                }

                combineNearbyMaxima(frameMaxima);

                // Turn the maxima points yellow/green, but leave the red intensity equal to the
                // original point's intensity so we don't lose data
                for (Point p2 : frameMaxima) {

                    for (int dx = -1; dx <= 1; dx++) {

                        if (((p2.x + dx) < 0) || ((p2.x + dx) >= raster.getWidth())) {
                            continue;
                        }

                        for (int dy = -1; dy <= 1; dy++) {

                            if (((p2.y + dy) < 0) || ((p2.y + dy) >= raster.getHeight())) {
                                continue;
                            }

                            raster.getPixel(p2.x + dx, p2.y + dy, pixel);
                            pixel[1] = 255;
                            pixel[2] = 0;
                            raster.setPixel(p2.x + dx, p2.y + dy, pixel);
                        }
                    }

                    grx.fillOval((int) p2.getX() - 1, (int) p2.getY() - 1, 3, 3);
                    maxima.addPoint(x, y, new ImagePoint(p2.x, p2.y)); // NOPMD SRB
                }
            }
        }

        executor.indicateProgress(41);

        // At this point, maxima have been identified and stored in the PointSetArrayPipeItem, and
        // the maxima have been highlighted on the images.  The next step is to correlate cells
        // from one frame to the next and associate motion vectors with each cell position where
        // the cell moves from frame to frame.  At the same time, we compute the tissue vector for
        // the same location

        if (!executor.isCancelled()) {

            for (int x = 0; x < (width - 1); x++) {

                for (int y = 0; y < height; y++) {

                    if (executor.isCancelled()) {
                        break;
                    }

                    executor.indicateProgress(42
                        + (((x * height) + y) * 36 / ((width - 1) * height)));

                    doExtremaMotion(input, maxima, x, y);
                }
            }

            if (!executor.isCancelled()) {

                // Write the CSV file with the points where cells are in each frame
                outfile = new File(pipe.getDir(), "maxima_and_motion.csv");

                try {
                    writer = new FileWriter(outfile);

                    writer.write(
                        "time, plane, x, y, move x, move y, tissue move x, tissue move y\r\n");

                    for (int x = 0; x < (width - 1); x++) {

                        for (int y = 0; y < height; y++) {

                            for (int j = 0; j < maxima.getNumPoints(x, y); j++) {
                                point = maxima.getPoint(x, y, j);
                                writer.write((x + 1) + ", " + (y + 1) + ", " + point.getXPos()
                                    + ", " + point.getYPos() + ", " + point.getXVel() + ", "
                                    + point.getYVel() + ", " + point.getXAmbientVel() + ", "
                                    + point.getYAmbientVel() + "\r\n");
                            }
                        }
                    }

                    writer.close();
                } catch (IOException e) {
                    LOG.log(Level.WARNING, "Exception generating extremal motion", e);
                }
            }
        }

        executor.indicateProgress(79);
    }

    /**
     * Given a list of maxima identified in the image, we scan for any maxima within <code>
     * EXTREMA_SIZE</code> in any direction, and combine them into a single maxima.
     *
     * @param  maxima  the list of identified maxima
     */
    private void combineNearbyMaxima(final List<Point> maxima) {

        List<Point> nearby;
        Point point;
        int totX;
        int totY;

        nearby = new ArrayList<Point>(10);

        // Now collapse multiple points near a single point down
        for (int i = 0; i < maxima.size(); i++) {

            // Get a test point.
            point = maxima.get(i);
            nearby.clear();

            // Find all nearby points
            for (int j = i + 1; j < maxima.size(); j++) {

                if (point.distance(maxima.get(j)) < EXTREMA_SIZE) {
                    nearby.add(maxima.get(j));
                }
            }

            // No nearby points, so move on
            if (nearby.isEmpty()) {
                continue;
            }

            // Average all nearby points
            totX = (int) (point.getX());
            totY = (int) (point.getY());

            for (Point p2 : nearby) {
                totX += p2.getX();
                totY += p2.getY();
                maxima.remove(p2);
            }

            maxima.set(i, new Point(totX / (nearby.size() + 1), totY / (nearby.size() + 1))); // NOPMD SRB
        }
    }

    /**
     * Computes the motion vectors for one time point.
     *
     * @param  images     the images to process
     * @param  maxima     the point set array in which to store maxima that were found
     * @param  timeIndex  the index of the frame we are processing (we compute vectors from this
     *                    time point to the next time point, so this value will always be at least
     *                    two less than then length of the images array
     * @param  plane      the image plane we are processing
     */
    private void doExtremaMotion(final ImageArrayPipeItem images,
        final PointSetArrayPipeItem maxima, final int timeIndex, final int plane) {

        BufferedImage image1;
        BufferedImage image2;
        int count;
        List<ImagePoint> list1;
        List<ImagePoint> list2;
        int distX;
        int distY;
        double dist;
        float corr;
        float bestCorr;
        int which1;
        int which2;
        ImagePoint pt1;
        ImagePoint pt2;

        // Get the images we are comparing
        image1 = images.getImage(timeIndex, plane);
        image2 = images.getImage(timeIndex + 1, plane);

        // Get the maxima from each image into a temporary array
        count = maxima.getNumPoints(timeIndex, plane);
        list1 = new ArrayList<ImagePoint>(count);

        for (int i = 0; i < count; i++) {
            list1.add(maxima.getPoint(timeIndex, plane, i));
        }

        count = maxima.getNumPoints(timeIndex + 1, plane);
        list2 = new ArrayList<ImagePoint>(count);

        for (int i = 0; i < count; i++) {
            list2.add(maxima.getPoint(timeIndex + 1, plane, i));
        }

        while ((!list1.isEmpty()) && (!list2.isEmpty())) {
            bestCorr = 0.0f;
            which1 = -1;
            which2 = -1;

            for (int i = 0; i < list1.size(); i++) {

                for (int j = 0; j < list2.size(); j++) {
                    pt1 = list1.get(i);
                    pt2 = list2.get(j);

                    distX = pt2.getXPos() - pt1.getXPos();
                    distY = pt2.getYPos() - pt1.getYPos();
                    dist = Math.sqrt((distX * distX) + (distY * distY));

                    if (dist < MAX_FRAME_MOTION) {
                        corr = correlate(image1, pt1, image2, pt2);

                        if (corr > bestCorr) {
                            bestCorr = corr;
                            which1 = i;
                            which2 = j;
                        }
                    }
                }
            }

            // The best match is found - record the vector
            if (which1 == -1) {

                // No more points within acceptable distance, so quit
                break;
            }

            pt1 = list1.get(which1);
            pt2 = list2.get(which2);
            list1.remove(which1);
            list2.remove(which2);

            pt1.setVel(pt2.getXPos() - pt1.getXPos(), pt2.getYPos() - pt1.getYPos());

            // Now, correlate a region surrounding (but not including) the
            // maxima to see how the surrounding tissue moved between frames
            ambientMotion(image1, image2, pt1);
        }
    }

    /**
     * Correlates a point in one frame with a point in another frame.
     *
     * @param   image1  the first frame
     * @param   pt1     the point in the first frame
     * @param   image2  the second frame
     * @param   pt2     the point in the second frame
     * @return  the correlation coefficient
     */
    private float correlate(final BufferedImage image1, final ImagePoint pt1,
        final BufferedImage image2, final ImagePoint pt2) {

        WritableRaster ras1;
        WritableRaster ras2;
        int delta;
        int[] pixel1;
        int[] pixel2;

        ras1 = image1.getRaster();
        ras2 = image2.getRaster();
        delta = 0;
        pixel1 = new int[3];
        pixel2 = new int[3];

        for (int y = -EXTREMA_SIZE; y < EXTREMA_SIZE; y++) {

            for (int x = -EXTREMA_SIZE; x < EXTREMA_SIZE; x++) {
                ras1.getPixel(pt1.getXPos() + x, pt1.getYPos() + y, pixel1);
                ras2.getPixel(pt2.getXPos() + x, pt2.getYPos() + y, pixel2);

                if (pixel1[0] > pixel2[0]) {
                    delta += pixel1[0] - pixel2[0];
                } else {
                    delta += pixel2[0] - pixel1[0];
                }
            }
        }

        return (delta == 0) ? Float.MAX_VALUE : (1.0f / delta);
    }

    /**
     * Tests the regions surrounding a point in one frame against the region surrounding a point in
     * another frame to see how the ambient field moves
     *
     * @param  image1  the first frame to examine
     * @param  image2  the second frame to examine
     * @param  point   the point about which to detect ambient movement
     */
    private void ambientMotion(final BufferedImage image1, final BufferedImage image2,
        final ImagePoint point) {

        int bestDx;
        int bestDy;
        int width;
        int height;
        int scanWidth;
        int scanHeight;
        double leastSquares;
        double total;
        int count;
        double normalized;
        int pix1;
        int pix2;

        width = image1.getWidth();
        height = image2.getHeight();
        scanWidth = width / 12;
        scanHeight = height / 12;

        // THIS PRODUCES BAD OUTPUT - PROBLEM HERE SOMEWHERE

        bestDx = 0;
        bestDy = 0;
        leastSquares = Double.MAX_VALUE;

        for (int dx = -MAX_FRAME_MOTION; dx <= MAX_FRAME_MOTION; dx++) {

            for (int dy = -MAX_FRAME_MOTION; dy <= MAX_FRAME_MOTION; dy++) {

                total = 0;
                count = 0;

                for (int x = point.getXPos() - scanWidth; x <= (point.getXPos() + scanWidth);
                        x++) {

                    for (int y = point.getYPos() - scanHeight; y <= (point.getYPos() + scanHeight);
                            y++) {

                        if ((x > (point.getXPos() - MAX_FRAME_MOTION))
                                && (x < (point.getXPos() + MAX_FRAME_MOTION))
                                && (y > (point.getYPos() - MAX_FRAME_MOTION))
                                && (y < (point.getYPos() + MAX_FRAME_MOTION))) {
                            continue;
                        }

                        if ((x < 0) || (y < 0) || (x >= width) || (y >= height)) {
                            continue;
                        }

                        if (((x + dx) < 0) || ((y + dy) < 0) || ((x + dx) >= width)
                                || ((y + dy) >= height)) {
                            continue;
                        }

                        pix1 = image1.getRGB(x, y) & 0x00FF;
                        pix2 = image2.getRGB(x + dx, y + dy) & 0x00FF;

                        if ((pix1 == 0) || (pix2 == 0)) {
                            continue;
                        }

                        count++;
                        total += (pix2 - pix1) * (pix2 - pix1);
                    }
                }

                normalized = total / count;

                if (normalized < leastSquares) {
                    leastSquares = normalized;
                    bestDx = dx;
                    bestDy = dy;
                }
            }
        }

        point.setAmbientVel(bestDx, bestDy);
    }

    /**
     * Generates the string representation of the filter.
     *
     * @return  the string representation
     */
    @Override public String toString() {

        return "ExtremalFinderFilter";
    }
}
