package com.srbenoit.microscopy;

import java.awt.image.BufferedImage;
import com.srbenoit.filter.AbstractFilter;
import com.srbenoit.filter.FilterException;
import com.srbenoit.filter.FilterInput;
import com.srbenoit.filter.FilterOutput;
import com.srbenoit.filter.FilterTreeExecutor;
import com.srbenoit.filter.Pipe;
import com.srbenoit.filter.items.ImageArrayPipeItem;

/**
 * A filter that compensates for slight motions in the video content, potentially increasing the
 * size of the frames to contain the shifted images.
 */
public class MotionCompensationFilter extends AbstractFilter {

    /** version number for serialization */
    private static final long serialVersionUID = -3712792142848530198L;

    /**
     * Constructs a new <code>MotionCompensationFilter</code>.
     */
    public MotionCompensationFilter() {

        super("Motion Compensation", MotionCompensationFilter.class.getName());

        this.inputs.add(new FilterInput(ImageArrayPipeItem.class, "Images to stabilize"));
        this.outputs.add(new FilterOutput(ImageArrayPipeItem.class, "Stabilized images",
                "stabilized_images"));
        makeRenderer();
    }

    /**
     * Duplicates the filter including all of its settings, but returns an independent object.
     *
     * @return  the duplicated object
     */
    @Override public AbstractFilter duplicate() {

        return new MotionCompensationFilter();
    }

    /**
     * Performs the filter operation.
     *
     * @param   executor  the <code>FilterTreeExecutor</code> that is executing the filter
     * @param   pipe      a pipe containing the input data items
     * @throws  FilterException  if the filter cannot complete
     */
    @Override public void filter(final FilterTreeExecutor executor, final Pipe pipe)
        throws FilterException {

        ImageArrayPipeItem input;
        ImageArrayPipeItem output;

        validateInputs(pipe);
        executor.indicateProgress(1);

        input = (ImageArrayPipeItem) pipe.get(this.inputs.get(0).getKey());
        executor.indicateProgress(2);

        // First two pipe items are passed through to the output - the image
        // item must be re-created since the number of Z planes changes
        output = new ImageArrayPipeItem(this.outputs.get(0).getKey(),
                "Motion-compensated images (PNG format)", pipe, input.getXLabel(),
                input.getYLabel(), input.getXSize(), 1, "png");
        pipe.add(output);
        executor.indicateProgress(3);

        runFilter(executor, input, output);

        executor.indicateProgress(80);

        if (!executor.isCancelled()) {
            pipe.save(executor);
        }

        executor.indicateProgress(100);
    }

    /**
     * Runs the filter, reading the source Metamorph TIF files and extracting an array of images,
     * the first dimension of which is time, and the second dimension of which is z plane.
     *
     * @param  executor  the <code>FilterTreeExecutor</code> that is executing the filter
     * @param  input     the images to process
     * @param  output    the image array in which to store the processed images
     */
    private void runFilter(final FilterTreeExecutor executor, final ImageArrayPipeItem input,
        final ImageArrayPipeItem output) {

        BufferedImage source;
        int width;
        int height;
        int[] bestDx;
        int[] bestDy;
        int minX;
        int maxX;
        int minY;
        int maxY;
        double leastSquares;
        int cumX;
        int cumY;
        double normalized;
        int newWidth;
        int newHeight;
        BufferedImage newImage;
        int testDx;
        int testDy;

        source = input.getImage(0, 0);
        width = source.getWidth();
        height = source.getHeight();

        // Now we take each pair of frames and compute the best (dx, dy)
        bestDx = new int[input.getXSize()];
        bestDy = new int[input.getXSize()];

        for (int t = 1; t < input.getXSize(); t++) {

            if (executor.isCancelled()) {
                break;
            }

            // Scan the range of possible (dx, dy) and for each, compute the
            // difference between images. If this difference is less than the
            // current least difference, this is the new best (dx, dy).
            leastSquares = Long.MAX_VALUE;

            for (int dx = -18; dx <= 18; dx += 3) {

                for (int dy = -18; dy <= 18; dy += 3) {
                    normalized = difference(input.getImage(t - 1, 0), input.getImage(t, 0), dx,
                            dy);

                    if (normalized < leastSquares) {
                        leastSquares = normalized;
                        bestDx[t] = dx;
                        bestDy[t] = dy;
                    }
                }
            }

            executor.indicateProgress(5 + (40 * t / input.getXSize()));

            testDx = bestDx[t];
            testDy = bestDy[t];

            for (int dx = testDx - 3; dx <= (testDx + 3); dx++) {

                for (int dy = testDy - 3; dy <= (testDy + 3); dy++) {

                    if ((dx == testDx) && (dy == testDy)) {
                        continue;
                    }

                    normalized = difference(input.getImage(t - 1, 0), input.getImage(t, 0), dx,
                            dy);

                    if (normalized < leastSquares) {
                        leastSquares = normalized;
                        bestDx[t] = dx;
                        bestDy[t] = dy;
                    }
                }
            }
        }

        // Now that we have the best DX, DY for each frame, compute the
        // window that the cumulative moves would occupy
        if (!executor.isCancelled()) {
            minX = 0;
            maxX = 0;
            minY = 0;
            maxY = 0;
            cumX = 0;
            cumY = 0;

            for (int t = 1; t < input.getXSize(); t++) {
                cumX += bestDx[t];
                cumY += bestDy[t];

                if (cumX > maxX) {
                    maxX = cumX;
                }

                if (cumX < minX) {
                    minX = cumX;
                }

                if (cumY > maxY) {
                    maxY = cumY;
                }

                if (cumY < minY) {
                    minY = cumY;
                }
            }

            // Now allocate new images and draw the original images at the proper
            // cumulative offsets
            cumX = -minX;
            cumY = -minY;
            newWidth = width + (maxX - minX);
            newHeight = height + (maxY - minY);

            // Ensure width and height are multiples of 16 pixels
            newWidth = ((newWidth + 15) / 16) * 16;
            newHeight = ((newHeight + 15) / 16) * 16;

            for (int x = 0; x < input.getXSize(); x++) {

                for (int y = 0; y < input.getYSize(); y++) {

                    if (executor.isCancelled()) {
                        break;
                    }

                    cumX += bestDx[x];
                    cumY += bestDy[x];

                    executor.indicateProgress(45 + (35 * x / input.getXSize()));

                    newImage = new BufferedImage(newWidth, newHeight, // NOPMD SRB
                            BufferedImage.TYPE_INT_RGB);
                    newImage.getGraphics().drawImage(input.getImage(x, y), cumX, cumY, null);
                    output.setImage(x, y, newImage);
                }
            }
        }
    }

    /**
     * Compute the mean difference between two images with the second offset by a particular
     * vector. This is computed as the average of the square of the difference in intensity between
     * the center 1/4 of the two images, after the offset has been applied.
     *
     * @param   first   the first image
     * @param   second  the second image
     * @param   xOff    the X offset to apply to the second image
     * @param   yOff    the T offset to apply to the second image
     * @return  the correlation (normalized square of intensity differences)
     */
    private double difference(final BufferedImage first, final BufferedImage second,
        final int xOff, final int yOff) {

        int minX;
        int maxX;
        int minY;
        int maxY;
        int avgX;
        int avgY;
        int rgb1;
        int rgb2;
        int diff;
        long total;

        // Compute range (in first image) of correlation
        minX = (xOff > 0) ? xOff : 0;
        minY = (yOff > 0) ? yOff : 0;
        maxX = (xOff < 0) ? (first.getWidth() + xOff) : first.getWidth();
        maxY = (yOff < 0) ? (first.getHeight() + yOff) : first.getHeight();

        // We correlate only the middle 1/4 of the image
        avgX = (minX + maxX) / 2;
        avgY = (minY + maxY) / 2;
        diff = maxX - minX;
        minX = avgX - (diff / 8);
        maxX = avgX + (diff / 8);
        diff = maxY - minY;
        minY = avgY - (diff / 8);
        maxY = avgY + (diff / 8);

        total = 0;

        for (int x = minX; x < maxX; x++) {

            for (int y = minY; y < maxY; y++) {
                rgb1 = first.getRGB(x, y);
                rgb2 = second.getRGB(x - xOff, y - yOff);
                diff = (rgb1 & 0x00FF) - (rgb2 & 0x00FF);
                total += diff * diff;
            }
        }

        return (double) total / ((maxY - minY) * (maxX - minX));
    }

    /**
     * Generates the string representation of the filter.
     *
     * @return  the string representation
     */
    @Override public String toString() {

        return "MotionCompensationFilter";
    }
}
