import edu.rice.cs.mint.runtime.Code;
import edu.rice.cs.mint.runtime.SafeCode;
import edu.rice.cs.mint.util.Benchmark;

public class unroll_part2{
    
    /*
     public static void main(String[] args)
     {
     final LoopIteration li = new PrintIteration();
     Code<? extends FibFun> CodeFib6 = <| new FibFun() {
     public void Fib(){
     Integer i;
     Integer init = 0;
     `(unroll( <| i |>, <| init |>, 3, 5, li));
     System.out.println(i);
     }
     } |>;
     FibFun sfib6 = CodeFib6.run();
     sfib6.Fib();
     }
     */
    
    public static void main(String[] args){
        final LoopIteration li = new PrintIteration();
        final sLoopIteration sli = new sPrintIteration();
        
        System.out.println("unroll(i, 0, 20, 100, li)");
        Benchmark.TimedTask[] results =
            Benchmark.benchmark(Benchmark.stagingTasks(new Benchmark.Task() {
            public void run() {
                //
                // call unstaged code here
                //
                final intCell i = new intCell();
                Integer init = 0;
                rolled(i, init, 20, 500, li);
                int j = i.value;
                //
                //
                //
            }
        }, new Benchmark.Thunk<Code<? extends Benchmark.Task>>() {
            public Code<? extends Benchmark.Task> value() {
                return <| new Benchmark.Task() {
                    public void run() {
                        //
                        // splice in staged code here
                        //
                        final intCell i = new intCell();
                        Integer init = 0;
                        `(unroll_part( <| i |>, <| init |>, 20, 500, sli, 8));
                        int j = i.value;
                        //
                        //
                        //
                    }
                } |>;
            }
        }));
        
//        Benchmark.print(0, results); // relative to unstaged
//        System.out.println();
        Benchmark.print(results.length-1, results); // relative to staged
    }
    
    
    public static separable Code<Void> unroll(Code<intCell> i,
                                              Code<Integer> init,
                                              int incr,
                                              int iterations,
                                              sLoopIteration F){
        Code<Void> C = <| { } |>;
        final int fIncr = incr;
        final int fIterations = iterations;
        for(int x = 0; x < iterations; x++){
            final int xx = x;
            C = <| { `C;
                `( F.iteration(<| `init + xx * fIncr |>)); } |>;
        }
        C = <| { `C; `(i).value = `init + fIterations * fIncr; } |>;
        return C;
    }
    
    public static void rolled(intCell i, Integer init, int incr, int iterations, LoopIteration F){
        for(int x = 0; x < iterations; x++){
            F.iteration(init + x * incr);
        }
        i.value = init + iterations * incr;
    }
    
    public static separable Code<Void> unroll_part(Code<intCell> i,
                                                   Code<Integer> init,
                                                   int incr,
                                                   int iterations,
                                                   sLoopIteration F,
                                                   int blockSize){
        final int loops = iterations/blockSize;
        final int leftover = iterations % blockSize;
        final int fBlockSize = blockSize;
        final int fIncr = incr;
        if(loops < 2)
            return unroll(i, init, incr, iterations, F);
        else{
            final Code<Integer> ii = <| `(i).value |>;
            return
                <| { for (`(i).value = `init;
                          `(i).value < (`init + new Integer(loops * fBlockSize * fIncr));){
                              `(unroll(i, ii, fIncr, fBlockSize, F));
                          }
                          `(unroll(i, ii, fIncr, leftover, F)); } |>;
        }
    }
    
    public static class intCell{
        public Integer value = 0;
    }

    public static interface LoopIteration{
        public void iteration (Integer i);
    }
    
    public static class PrintIteration implements LoopIteration{
        public void iteration (Integer i){
            new Integer(i * i);
        }
    }
    
    public static interface sLoopIteration{
        public separable Code<Void> iteration (Code<Integer> i);
    }
    
    public static class sPrintIteration implements sLoopIteration{
        public separable Code<Void> iteration (Code<Integer> i){
            return <| {new Integer(`i * `i); } |>;
        }
    }
}
