package com.srbenoit.math.optimizers;

import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import com.srbenoit.log.LoggedObject;

/**
 * An optimizer that scans a region of solution space for local minima by partitioning the space,
 * evaluating the function being minimized on grid points, and initializing a minimizer at each
 * grid point that has a smaller value than all its neighbors.
 */
public class RegionalOptimizer extends LoggedObject {

    /** the optimizable function to be optimized */
    private final transient RegionalOptimizable function;

    /** the range being evaluated in each dimension */
    private final transient SearchRange ranges;

    /**
     * the number of grid subdivisions to use in each dimension (the product of these values will
     * give the number of function evaluations performed to identify starting points for
     * optimizations)
     */
    private final transient int[] numSubdivisions;

    /**
     * Constructs a new <code>RegionalOptimizer</code>.
     *
     * @param  theFunction    the optimizable function to be optimized
     * @param  theRanges      the range being evaluated in each dimension
     * @param  theNumSubdivs  the number of grid subdivisions to use in each dimension (the product
     *                        of these values will give the number of function evaluations
     *                        performed to identify starting points for optimizations)
     */
    public RegionalOptimizer(final RegionalOptimizable theFunction, final SearchRange theRanges,
        final int[] theNumSubdivs) {

        int dim;

        dim = theFunction.dimension();

        this.function = theFunction;
        this.ranges = theRanges;
        this.numSubdivisions = theNumSubdivs.clone();

        if (this.ranges.getDimension() != dim) {
            throw new IllegalArgumentException("Ranges dimension must match function dimension");
        }

        if (this.numSubdivisions.length != dim) {
            throw new IllegalArgumentException(
                "Length of number of Subdivisions array must match function dimension");
        }

        // Ensure ranges are in proper order
        for (int i = 0; i < dim; i++) {

            if (this.ranges.getMin(i) == this.ranges.getMax(i)) {
                throw new IllegalArgumentException("Invalid range in dimension " + i);
            }

            if (this.numSubdivisions[i] <= 0) {
                throw new IllegalArgumentException("Invalid number of subdivisions in dimension "
                    + i);
            }
        }
    }

    /**
     * Finds all local optimal parameter values within the region.
     *
     * @param   logLevel  the level of logging to do
     * @return  a list containing the located minima. Each is an array of doubles with the (local)
     *          best-fit values for the parameters
     */
    public List<double[]> optimize(final Level logLevel) {

        int dim;
        List<double[]> accumulator;
        Optimizer opt;
        double[] scale;
        int[] point;
        double[] guess;
        int numPoints;
        double value;
        double delta;
        boolean isMin;
        boolean isDuplicate;
        double dist;
        double[] optimum;
        List<double[]> result;
        double[] evaluations;
        int step;
        OutputValue output;

        LOG.setLevel(logLevel);

        accumulator = new ArrayList<double[]>(30);
        result = new ArrayList<double[]>(20);

        dim = this.function.dimension();
        guess = new double[dim];
        scale = new double[dim];
        point = new int[dim];

        // Determine the scale for each dimension and build an optimizer
        for (int i = 0; i < dim; i++) {
            scale[i] = (this.ranges.getMax(i) - this.ranges.getMin(i)) / this.numSubdivisions[i];
        }

        opt = new Optimizer(this.function, scale, true);

        // Now march through our grid points. At each grid point, evaluate
        // the function at the point and at the surrounding points. If the
        // function value at the point is the lowest, run an optimization
        // there.
        numPoints = 1;

        for (int i = 0; i < dim; i++) {
            numPoints *= this.numSubdivisions[i];
            point[i] = 0;
            guess[i] = this.ranges.getMin(i) + (scale[i] / 2);
        }

        LOG.log(Level.INFO, "A total of {0} points to test", numPoints);

        evaluations = new double[numPoints];

        // Evaluate the function at all grid points
        for (int j = 0; j < numPoints; j++) {

            output = this.function.evaluate(guess);
            evaluations[j] = output.isOutOfRange() ? Double.POSITIVE_INFINITY : output.getValue();

            // Move to the next point.
            for (int i = 0; i < dim; i++) {
                point[i]++;
                guess[i] += scale[i];

                if (point[i] < this.numSubdivisions[i]) {
                    break;
                }

                point[i] = 0;
                guess[i] = this.ranges.getMin(i) + (scale[i] / 2);
            }
        }

        // Reset point
        for (int i = 0; i < dim; i++) {
            point[i] = 0;
            guess[i] = this.ranges.getMin(i) + (scale[i] / 2);
        }

        LOG.info("Optimizing around minima ...");

        for (int j = 0; j < numPoints; j++) {

            // Evaluate at 'guess' and at all surrounding points. If any are
            // smaller than the value at 'guess', this is NOT a minimum.
            value = evaluations[j];

            isMin = true;
            step = 1;

            for (int i = 0; i < dim; i++) {

                if ((point[i] > 0) && (evaluations[j - step] < value)) {
                    isMin = false;

                    break;
                }

                if ((point[i] < (this.numSubdivisions[i] - 1))
                        && (evaluations[j + step] < value)) {
                    isMin = false;

                    break;
                }

                step *= this.numSubdivisions[i];
            }

            // If this point is a minimum, optimize here and store result
            if (isMin) {

                try {
                    optimum = opt.optimize(guess, logLevel);

                    // See if the optimum lies within the ranges
                    isMin = true;

                    for (int i = 0; i < dim; i++) {

                        if (optimum[i] < this.ranges.getMin(i)) {
                            isMin = false;

                            break;
                        }

                        if (optimum[i] > this.ranges.getMax(i)) {
                            isMin = false;

                            break;
                        }
                    }

                    if (isMin) {
                        accumulator.add(optimum);
                        this.function.foundAMinium(optimum);
                    }
                } catch (FailedToConvergeException e) {
                    LOG.log(Level.WARNING, "Failed to locate minima: {0}", e.getMessage());
                }
            }

            // Move to the next point.
            for (int i = 0; i < dim; i++) {
                point[i]++;
                guess[i] += scale[i];

                if (point[i] < this.numSubdivisions[i]) {
                    break;
                }

                point[i] = 0;
                guess[i] = this.ranges.getMin(i) + (scale[i] / 2);
            }
        }

        // Finally, we copy records into result omitting duplicates and results
        // that fall outside the allowed range.
        for (double[] minimum : accumulator) {
            isDuplicate = false;

            for (double[] test : result) {

                for (int i = 0; i < test.length; i++) {

                    if (test[i] < this.ranges.getMin(i)) {
                        isDuplicate = true;

                        break;
                    }

                    if (test[i] > this.ranges.getMax(i)) {
                        isDuplicate = true;

                        break;
                    }
                }

                if (isDuplicate) {
                    break;
                }

                dist = 0;

                for (int i = 0; i < test.length; i++) {
                    delta = (test[i] - minimum[i]) / scale[i];
                    dist += delta * delta;
                }

                dist = Math.sqrt(dist);

                if (dist < 0.001) {
                    isDuplicate = true;

                    break;
                }
            }

            if (!isDuplicate) {
                result.add(minimum);
            }
        }

        return result;
    }
}
