package com.srbenoit.math.optimizers;

import java.util.logging.Level;
import com.srbenoit.log.LoggedObject;

/**
 * An optimizer based on the Nelder-Mead Simplex method to locate a local minimum in a function.
 */
public class Optimizer extends LoggedObject {

    /** maximum allowed iterations of the algorithm */
    public static final int MAXITERATIONS = 100000;

    /** fractional change in the evaluated function to stop algorithm */
    public static final double TOLERANCE = 1E-14;

    /** the optimizable function to be optimized */
    private final transient Optimizable function;

    /** the dimension of the optimizable function */
    private final transient int dimension;

    /** the scale in each dimension of the original simplex */
    private final transient double[] scale;

    /** preallocated double array for a point being tested */
    private final transient double[] pointToTry;

    /** flag indicating optimizer should favor integers in its output */
    private final transient boolean favorInts;

    /**
     * Constructs a new <code>Optimizer</code>.
     *
     * @param  theFunction    the optimizable function to be optimized
     * @param  theScale       the approximate scale of the function - a step size over which the
     *                        function will vary appreciably
     * @param  favorIntegers  if <code>true</code>, the optimizer will check the optimal solution
     *                        and if any values are very close to integers, it will check the
     *                        function value at the integer and if the same as the minimum, it will
     *                        return the integer; if <code>false</code>, it will return the optimum
     *                        is initially found
     */
    public Optimizer(final Optimizable theFunction, final double[] theScale,
        final boolean favorIntegers) {

        this.function = theFunction;
        this.dimension = theFunction.dimension();
        this.scale = theScale.clone();
        this.pointToTry = new double[theFunction.dimension()];
        this.favorInts = favorIntegers;
    }

    /**
     * Generates optimal parameter values.
     *
     * @param   logLevel  the level of logging to do
     * @param   guess     the point at which to start optimization
     * @return  an array of doubles with the best-fit values for the parameters
     * @throws  FailedToConvergeException  if the optimizer failed to converge
     */
    public double[] optimize(final double[] guess, final Level logLevel)
        throws FailedToConvergeException {

        double[][] simplex;
        OutputValue value;
        double[] values;

        LOG.setLevel(logLevel);

        LOG.log(Level.INFO, "Optimizing in dimension {0}", this.dimension);

        simplex = new double[this.dimension + 1][this.dimension];
        values = new double[this.dimension + 1];

        // Establish an initial simplex
        for (int i = 0; i <= this.dimension; i++) {
            System.arraycopy(guess, 0, simplex[i], 0, this.dimension);
        }

        for (int j = 0; j < this.dimension; j++) {
            simplex[j + 1][j] += this.scale[j];
        }

        // Compute the values at simplex vertices.
        for (int i = 0; i <= this.dimension; i++) {
            value = this.function.evaluate(simplex[i]);
            values[i] = value.getValue();
        }

        // Now, run the simplex algorithm...
        if (!simplex(simplex, values)) {
            throw new FailedToConvergeException("Simplex Optimizer failed to converge");
        }

        // The result is now the first simplex vertex
        return simplex[0];
    }

    /**
     * Performs simplex optimization.
     *
     * @param   simplex  on input, the starting simplex; on output, all points of the simplex will
     *                   be within TOLERANCE of the local minimum found
     * @param   values   on input, the initial values at the vertices of the simplex; on output,
     *                   the values at the new vertices
     * @return  <code>true</code> if algorithm converged on a solution before reaching maximum
     *          iterations; <code>false</code> if not
     */
    private boolean simplex(final double[][] simplex, final double[] values) {

        boolean converged;
        int ihi, ilo, inhi, mpts;
        double sum;
        double swap;
        double[] swapvertex;
        double ysave;
        double ytry;
        double rtol;
        double[] infNorms;
        double[] tryInts;
        int loops = 0;
        double ceil;
        double floor;
        OutputValue value;
        StringBuilder str;

        converged = false;
        infNorms = new double[this.dimension];

        mpts = this.dimension + 1;

        // Load partial sums (this is the sum of the infinity norms of all
        // vectors in the simplex).
        for (int j = 0; j < this.dimension; j++) {
            sum = 0.0;

            for (int i = 0; i < mpts; i++) {
                sum += simplex[i][j];
            }

            infNorms[j] = sum;
        }

        str = new StringBuilder(200);

        for (;;) {

            // First, find the highest, next highest, and lowest vertices.
            ilo = 0;

            if (values[0] > values[1]) {
                ihi = 0;
                inhi = 1;
            } else {
                ihi = 1;
                inhi = 0;
            }

            for (int i = 0; i < mpts; i++) {

                if (values[i] <= values[ilo]) {
                    ilo = i;
                }

                if (values[i] > values[ihi]) {
                    inhi = ihi;
                    ihi = i;
                } else if ((values[i] > values[inhi]) && (i != inhi)) {
                    inhi = i;
                }
            }

            // Compute tolerance between values, see if we're done.
            if ((values[ihi] - values[ilo]) == 0) {
                rtol = 0;
            } else {
                rtol = 2.0 * Math.abs(values[ihi] - values[ilo])
                    / (Math.abs(values[ihi]) + Math.abs(values[ilo]));
            }

            if (loops > (MAXITERATIONS - 10)) {
                str.setLength(0);
                str.append("FAILING: lo=");
                str.append(values[ilo]);
                str.append(", hi=");
                str.append(values[ihi]);
                str.append(", tol=");
                str.append(rtol);
                LOG.warning(str.toString());
            }

            if (rtol < TOLERANCE) {

                // Put the lowest value in values[0];
                swap = values[0];
                values[0] = values[ilo];
                values[ilo] = swap;

                // Put the vertex with the lowest error in simplex[0]
                swapvertex = simplex[0];
                simplex[0] = simplex[ilo];
                simplex[ilo] = swapvertex;

                converged = true;
                LOG.log(Level.INFO, "converged after {0} iterations: {1}",
                    new Object[] { loops, rtol });

                break;
            }

            // If we've looped too many times, bail out.
            if (loops > MAXITERATIONS) {
                LOG.log(Level.WARNING,
                    "Simplex optimizer exceeded {0} iterations - failed to converge",
                    MAXITERATIONS);

                break;
            }

            loops += 2;

            // Extrapolate the simplex by a factor of -1 through the face of
            // the simplex across from the high point.
            ytry = attempt(simplex, values, infNorms, ihi, -1.0);

            if (ytry <= values[ilo]) {

                // This is better than our best point, so try increasing
                // extrapolation to a factor of -2
                ytry = attempt(simplex, values, infNorms, ihi, 2.0);

                this.function.updatedParameters(simplex[ilo]);
            } else if (ytry >= values[inhi]) {

                // reflected point is worse than the second-highest, so look
                // for an intermediate lower point (contract simplex)
                ysave = values[ihi];
                ytry = attempt(simplex, values, infNorms, ihi, 0.5);

                if (ytry >= ysave) {

                    // Can't get rid of high point, better contract around
                    // the lowest point
                    for (int i = 0; i < mpts; i++) {

                        if (i != ilo) {

                            for (int j = 0; j < this.dimension; j++) {
                                simplex[i][j] = 0.5 * (simplex[i][j] + simplex[ilo][j]);
                            }

                            value = this.function.evaluate(simplex[i]);
                            values[i] = value.getValue();
                        }
                    }

                    loops++;

                    // Recompute infinity norms
                    for (int j = 0; j < this.dimension; j++) {
                        sum = 0.0;

                        for (int i = 0; i < mpts; i++) {
                            sum += simplex[i][j];
                        }

                        infNorms[j] = sum;
                    }
                }
            } else {
                loops--;
            }
        }

        // One last thing - if any of the values are very near an integer, we
        // try adjusting them to the integer value and evaluating. That way
        // we can find integer minima exactly.
        if (converged && this.favorInts) {
            tryInts = simplex[0].clone();
            value = this.function.evaluate(tryInts);
            ysave = value.getValue();

            for (int i = 0; i < tryInts.length; i++) {
                ceil = Math.ceil(tryInts[i]);
                floor = Math.floor(tryInts[i]);

                if ((ceil - tryInts[i]) < 1E-4) {
                    tryInts[i] = ceil;
                } else if ((tryInts[i] - floor) < 1E-4) {
                    tryInts[i] = floor;
                }

                // Get rid of weird "negative zero" value.
                if (tryInts[i] == -0.0) {
                    tryInts[i] = 0.0;
                }
            }

            value = this.function.evaluate(tryInts);
            ytry = value.getValue();

            if (ytry <= ysave) {
                System.arraycopy(tryInts, 0, simplex[0], 0, tryInts.length);
            }
        }

        return converged;
    }

    /**
     * Extrapolates through the face of the simplex across from the high point, checking the value
     * there, and replacing the high point if the new value is lower.
     *
     * @param   simplex    on input, the starting simplex; on output, the adjusted simplex
     * @param   values     on input, the initial values at the vertices of the simplex; on output,
     *                     the values at the new vertices
     * @param   infNorms   the infinity norms of the vertex points
     * @param   highIndex  the index of the vertex with the highest node
     * @param   fac        the factor by which to extrapolate
     * @return  the function value at the trial point
     */
    private double attempt(final double[][] simplex, final double[] values,
        final double[] infNorms, final int highIndex, final double fac) {

        double fac1;
        double fac2;
        double result;
        OutputValue value;

        fac1 = (1 - fac) / this.dimension;
        fac2 = fac1 - fac;

        for (int j = 0; j < this.dimension; j++) {
            this.pointToTry[j] = (infNorms[j] * fac1) - (simplex[highIndex][j] * fac2);
        }

        // Evaluate the function at the new point.
        value = this.function.evaluate(this.pointToTry);
        result = value.getValue();

        // If the value is smaller, move the simplex.
        if (result < values[highIndex]) {
            values[highIndex] = result;

            for (int j = 0; j < this.dimension; j++) {
                infNorms[j] += this.pointToTry[j] - simplex[highIndex][j];
            }

            System.arraycopy(this.pointToTry, 0, simplex[highIndex], 0, this.dimension);
        }

        return result;
    }
}
