package com.srbenoit.geom;

/**
 * A transformation matrix capable representing an affine transformation and acting on vectors. We
 * store only 12 matrix elements rather than 16 since four of the elements are constant.
 */
public final class Transform3 {

    /** the first element of the first row */
    private transient double m00;

    /** the second element of the first row */
    private transient double m01;

    /** the third element of the first row */
    private transient double m02;

    /** the fourth element of the first row */
    private transient double m03;

    /** the first element of the second row */
    private transient double m10;

    /** the second element of the second row */
    private transient double m11;

    /** the third element of the second row */
    private transient double m12;

    /** the fourth element of the second row */
    private transient double m13;

    /** the first element of the third row */
    private transient double m20;

    /** the second element of the third row */
    private transient double m21;

    /** the third element of the third row */
    private transient double m22;

    /** the fourth element of the third row */
    private transient double m23;

    /** working array */
    private transient double[] rowScale;

    /** working array */
    private transient double[] temp;

    /** working array */
    private transient double[] result;

    /** working array */
    private transient int[] rowPerm;

    /**
     * Constructs and initializes a <code>Transform3</code> to the identity transformation.
     */
    public Transform3() {

        this.m00 = 1.0f;
        this.m01 = 0.0f;
        this.m02 = 0.0f;
        this.m03 = 0.0f;

        this.m10 = 0.0f;
        this.m11 = 1.0f;
        this.m12 = 0.0f;
        this.m13 = 0.0f;

        this.m20 = 0.0f;
        this.m21 = 0.0f;
        this.m22 = 1.0f;
        this.m23 = 0.0f;

        this.rowScale = new double[4];

        this.temp = new double[16];
        this.result = new double[16];
        this.rowPerm = new int[4];
    }

    /**
     * Returns a string that contains the values of this <code>Transform3</code>.
     *
     * @return  the <code>String</code> representation
     */
    @Override public String toString() {

        return this.m00 + ", " + this.m01 + ", " + this.m02 + ", " + this.m03 + "\n" + this.m10
            + ", " + this.m11 + ", " + this.m12 + ", " + this.m13 + "\n" + this.m20 + ", "
            + this.m21 + ", " + this.m22 + ", " + this.m23 + "\n";
    }

    /**
     * Sets the value of an element in the matrix.
     *
     * @param  row    the row
     * @param  col    the column
     * @param  value  the new value to place at that location in the matrix
     */
    public void set(final int row, final int col, final double value) {

        switch (row) {

        case 0:
            switch (col) {

            case 0:
                this.m00 = value;
                break;

            case 1:
                this.m01 = value;
                break;

            case 2:
                this.m02 = value;
                break;

            case 3:
                this.m03 = value;
                break;

            default:
                break;
            }

            break;

        case 1:
            switch (col) {

            case 0:
                this.m10 = value;
                break;

            case 1:
                this.m11 = value;
                break;

            case 2:
                this.m12 = value;
                break;

            case 3:
                this.m13 = value;
                break;

            default:
                break;
            }

            break;

        case 2:
            switch (col) {

            case 0:
                this.m20 = value;
                break;

            case 1:
                this.m21 = value;
                break;

            case 2:
                this.m22 = value;
                break;

            case 3:
                this.m23 = value;
                break;

            default:
                break;
            }

            break;

        default:

            break;
        }
    }

    /**
     * Gets the value of an element in the matrix.
     *
     * @param   row  the row
     * @param   col  the column
     * @return  the value at that location in the matrix
     */
    public double get(final int row, final int col) {

        double value;

        switch (row) {

        case 0:
            switch (col) {

            case 0:
                value = this.m00;
                break;

            case 1:
                value = this.m01;
                break;

            case 2:
                value = this.m02;
                break;

            case 3:
                value = this.m03;
                break;

            default:
                value = 0;
                break;
            }

            break;

        case 1:
            switch (col) {

            case 0:
                value = this.m10;
                break;

            case 1:
                value = this.m11;
                break;

            case 2:
                value = this.m12;
                break;

            case 3:
                value = this.m13;
                break;

            default:
                value = 0;
                break;
            }

            break;

        case 2:
            switch (col) {

            case 0:
                value = this.m20;
                break;

            case 1:
                value = this.m21;
                break;

            case 2:
                value = this.m22;
                break;

            case 3:
                value = this.m23;
                break;

            default:
                value = 0;
                break;
            }

            break;

        default:
            value = 0;
            break;
        }

        return value;
    }

    /**
     * Sets the elements of this matrix from another matrix.
     *
     * @param  matrix  the source matrix
     */
    public void set(final Transform3 matrix) {

        this.m00 = matrix.get(0, 0);
        this.m01 = matrix.get(0, 1);
        this.m02 = matrix.get(0, 2);
        this.m03 = matrix.get(0, 3);

        this.m10 = matrix.get(1, 0);
        this.m11 = matrix.get(1, 1);
        this.m12 = matrix.get(1, 2);
        this.m13 = matrix.get(1, 3);

        this.m20 = matrix.get(2, 0);
        this.m21 = matrix.get(2, 1);
        this.m22 = matrix.get(2, 2);
        this.m23 = matrix.get(2, 3);
    }

    /**
     * Right-multiplies this matrix by <code>matrix</code>.
     *
     * @param  matrix  the matrix by which to multiply this matrix
     */
    public void mul(final Transform3 matrix) {

        double c00;
        double c01;
        double c02;
        double c03;
        double c10;
        double c11;
        double c12;
        double c13;
        double c20;
        double c21;
        double c22;
        double c23;

        c00 = (this.m00 * matrix.get(0, 0)) + (this.m01 * matrix.get(1, 0))
            + (this.m02 * matrix.get(2, 0));
        c01 = (this.m00 * matrix.get(0, 1)) + (this.m01 * matrix.get(1, 1))
            + (this.m02 * matrix.get(2, 1));
        c02 = (this.m00 * matrix.get(0, 2)) + (this.m01 * matrix.get(1, 2))
            + (this.m02 * matrix.get(2, 2));
        c03 = (this.m00 * matrix.get(0, 3)) + (this.m01 * matrix.get(1, 3))
            + (this.m02 * matrix.get(2, 3)) + this.get(0, 3);

        c10 = (this.m10 * matrix.get(0, 0)) + (this.m11 * matrix.get(1, 0))
            + (this.m12 * matrix.get(2, 0));
        c11 = (this.m10 * matrix.get(0, 1)) + (this.m11 * matrix.get(1, 1))
            + (this.m12 * matrix.get(2, 1));
        c12 = (this.m10 * matrix.get(0, 2)) + (this.m11 * matrix.get(1, 2))
            + (this.m12 * matrix.get(2, 2));
        c13 = (this.m10 * matrix.get(0, 3)) + (this.m11 * matrix.get(1, 3))
            + (this.m12 * matrix.get(2, 3)) + this.get(1, 3);

        c20 = (this.m20 * matrix.get(0, 0)) + (this.m21 * matrix.get(1, 0))
            + (this.m22 * matrix.get(2, 0));
        c21 = (this.m20 * matrix.get(0, 1)) + (this.m21 * matrix.get(1, 1))
            + (this.m22 * matrix.get(2, 1));
        c22 = (this.m20 * matrix.get(0, 2)) + (this.m21 * matrix.get(1, 2))
            + (this.m22 * matrix.get(2, 2));
        c23 = (this.m20 * matrix.get(0, 3)) + (this.m21 * matrix.get(1, 3))
            + (this.m22 * matrix.get(2, 3)) + this.get(2, 3);

        this.m00 = c00;
        this.m01 = c01;
        this.m02 = c02;
        this.m03 = c03;
        this.m10 = c10;
        this.m11 = c11;
        this.m12 = c12;
        this.m13 = c13;
        this.m20 = c20;
        this.m21 = c21;
        this.m22 = c22;
        this.m23 = c23;
    }

    /**
     * Inverts the <code>Transform3</code> in place.
     *
     * @throws  SingularMatrixException  if the matrix is singular (not invertible)
     */
    public void invert() throws SingularMatrixException {

        int inx;

        // Use LU decomposition and back-substitution code specifically
        // for 4x4 matrices.

        // Copy source matrix to t1tmp
        this.temp[0] = this.m00;
        this.temp[1] = this.m01;
        this.temp[2] = this.m02;
        this.temp[3] = this.m03;

        this.temp[4] = this.m10;
        this.temp[5] = this.m11;
        this.temp[6] = this.m12;
        this.temp[7] = this.m13;

        this.temp[8] = this.m20;
        this.temp[9] = this.m21;
        this.temp[10] = this.m22;
        this.temp[11] = this.m23;

        this.temp[12] = 0.0;
        this.temp[13] = 0.0;
        this.temp[14] = 0.0;
        this.temp[15] = 1.0;

        // Calculate LU decomposition: Is the matrix singular?
        luDecomposition(this.temp, this.rowPerm);

        // Perform back substitution on the identity matrix
        for (inx = 0; inx < 16; inx++) {
            this.result[inx] = 0.0;
        }

        this.result[0] = 1.0;
        this.result[5] = 1.0;
        this.result[10] = 1.0;
        this.result[15] = 1.0;
        luBacksubstitution(this.temp, this.rowPerm, this.result);

        this.m00 = this.result[0];
        this.m01 = this.result[1];
        this.m02 = this.result[2];
        this.m03 = this.result[3];

        this.m10 = this.result[4];
        this.m11 = this.result[5];
        this.m12 = this.result[6];
        this.m13 = this.result[7];

        this.m20 = this.result[8];
        this.m21 = this.result[9];
        this.m22 = this.result[10];
        this.m23 = this.result[11];
    }

    /**
     * Given a 4x4 array <code>matrix0</code>, this function replaces it with the LU decomposition
     * of a row-wise permutation of itself. This function is similar to luDecomposition, except
     * that it is tuned specifically for 4x4 matrices.
     *
     * @param   matrix0   the matrix that is to be decomposed, and (on completion), the decomposed
     *                    matrix
     * @param   rowPerms  the row permutations resulting from partial pivoting
     * @throws  SingularMatrixException  if the matrix is singular (not invertible)
     */
    private void luDecomposition(final double[] matrix0, final int[] rowPerms)
        throws SingularMatrixException {

        int row;
        int col;
        int inx;
        double big;
        double tempr;
        int mtx;
        int imax;
        int knx;
        int target;
        int pt1;
        int pt2;
        double sum;

        // For each row ...
        for (row = 0; row < 4; row++) {

            // Find the largest element in the row
            big = 0.0;

            for (col = 0; col < 4; col++) {
                tempr = Math.abs(matrix0[col]);

                if (tempr > big) {
                    big = tempr;
                }
            }

            // Is the matrix singular?
            if (big == 0.0) {
                throw new SingularMatrixException();
            }

            this.rowScale[row] = 1.0 / big;
        }

        mtx = 0;

        // For all columns, execute Crout's method
        for (col = 0; col < 4; col++) {

            // Determine elements of upper diagonal matrix U
            for (inx = 0; inx < col; inx++) {
                target = mtx + (4 * inx) + col;
                sum = matrix0[target];
                pt1 = mtx + (4 * inx);
                pt2 = mtx + col;

                for (knx = 0; knx < inx; knx++) {
                    sum -= matrix0[pt1] * matrix0[pt2];
                    pt1++;
                    pt2 += 4;
                }

                matrix0[target] = sum;
            }

            // Search for largest pivot element and calculate
            // intermediate elements of lower diagonal matrix L.
            big = 0.0;
            imax = -1;

            for (inx = col; inx < 4; inx++) {
                target = mtx + (4 * inx) + col;
                sum = matrix0[target];
                pt1 = mtx + (4 * inx);
                pt2 = mtx + col;

                for (knx = 0; knx < col; knx++) {
                    sum -= matrix0[pt1] * matrix0[pt2];
                    pt1++;
                    pt2 += 4;
                }

                matrix0[target] = sum;

                // Is this the best pivot so far?
                tempr = this.rowScale[inx] * Math.abs(sum);

                if (tempr >= big) {
                    big = tempr;
                    imax = inx;
                }
            }

            if (imax < 0) {
                throw new SingularMatrixException();
            }

            // Is a row exchange necessary?
            if (col != imax) {

                // Yes: exchange rows
                pt1 = mtx + (4 * imax);
                pt2 = mtx + (4 * col);

                for (knx = 0; knx < 4; knx++) {
                    tempr = matrix0[pt1];
                    matrix0[pt1] = matrix0[pt2];
                    matrix0[pt2] = tempr;
                    pt1++;
                    pt2++;
                }

                // Record change in scale factor
                this.rowScale[imax] = this.rowScale[col];
            }

            // Record row permutation
            rowPerms[col] = imax;

            // Is the matrix singular
            if (matrix0[(mtx + (4 * col) + col)] == 0.0) {
                throw new SingularMatrixException();
            }

            // Divide elements of lower diagonal matrix L by pivot
            if (col != (4 - 1)) {
                tempr = 1.0 / (matrix0[(mtx + (4 * col) + col)]);
                target = mtx + (4 * (col + 1)) + col;

                for (inx = 0; inx < (3 - col); inx++) {
                    matrix0[target] *= tempr;
                    target += 4;
                }
            }
        }
    }

    /**
     * Solves a set of linear equations.
     *
     * @param  matrix1   the matrix produced by <code>luDecomposition</code> (not changed by this
     *                   method)
     * @param  rowPerms  the row permutations resulting from partial pivoting produced by <code>
     *                   luDecomposition</code> (not changed by this method)
     * @param  matrix2   a set of column vectors assembled into a 4x4 matrix of values (the
     *                   procedure takes each column of "matrix2" in turn and treats it as the
     *                   right-hand side of the matrix equation Ax = LUx = b. The solution vector
     *                   replaces the original column of the matrix. If <code>matrix2</code> is the
     *                   identity matrix, the procedure replaces its contents with the inverse of
     *                   the matrix from which <code>matrix1</code> was originally derived)
     */
    private void luBacksubstitution(final double[] matrix1, final int[] rowPerms,
        final double[] matrix2) {

        int prow;
        int perm;
        int jnx;
        double sum;

        for (int col = 0; col < 4; col++) {

            // Forward substitution
            prow = -1;

            for (int row = 0; row < 4; row++) {
                perm = rowPerms[row];
                sum = matrix2[col + (4 * perm)];
                matrix2[col + (4 * perm)] = matrix2[col + (4 * row)];

                if (prow >= 0) {

                    for (jnx = prow; jnx <= (row - 1); jnx++) {
                        sum -= matrix1[(row * 4) + jnx] * matrix2[col + (4 * jnx)];
                    }
                } else if (sum != 0.0) {
                    prow = row;
                }

                matrix2[col + (4 * row)] = sum;
            }

            // Back substitution
            matrix2[col + 12] /= matrix1[15];

            matrix2[col + 8] = (matrix2[col + 8] - (matrix1[11] * matrix2[col + 12]))
                / matrix1[10];

            matrix2[col + 4] = (matrix2[col + 4] - (matrix1[6] * matrix2[col + 8])
                    - (matrix1[7] * matrix2[col + 12])) / matrix1[5];

            matrix2[col] = (matrix2[col] - (matrix1[1] * matrix2[col + 4])
                    - (matrix1[2] * matrix2[col + 8]) - (matrix1[3] * matrix2[col + 12]))
                / matrix1[0];
        }
    }

    /**
     * Transform the point <code>pointIn</code> using this <code>Transform3</code> and place the
     * result into <code>pointOut</code>. This method permits the same tuple to be used as the
     * source and destination.
     *
     * @param  pointIn   the point to be transformed
     * @param  pointOut  the point into which the transformed values are placed
     */
    public void transformPoint(final Point3Int pointIn, final Point3Int pointOut) {

        pointOut.setPos((this.m00 * pointIn.getPosX()) + (this.m01 * pointIn.getPosY())
            + (this.m02 * pointIn.getPosZ()) + this.m03,
            (this.m10 * pointIn.getPosX()) + (this.m11 * pointIn.getPosY())
            + (this.m12 * pointIn.getPosZ()) + this.m13,
            (this.m20 * pointIn.getPosX()) + (this.m21 * pointIn.getPosY())
            + (this.m22 * pointIn.getPosZ()) + this.m23);
    }

    /**
     * Transform the vector <code>vecIn</code> using this <code>Transform3</code> and place the
     * result into <code>vecOut</code>. This method permits the same tuple to be used as the source
     * and destination.
     *
     * @param  vecIn   the vector to be transformed
     * @param  vecOut  the vector into which the transformed values are placed
     */
    public void transformVec(final Vector3Int vecIn, final Vector3Int vecOut) {

        vecOut.setVec((this.m00 * vecIn.getVecX()) + (this.m01 * vecIn.getVecY())
            + (this.m02 * vecIn.getVecZ()),
            (this.m10 * vecIn.getVecX()) + (this.m11 * vecIn.getVecY())
            + (this.m12 * vecIn.getVecZ()),
            (this.m20 * vecIn.getVecX()) + (this.m21 * vecIn.getVecY())
            + (this.m22 * vecIn.getVecZ()));
    }
}
