package edu.rice.cs.mint.util;

import edu.rice.cs.mint.runtime.*;
import java.io.*;
import java.util.*;

/**
 * Benchmarking class. It calibrates the number of repetitions to make sure the tasks last at least 1 seconds each.
 */
public class Benchmark {
    /**
     * Number of seconds as target time. We double the number of repetitions until we exceed this duration.
     */
    public static float TARGET_SECONDS = 1.0f;
    static {
        try {
            String p = System.getProperty("edu.rice.cs.mint.util.benchmark.target.seconds");
            if (p!=null) {
                float f = new Float(p);
                TARGET_SECONDS = f;
            }
        }
        catch(Exception e) { /* just keep it at 1.0f */ }
    }
  
    /**
     * Extended Task with information about number of repetitions, etc.
     */
    public static abstract class Task implements Runnable {
        /** Create an unnamed task. */
        public Task() { }

        /** Create a task with a given name. */
        public Task(String name) {
            this();
            _name = name;
        }
  
        public abstract void run();
        
        /** Return the name of this task.
          * @return name of this task */
        public String name() { return _name; }
        /** Set the name of this task.
          * @param name new name of this task */
        public void setName(String name) { _name = name; }
        
        public long reps() { return _reps; }
        public void setReps(long reps) { _reps = reps; }
        
        public long time() { return _t; }
        public void setTime(long t) { _t = t; }
        
        public int normalIndex() { return _normalIndex; }
        public void setNormalIndex(int i) { _normalIndex = i; }
        
        public double percent() { return _percent; }
        public void setPercent(double percent) { _percent = percent; }
        
        public String toString() {
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            pw.printf("%-8s x %,12d reps: ~%7.3f s @ %16.9f mics/rep := %,13.3fx", name(), _reps, ((double)_t)/1000000000.0, (_t/1000.0)/_reps, _percent/100);
            return sw.toString();
        }
        
        protected String _name;
        protected long _reps = 1;
        protected int _normalIndex = 0; // index of the task that represents 100%
        protected long _t;
        protected double _percent;
    }
    
    /** Just a marker class to distinguish tasks that have been benchmarked already
      * from those that haven't. */
    public static class TimedTask extends Task {
        public TimedTask(Task task) {
            _task = task;
            _name = task.name();
            _reps = task.reps();
            _normalIndex = task.normalIndex(); // index of the task that represents 100%
            _t = task.time();
            _percent = task.percent();
        }
        public void run() {
            _task.run();
        }
        
        protected Task _task;
    }
    
        
    /** A task that can actually return a value. */
    public static abstract class Thunk<T> extends Task {
        public void run() {
            value();
        }
        public abstract T value();
    }


    /**
     * Given an unstaged Task and a staged Task<Code<? extends Task>>, return
     * an array of tasks to be executed by benchmark(). These tasks separate out the
     * individual phases of the staged code: gencode, compile and staged.
     * @param unstaged the unstaged code, passed in a Task<T>
     * @param genCode the task that returns the code to create a Task<T>
     * 
     * Example: To test power(2,17), you would call:
     * 
     * stagingTasks(new Task() {
     *   public void run() { int i = power(2,17); }
     * },
     * new Thunk<Code<? extends Task>>() {
     *   public Code<? extends Task> value() {
     *     return <| new Task() {
     *       public void run() {
     *         int i = `(spower(<| 2 |>, 17));
     *       }
     *     } |>;
     *   }
     * });
     * 
     * @return array of tasks to benchmark staged code
     */
    public static Task[] stagingTasks(Task unstaged,
                                      Thunk<Code<? extends Task>> genCode) {
        Task[] tasks = new Task[4];
        unstaged.setName("unstaged");
        tasks[0] = unstaged;
        genCode.setName("gencode");
        tasks[1] = genCode;
        final Code<? extends Task> x = genCode.value();
        tasks[2] = new Task("compile") {
            public void run() {
                x.run();
            }
        };
        final Task y = x.run();
        y.setName("staged");
        tasks[3] = y;
        return tasks;
    }

    /**
     * Benchmark the tasks given. Percentages are given relative to the first task (i.e. tasks[0] is 100%).
     * @param tasks tasks to benchmark
     * @return tasks with timing information
     */
    public static TimedTask[] benchmark(Task... tasks) {
        return benchmark(0, tasks);
    }
    
    /**
     * Benchmark the tasks given. Percentages are given relative to the specified task
     * (i.e. tasks[normalIndex] is 100%).
     * @param normalIndex index of the task that represents 100%
     * @param tasks tasks to benchmark
     * @return tasks with timing information
     */
    @SuppressWarnings("unchecked")
    public static TimedTask[] benchmark(int normalIndex, Task... tasks) {
        // sleep five seconds so that potential hard drive activity/background
        // tasks from starting Java are done
        try {
            Thread.sleep(5000);
        }
        catch(InterruptedException ie) { throw new RuntimeException(ie); }
          
        // warmup and calibration
        // note: this doubles the number of repetions one more time,
        // even when time >= 1 second, just to be on the safe side
        for (Task task: tasks) {
            task.setNormalIndex(normalIndex);
            long t1, t2, tsum = 0, reps = 0, totalReps = task.reps();
            do {
                t1 = System.nanoTime();
                // run specified number of repetitions
                for (int r=0; r<totalReps-reps; ++r) {
                    task.run();
                }
                t2 = System.nanoTime();
                tsum += (t2-t1);
                reps = totalReps;
                totalReps *= 2;
            } while (tsum<1000000000l*TARGET_SECONDS);
            task.setReps(totalReps);
        }
        
        // timing: run through all tasks
        for (Task task: tasks) {
            long t1, t2;
            t1 = System.nanoTime();
            for (int r=0; r<task.reps(); ++r) {
                task.run();
            }
            t2 = System.nanoTime();
            task.setTime(t2-t1);
        }
        
        TimedTask[] timed = new TimedTask[tasks.length];        
        // normalize
        for(int i=0; i<tasks.length; ++i) {
            if (tasks[i].normalIndex()==i) {
                tasks[i].setPercent(100);
            }
            else {
                double t100 = tasks[tasks[i].normalIndex()].time();
                double reps100 = tasks[tasks[i].normalIndex()].reps();
                double one100 = t100 / reps100;
                double t = tasks[i].time();
                double reps = tasks[i].reps();
                double one = t / reps;
                tasks[i].setPercent(((double)100)/one100*one);
            }
            timed[i] = new TimedTask(tasks[i]);
        }
        
        return timed;
    }
    
    /**
     * Print out the results.
     * Percentages are given relative to the first task (i.e. tasks[0] is 100%).
     * @param tasks tasks to benchmark
     */
    public static void print(TimedTask... tasks) {
        print(0, tasks);
    }
    
    /**
     * Print out the results.
     * Percentages are given relative to the specified task (i.e. tasks[normalIndex] is 100%).
     * @param tasks tasks to benchmark
     */
    public static void print(int normalIndex, TimedTask... tasks) {
        // normalize
        for(int i=0; i<tasks.length; ++i) {
            tasks[i].setNormalIndex(normalIndex);
            if (normalIndex==i) {
                tasks[i].setPercent(100);
            }
            else {
                double t100 = tasks[normalIndex].time();
                double reps100 = tasks[normalIndex].reps();
                double one100 = t100 / reps100;
                double t = tasks[i].time();
                double reps = tasks[i].reps();
                double one = t / reps;
                tasks[i].setPercent(((double)100)/one100*one);
            }
        }
        
        // print
        System.out.println("Relative to "+tasks[normalIndex].name());
        for(Task task: tasks) {
            System.out.println(task);
        }
    }
    
    /**
     * Benchmark the tasks given and print out the results.
     * Percentages are given relative to the first task (i.e. tasks[0] is 100%).
     * @param tasks tasks to benchmark
     */
    public static TimedTask[] benchmarkAndPrint(Task... tasks) {
        return benchmarkAndPrint(0, tasks);
    }
    
    /**
     * Benchmark the tasks given and print out the results.
     * Percentages are given relative to the specified task (i.e. tasks[normalIndex] is 100%).
     * @param tasks tasks to benchmark
     */
    public static TimedTask[] benchmarkAndPrint(int normalIndex, Task... tasks) {
        TimedTask[] results = benchmark(normalIndex, tasks);
        print(normalIndex, results);
        return results;
    }
}
