import static edu.rice.hj.Module1.*;

/**
 * Class for running a one dimensional averaging as covered in class
 */
public class OneDimAveraging {

    private static double[] initialOutput;

    /**
     * Stores the starting state of the array as initialOutput for validation purposes
     * @param oneDimAveraging - a OneDimAverging object to work upon
     */
    public static void setOutput(final OneDimAveraging oneDimAveraging) {
        if (OneDimAveraging.initialOutput != null) {
            System.out.println("Warning: initialOutput has already been set.");
        }
        OneDimAveraging.initialOutput = oneDimAveraging.myVal;
    }

    /**
     * Prints out all the inputs being used for this run in a pretty way
     * @param tasks - number of groups
     * @param n - size of the array
     * @param iterations - how many iterations to run
     * @param rounds
     */
    public static void printParams(final int tasks, final int n, final int iterations, final int rounds) {
        System.out.println("Configuration: ");
        System.out.println("  # groups: " + tasks);
        System.out.println("  Array size n: " + n);
        System.out.println("  # iterations: " + iterations);
        System.out.println("  Rounds: " + rounds + " (to reduce JIT overhead)");
        System.out.println();
    }

    /**
     * function to time running of the program
     * @param label - name of application running
     * @param actualBody - Runnable object which contains the code you want to time
     * @param postExecBody - Runnable object which contains code that needs to be executed
     *                     after running actualBody
     */
    private static void timeIt(
            final String label,
            final Runnable actualBody,
            final Runnable postExecBody) {

        final long s = System.currentTimeMillis();
        actualBody.run();
        final long e = System.currentTimeMillis();

        postExecBody.run();
        System.out.printf("%35s Time: %6d ms. \n", label, (e - s));
    }

    /**
     * Performs division between two ints, rounding up for decimal values
     * @param n - numerator
     * @param d - denominator
     * @return - result of the division rounded up
     */
    private static int ceilDiv(final int n, final int d) {
        final int m = n / d;
        if (m * d == n) {
            return m;
        } else {
            return (m + 1);
        }
    }

    /**
     * Main method for running one dimensional averaging
     * @param args -Takes as input: number of groups to use (defualt to 4096 * number of worker threads),
     *          size of the array (defaults to 8*1024*2048),
     *             number of iterations to get a better average (defaults to 64),
     *             number of rounds to execute the program (will run the same program multiple times) (defaults to 5)
     *
     */
    public static void main(final String[] args) {

        initializeHabanero();

        final int numGroups = (args.length > 0) ? Integer.parseInt(args[0]) : (4_096 * numWorkerThreads());
        final int n = (args.length > 1) ? Integer.parseInt(args[1]) : (8 * 1_024 * 2_048);
        final int iterations = (args.length > 2) ? Integer.parseInt(args[2]) : 64;
        final int rounds = (args.length > 3) ? Integer.parseInt(args[3]) : 5;
        printParams(numGroups, n, iterations, rounds);

        {
            // initial run to set the output to be equal to the sequential run
            final OneDimAveraging initialObj = new OneDimAveraging(n);
            initialObj.runSequential(iterations, numGroups);
            setOutput(initialObj);
        }

        System.out.println("Timed executions:");
        finish(() -> {
            for (int r = 0; r < rounds; r++) {
                final OneDimAveraging serialBody = new OneDimAveraging(n);
                final OneDimAveraging forkJoinBody = new OneDimAveraging(n);
                final OneDimAveraging chunkedForkJoinBody = new OneDimAveraging(n);

                System.out.println(" Round: " + r + (r == 0 ? " [ignore: warm up for JIT]" : ""));

                timeIt("Sequential", () -> {
                    serialBody.runSequential(iterations, numGroups);
                }, () -> {
                    serialBody.validateOutput();
                });

                timeIt(String.format("Fork-Join [numGroups=%d]", numGroups), () -> {
                    forkJoinBody.runForkJoin(iterations, numGroups);
                }, () -> {
                    forkJoinBody.validateOutput();
                });

                timeIt(String.format("Chunked-ForkJoin [numGroups=%d]", numGroups), () -> {
                    chunkedForkJoinBody.runChunkedForkJoin(iterations, numGroups);
                }, () -> {
                    chunkedForkJoinBody.validateOutput();
                });

                System.out.println();
            }
        });

        finalizeHabanero();
    }

    /**
     * myVal is the array that is averaged upon
     */
    public double[] myNew, myVal;
    public int n;

    /**
     * Constructor for the OneDimAveraging object, sets all the member variables
     * @param n - size of the array
     */
    public OneDimAveraging(final int n) {
        this.n = n;
        this.myNew = new double[n + 2];
        this.myVal = new double[n + 2];
        this.myVal[n + 1] = 1.0;
    }

    /**
     * Checks to make sure that the results are valid. Prints an error and ceases execution
     * if the InitialOuptut or myVal are null. Otherwise, makes sure that the difference
     * between the start and now for each value is less then 1e-20
     */
    public void validateOutput() {
        if (OneDimAveraging.initialOutput == null) {
            System.out.println("initialOutput is null");
            return;
        } else if (myVal == null) {
            System.out.println("myVal is null");
            return;
        }

        for (int i = 0; i < n + 2; i++) {
            final double init = OneDimAveraging.initialOutput[i];
            final double curr = myVal[i];
            final double diff = Math.abs(init - curr);
            if (diff > 1e-20) {
                System.out.println("ERROR: validation failed!");
                System.out.println("  Diff: myVal[" + i + "]=" + curr + " != initialOutput[" + i + "]=" + init);
                break;
            }
        }
    }

    /**
     * Does oneDimAveraging sequentially
     * @param iterations - Number of times to perform the average
     * @param numGroups - How many chunks to break the array into
     */
    public void runSequential(final int iterations, final int numGroups) {

        final int batchSize = ceilDiv(n, numGroups);
        for (int iter = 0; iter < iterations; iter++) {
            forseq(0, numGroups - 1, (i) -> {

                final int start = i * batchSize + 1;
                final int end = Math.min(start + batchSize - 1, n);

                for (int j = start; j <= end; j++) {
                    myNew[j] = (myVal[j - 1] + myVal[j + 1]) / 2.0;
                }
            });
            final double[] temp = myNew;
            myNew = myVal;
            myVal = temp;
        }
    }

    /**
     * Fork-join algorithm for 1-D Iterative Averaging
     *
     * @param iterations - Number of times to perform the average
     * @param numGroups - How many chunks to break the array into
     */
    public void runForkJoin(final int iterations, final int numGroups) {

        final int batchSize = ceilDiv(n, numGroups);
        for (int iter = 0; iter < iterations; iter++) {
            forseq(0, numGroups - 1, (i) -> {

                final int start = i * batchSize + 1;
                final int end = Math.min(start + batchSize - 1, n);

                for (int j = start; j <= end; j++) {
                    myNew[j] = (myVal[j - 1] + myVal[j + 1]) / 2.0;
                }
            });
            final double[] temp = myNew;
            myNew = myVal;
            myVal = temp;
        }
    }

    /**
     * Fork-join algorithm for 1-D Iterative Averaging with chunking
     *
     * @param iterations - Number of times to perform the average
     * @param numGroups - How many chunks to break the array into
     */
    public void runChunkedForkJoin(final int iterations, final int numGroups) {

        final int batchSize = ceilDiv(n, numGroups);
        for (int iter = 0; iter < iterations; iter++) {
            forseq(0, numGroups - 1, (i) -> {

                final int start = i * batchSize + 1;
                final int end = Math.min(start + batchSize - 1, n);

                for (int j = start; j <= end; j++) {
                    myNew[j] = (myVal[j - 1] + myVal[j + 1]) / 2.0;
                }
            });
            final double[] temp = myNew;
            myNew = myVal;
            myVal = temp;
        }
    }

}