package com.srbenoit.modeling.mesh;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.GridLayout;
import java.awt.Toolkit;
import java.lang.reflect.InvocationTargetException;
import java.util.logging.Level;
import javax.swing.ButtonGroup;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JRadioButton;
import javax.swing.JSlider;
import javax.swing.SwingConstants;
import javax.swing.SwingUtilities;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
import com.srbenoit.geom.BasedVector3;
import com.srbenoit.geom.Vector3;
import com.srbenoit.render.Camera;
import com.srbenoit.render.Light;
import com.srbenoit.render.RenderPanel;
import com.srbenoit.render.Scene;
import com.srbenoit.render.WorldFace;
import com.srbenoit.render.WorldVertex;

/**
 * A class that builds a small test patch consisting of 6 faces, 7 vertices, and 12 edges (actually
 * 24 edges, including each direction).
 */
public class HexPatch extends Scene implements ChangeListener, Runnable {

    /** equilibrium separation of elements (\varepsilon in paper) */
    public static final double EPS = 1e-7;

    /** the Lennard-Jones well depth for element interactions */
    public static final double KLJ = 1e3;

    /** the bulk modulus */
    public static final double KB = 6750;

    /** the tension */
    public static final double TENSION = 3e-5;

    /** the elastic modulus */
    // public static final double KC = 7e-20;
    public static final double KC = 7e1;

    /** the soft-sphere well depth for anti-self-intersection interactions */
    public static final double KSS = 1e3;

    /** scale to apply to force vector for display */
    public static final double FORCE_SCALE = 1e-18;

    /** the indices of the seven vertices that make up the patch */
    private final MembraneVertex[] vertices;

    /** the indices of the six faces that make up the patch */
    private final MembraneFace[] faces;

    /** the panel that will render the system */
    private RenderPanel panel;

    /** the camera that will be used in rendering the scene */
    private Camera camera;

    /** slider to control X coordinate of central point */
    private JSlider xSlider;

    /** slider to control Y coordinate of central point */
    private JSlider ySlider;

    /** slider to control Z coordinate of central point */
    private JSlider zSlider;

    /** a based vector to display a force */
    private BasedVector3 vec;

    /** radio button to select element interaction force */
    private JRadioButton elemFrc;

    /** radio button to select pressure force */
    private JRadioButton presFrc;

    /** radio button to select tension force */
    private JRadioButton tensFrc;

    /** radio button to select curvature force */
    private JRadioButton curvFrc;

    /** radio button to select self-intersection force */
    private JRadioButton siFrc;

    /** the equilibrium volume of the cell */
    private double defVolume;

    /** the current volume of the cell */
    private double curVolume;

    /**
     * Constructs a new <code>HexPatch</code>.
     */
    public HexPatch() {

        super(16);

        double r3o2;

        r3o2 = Math.sqrt(3) / 2;

        this.vertices = new MembraneVertex[12];
        this.faces = new MembraneFace[16];

        this.vertices[0] = new MembraneVertex(0.5 * EPS, -r3o2 * EPS, 0, 1, EPS);
        this.vertices[1] = new MembraneVertex(EPS, 0, 0, 1, EPS);
        this.vertices[2] = new MembraneVertex(0.5 * EPS, r3o2 * EPS, 0, 1, EPS);
        this.vertices[3] = new MembraneVertex(-0.5 * EPS, r3o2 * EPS, 0, 1, EPS);
        this.vertices[4] = new MembraneVertex(-EPS, 0, 0, 1, EPS);
        this.vertices[5] = new MembraneVertex(-0.5 * EPS, -r3o2 * EPS, 0, 1, EPS);
        this.vertices[6] = new MembraneVertex(0, 0, 0.5 * EPS, 1, EPS);
        this.vertices[7] = new MembraneVertex(0, 0, -1.1 * EPS, 1, EPS);

        this.vertices[8] = new MembraneVertex(0.3 * EPS, 0.3 * EPS, 1.6 * EPS, 2, 0.5 * EPS);
        this.vertices[9] = new MembraneVertex(0.3 * EPS, -0.3 * EPS, 1.6 * EPS, 2, 0.5 * EPS);
        this.vertices[10] = new MembraneVertex(-0.3 * EPS, -0.3 * EPS, 1.6 * EPS, 2, 0.5 * EPS);
        this.vertices[11] = new MembraneVertex(-0.3 * EPS, 0.3 * EPS, 1.6 * EPS, 2, 0.5 * EPS);

        for (int i = 0; i < this.vertices.length; i++) {
            addVertex(this.vertices[i]);
        }

        // the top surface of the "cell"
        this.faces[0] = new MembraneFace(this.vertices[0], this.vertices[1], this.vertices[6], 1);
        this.faces[1] = new MembraneFace(this.vertices[1], this.vertices[2], this.vertices[6], 1);
        this.faces[2] = new MembraneFace(this.vertices[2], this.vertices[3], this.vertices[6], 1);
        this.faces[3] = new MembraneFace(this.vertices[3], this.vertices[4], this.vertices[6], 1);
        this.faces[4] = new MembraneFace(this.vertices[4], this.vertices[5], this.vertices[6], 1);
        this.faces[5] = new MembraneFace(this.vertices[5], this.vertices[0], this.vertices[6], 1);

        // the bottom surface of the "cell"
        this.faces[6] = new MembraneFace(this.vertices[1], this.vertices[0], this.vertices[7], 1);
        this.faces[7] = new MembraneFace(this.vertices[2], this.vertices[1], this.vertices[7], 1);
        this.faces[8] = new MembraneFace(this.vertices[3], this.vertices[2], this.vertices[7], 1);
        this.faces[9] = new MembraneFace(this.vertices[4], this.vertices[3], this.vertices[7], 1);
        this.faces[10] = new MembraneFace(this.vertices[5], this.vertices[4], this.vertices[7], 1);
        this.faces[11] = new MembraneFace(this.vertices[0], this.vertices[5], this.vertices[7], 1);

        // an object for the cell to interact with
        this.faces[12] = new MembraneFace(this.vertices[8], this.vertices[9], this.vertices[10],
                2);
        this.faces[13] = new MembraneFace(this.vertices[8], this.vertices[10], this.vertices[11],
                2);
        this.faces[14] = new MembraneFace(this.vertices[8], this.vertices[10], this.vertices[9],
                2);
        this.faces[15] = new MembraneFace(this.vertices[8], this.vertices[11], this.vertices[10],
                2);

        for (int i = 0; i < this.faces.length; i++) {
            addFace(this.faces[i]);
        }

        // Create a camera that is slightly elevated above the X-Y plane, looking at the origin
        this.camera = new Camera(4 * EPS);
        this.camera.setPolarAngle(Math.PI / 3, true);

        addLight(new Light(80 * EPS, -80 * EPS, 100 * EPS, Color.WHITE));

        this.vec = new BasedVector3(0, 0, 0.5 * EPS, 0, 0, EPS);
        addBasedVector(vec);

        this.defVolume = computeVolume(1);
        LOG.log(Level.INFO, "Defeault volume = {0}", this.defVolume);
    }

    /**
     * Method to be called in the AWT event thread to construct the user interface.
     */
    @Override public void run() {

        JFrame frame;
        JPanel content;
        JPanel radios;
        Dimension screen;
        Dimension size;
        ButtonGroup group;

        frame = new JFrame("Hex Patch Visualization");
        content = new JPanel(new BorderLayout());
        frame.setContentPane(content);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

        this.panel = new RenderPanel(500, 400, this.camera);
        this.panel.addMouseListener(this.panel);
        this.panel.addMouseMotionListener(this.panel);

        content.add(this.panel, BorderLayout.CENTER);

        this.xSlider = new JSlider(SwingConstants.HORIZONTAL, -1000, 1000, 0);
        this.xSlider.addChangeListener(this);
        content.add(this.xSlider, BorderLayout.SOUTH);

        this.ySlider = new JSlider(SwingConstants.VERTICAL, -1000, 1000, 0);
        this.ySlider.addChangeListener(this);
        content.add(this.ySlider, BorderLayout.WEST);

        this.zSlider = new JSlider(SwingConstants.VERTICAL, -1000, 1000, 500);
        this.zSlider.addChangeListener(this);
        content.add(this.zSlider, BorderLayout.EAST);

        radios = new JPanel(new GridLayout(0, 5));
        this.elemFrc = new JRadioButton("elem");
        this.presFrc = new JRadioButton("pres");
        this.tensFrc = new JRadioButton("tens");
        this.curvFrc = new JRadioButton("curv");
        this.siFrc = new JRadioButton("si");
        radios.add(this.elemFrc);
        radios.add(this.presFrc);
        radios.add(this.tensFrc);
        radios.add(this.curvFrc);
        radios.add(this.siFrc);
        group = new ButtonGroup();
        group.add(this.elemFrc);
        group.add(this.presFrc);
        group.add(this.tensFrc);
        group.add(this.curvFrc);
        group.add(this.siFrc);
        this.curvFrc.setSelected(true);
        content.add(radios, BorderLayout.NORTH);

        frame.pack();
        screen = Toolkit.getDefaultToolkit().getScreenSize();
        size = frame.getSize();
        frame.setLocation((screen.width - size.width) / 2, (screen.height - size.height) / 2);
        frame.setVisible(true);
    }

    /**
     * Renders the scene.
     */
    public void go() {

        while (this.panel.isVisible()) {

            this.curVolume = computeVolume(1); // this refreshes face normals

            // TODO: Compute forces given current node positions, and update the based vector that
            // displays the force magnitude.
            if (this.elemFrc.isSelected()) {
                computeElemForce(6, this.vec);
            } else if (this.presFrc.isSelected()) {
                computePresForce(6, this.vec);
            } else if (this.tensFrc.isSelected()) {
                computeTensForce(6, this.vec);
            } else if (this.curvFrc.isSelected()) {
                computeCurvForce(6, this.vec);
            } else if (this.siFrc.isSelected()) {
                computeSiForce(6, this.vec);
            }

            vec.scaleVec(FORCE_SCALE);

            this.panel.render(this);

            try {
                Thread.sleep(20);
            } catch (InterruptedException ex) {
                // No action
            }
        }
    }

    /**
     * Computes the element interaction force on a vertex.
     *
     * @param  index  the index of the vertex whose force to compute
     * @param  vec    the vector in which to store the computed force
     */
    private void computeElemForce(final int index, final BasedVector3 vec) {

        MembraneVertex target;
        MembraneVertex vert;
        Vector3 delta;
        double dist;
        double ratio;
        double cube;
        double sixth;
        double twelfth;
        double scale;

        vec.setVec(0, 0, 0);
        delta = new Vector3();

        target = this.vertices[index];

        for (int i = 0; i < this.vertices.length; i++) {
            vert = this.vertices[i];

            if (vert.getMembraneId() == target.getMembraneId()) {
                continue;
            }

            delta.vectorBetween(target, vert);
            dist = delta.length();
            ratio = (EPS + vert.getRadius()) / dist;

            cube = ratio * ratio * ratio;
            sixth = cube * cube;
            twelfth = sixth * sixth;

            scale = -KLJ * (twelfth - sixth) / (dist * dist);
            vec.addVecScaled(scale, delta);
        }
    }

    /**
     * Computes the pressure force on a vertex.
     *
     * @param  index  the index of the vertex whose force to compute
     * @param  vec    the vector in which to store the computed force
     */
    private void computePresForce(final int index, final BasedVector3 vec) {

        MembraneVertex vert;
        double scale;
        WorldFace face;
        int indexInFace;
        Vector3 other1;
        Vector3 other2;
        WorldVertex world;
        Vector3 cross;

        vert = this.vertices[index];
        other1 = new Vector3();
        other2 = new Vector3();
        cross = new Vector3();

        // TODO: Look up default volume based on membrane ID
        scale = (-KB / (6 * this.defVolume)) * (this.curVolume - this.defVolume);

        vec.setVec(0, 0, 0);

        for (int i = 0; i < vert.getNumFaces(); i++) {
            face = vert.getFace(i);
            indexInFace = vert.getIndexInFace(i);
            world = face.getVertex(indexInFace + 1);
            other1.setVec(world.getPosX(), world.getPosY(), world.getPosZ());
            world = face.getVertex(indexInFace + 1);
            other2.setVec(world.getPosX(), world.getPosY(), world.getPosZ());
            cross.cross(other1, other2);
            vec.addVecScaled(scale, cross);
        }
    }

    /**
     * Computes the tension force on a vertex.
     *
     * @param  index  the index of the vertex whose force to compute
     * @param  vec    the vector in which to store the computed force
     */
    private void computeTensForce(final int index, final BasedVector3 vec) {

        MembraneVertex vert;
        double scale;
        WorldFace face;
        int indexInFace;
        WorldVertex other1;
        WorldVertex other2;
        Vector3 diff;
        Vector3 cross;

        vert = this.vertices[index];
        diff = new Vector3();
        cross = new Vector3();

        scale = -TENSION / 4;

        vec.setVec(0, 0, 0);

        for (int i = 0; i < vert.getNumFaces(); i++) {
            face = vert.getFace(i);
            indexInFace = vert.getIndexInFace(i);
            other1 = face.getVertex(indexInFace + 1);
            other2 = face.getVertex(indexInFace - 1);
            diff.vectorBetween(other1, other2);
            cross.cross(face, diff);
            vec.addVecScaled(scale, cross);
        }
    }

    /**
     * Computes the curvature force on a vertex.
     *
     * @param  index  the index of the vertex whose force to compute
     * @param  vec    the vector in which to store the computed force
     */
    private void computeCurvForce(final int index, final BasedVector3 vec) {

        MembraneVertex vert;
        double scale;
        int numFaces;
        WorldFace face1;
        WorldFace face2;
        WorldFace face3;
        int indexInFace1;
        int indexInFace2;
        WorldVertex end1;
        WorldVertex end2;
        WorldVertex end3;
        Vector3 ej0;
        Vector3 ej1;
        Vector3 ej2;
        Vector3 ej1minusej0;
        Vector3 ej2minusej1;
        Vector3 cross01;
        Vector3 cross12;
        Vector3 left1;
        Vector3 right1;
        Vector3 left;
        Vector3 right;
        Vector3 cross;
        double dot;
        Vector3 term;
        double coeff;
        boolean hit;

        vert = this.vertices[index];
        ej0 = new Vector3();
        ej1 = new Vector3();
        ej2 = new Vector3();
        ej1minusej0 = new Vector3();
        ej2minusej1 = new Vector3();
        cross01 = new Vector3();
        cross12 = new Vector3();
        left1 = new Vector3();
        right1 = new Vector3();
        left = new Vector3();
        right = new Vector3();
        cross = new Vector3();
        term = new Vector3();

        scale = 4 * KC / EPS;

        vec.setVec(0, 0, 0);

        numFaces = vert.getNumFaces();

        for (int i = 0; i < numFaces; i++) {
            face1 = vert.getFace(i);
            indexInFace1 = vert.getIndexInFace(i);

            end1 = face1.getVertex(indexInFace1 + 1);
            end2 = face1.getVertex(indexInFace1 - 1);

            // Find another face whose right edge is face1's left edge
            for (int j = 0; j < numFaces; j++) {

                if (j == i) {
                    continue;
                }

                face2 = vert.getFace(j);
                indexInFace2 = vert.getIndexInFace(j);

                if (face2.getVertex(indexInFace2 + 1) == end2) {

                    term.setVec(0, 0, 0);

                    end3 = face2.getVertex(indexInFace2 - 1);

                    // Comupte the contribution of the edge between two faces
                    ej0.vectorBetween(vert, end1);
                    ej1.vectorBetween(vert, end2);
                    ej2.vectorBetween(vert, end3);
                    ej1minusej0.subVec(ej1, ej0);
                    ej2minusej1.subVec(ej2, ej1);

                    dot = face1.dot(face2);
                    term.addVecScaled((1 - dot) / ((1 + dot) * ej1.length()), ej1);

                    coeff = 2 * ej1.length() / ((1 + dot) * (1 + dot));
                    cross01.cross(ej0, ej1);
                    cross12.cross(ej1, ej2);

                    left1.setVec(face2);
                    left1.addVecScaled(-dot, face1);
                    left1.scaleVec(1 / cross01.length());
                    left.cross(left1, ej1minusej0);

                    right1.setVec(face1);
                    right1.addVecScaled(-dot, face2);
                    right1.scaleVec(1 / cross12.length());
                    right.cross(right1, ej2minusej1);

                    left.addVec(right);
                    term.addVecScaled(coeff, left);

                    // Find the face that shares the edge opposite the vertex
                    for (int inx = 0; inx < end1.getNumFaces(); inx++) {
                        face3 = end1.getFace(inx);

                        if (face3 == face1) {
                            continue;
                        }

                        if ((face3.getVertex0() == end2) || (face3.getVertex1() == end2)
                                || (face3.getVertex2() == end2)) {

                            // face3 is the face with normal vector $\eta_j$.
                            dot = face1.dot(face3);
                            coeff = 2 * ej1minusej0.length() / ((1 + dot) * (1 + dot));

                            left.setVec(face3);
                            left.addVecScaled(-dot, face1);

                            right.setVec(ej1minusej0);
                            right.scaleVec(1 / cross01.length());

                            cross.cross(left, right);
                            term.addVecScaled(coeff, cross);

                            break;
                        }
                    }

                    vec.addVecScaled(scale, term);

                    break;
                }
            }
        }
    }

    /**
     * Computes the anti-self-intersection force on a vertex.
     *
     * @param  index  the index of the vertex whose force to compute
     * @param  vec    the vector in which to store the computed force
     */
    private void computeSiForce(final int index, final BasedVector3 vec) {

        MembraneVertex target;
        MembraneVertex vert;
        int count1;
        int count2;
        Vector3 delta;
        double dist;
        double ratio;
        double cube;
        double sixth;
        double twelfth;
        double scale;
        boolean shared;

        vec.setVec(0, 0, 0);
        delta = new Vector3();

        target = this.vertices[index];
        count1 = target.getNumFaces();

        for (int i = 0; i < this.vertices.length; i++) {
            vert = this.vertices[i];

            if (vert.getMembraneId() != target.getMembraneId()) {
                continue;
            }

            // If the vertex shares a face with the target vertex, ignore it
            count2 = vert.getNumFaces();
            shared = false;

outer:
            for (int j = 0; j < count1; j++) {

                for (int k = 0; k < count2; k++) {

                    if (target.getFace(j) == vert.getFace(k)) {
                        shared = true;

                        break outer;
                    }
                }
            }

            if (shared) {
                continue;
            }

            delta.vectorBetween(target, vert);
            dist = delta.length();
            ratio = EPS / dist;

            cube = ratio * ratio * ratio;
            sixth = cube * cube;
            twelfth = sixth * sixth;

            scale = -12 * KSS * twelfth / (dist * dist);
            vec.addVecScaled(scale, delta);
        }
    }

    /**
     * Handler for state changes from sliders.
     *
     * @param  evt  the change event
     */
    @Override public void stateChanged(final ChangeEvent evt) {

        double xCoord;
        double yCoord;
        double zCoord;

        xCoord = EPS * this.xSlider.getValue() / 500.0;
        yCoord = EPS * this.ySlider.getValue() / 1000.0;
        zCoord = EPS * this.zSlider.getValue() / 1000.0;

        this.vertices[6].setPos(xCoord, yCoord, zCoord);
        this.vec.setPos(xCoord, yCoord, zCoord);

        for (int i = 0; i < this.faces.length; i++) {
            this.faces[i].computeNormal();
        }
    }

    /**
     * Computes the current volume of a membrane.
     *
     * @param   index  the index of the membrane
     * @return  the volume
     */
    private double computeVolume(final int index) {

        double volume;
        MembraneFace face;
        WorldVertex vert0;
        WorldVertex vert1;
        WorldVertex vert2;
        Vector3 vector;
        Vector3 vec01;
        Vector3 vec02;
        double dot;
        Vector3 cross;

        volume = 0;
        vector = new Vector3();
        vec01 = new Vector3();
        vec02 = new Vector3();
        cross = new Vector3();

        for (int i = 0; i < this.faces.length; i++) {
            face = this.faces[i];

            if (face.getMembraneId() != index) {
                continue;
            }

            face.computeNormal();

            vert0 = face.getVertex0();
            vert1 = face.getVertex1();
            vert2 = face.getVertex2();

            vec01.vectorBetween(vert0, vert1);
            vec02.vectorBetween(vert0, vert2);

            vector.setVec(vert0.getPosX(), vert0.getPosY(), vert0.getPosZ());
            dot = vector.dot(face);
            cross.cross(vec01, vec02);

            volume += cross.length() * dot / 6;
        }

        return volume;
    }

    /**
     * Main method to create a hex patch and a render panel to display it.
     *
     * @param  args
     */
    public static void main(final String... args) {

        HexPatch hex;

        hex = new HexPatch();

        try {
            SwingUtilities.invokeAndWait(hex);
            hex.go();
        } catch (InterruptedException ex) {
            LOG.log(Level.SEVERE, null, ex);
        } catch (InvocationTargetException ex) {
            LOG.log(Level.SEVERE, null, ex);
        }
    }
}
