import edu.rice.cs.mint.runtime.Code;
import edu.rice.cs.mint.runtime.SafeCode;
import edu.rice.cs.mint.util.Benchmark;

import static edu.rice.cs.mint.util.Range.*;

public class MatrixMultiply
{   
    public static void printArray(Double[][] a)
    {
        for(int i = 0; i< a.length; i++){
            for(int j = 0; j < a[0].length; j++)
                System.out.print(a[i][j] + " ");
            System.out.println();
        }
    }
    
    public static void printArray(double[][] a)
    {
        for(int i = 0; i< a.length; i++){
            for(int j = 0; j < a[0].length; j++)
                System.out.print(a[i][j] + " ");
            System.out.println();
        }
    }
    
    public static double[][] mmult(double[][] a, double[][] b, double[][] output)
    {
        for(int i = 0; i < a.length; i++)
            for(int j = 0; j < b[0].length; j++)
            for(int k = 0; k < b.length; k++)
            if(k == 0)
            output[i][j] = a[i][k] * b[k][j];
        else
            output[i][j] = output[i][j] + a[i][k] * b[k][j];
        return output;
    }
    
    public static separable Code<Void> smmult(Code<double[][]> a, Code<double[][]> b, Code<double[][]> output, int f, int m, int l)
    {
        Code<Void> stats = <| { } |>;
        for(final int i: range(0, f)) { 
            for(final int j: range(0, l)) {
                Code<Double> c = <| 0.0 |>;
                for(final int k: range(0, m)) {
                    Code<Double> temp1 = <| (`(a))[i][k] |>;
                    Code<Double> temp2 = <| (`(b))[k][j] |>;
                    if(k == 0)
                        c = <| `temp1 * `temp2 |>;
                    else
                        c = <|`c + `temp1 * `temp2 |>;
                }
                stats = <| {`stats; (`output)[i][j] = `c;} |>;
            }
        }
        return stats;
    }
    
    public static void main(String[] args) {
        System.out.println("mmult(a,b)");
        Benchmark.TimedTask[] results =
            Benchmark.benchmark(Benchmark.stagingTasks(new Benchmark.Task() {
            public void run() {
                final double[][] a = {{2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}};
                final double[][] b = {{2.3, 4.2, 3.2}, {2.3, 3.2, 4.6}, {2.3, 4.2, 4.6}, {2.3, 4.2, 3.2}};
                double[][] c = new double[4][3];
                //
                // call unstaged code here
                //
                double[][] res = mmult(a, b, c);
                //
                //
                //
            }
        }, new Benchmark.Thunk<Code<? extends Benchmark.Task>>() {
            public Code<? extends Benchmark.Task> value() {
                return <| new Benchmark.Task() {
                    public void run() {
                        final double[][] aa = {{2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}, {2.3, 4.2, 3.2, 4.6}};
                        final double[][] bb = {{2.3, 4.2, 3.2}, {2.3, 3.2, 4.6}, {2.3, 4.2, 4.6}, {2.3, 4.2, 3.2}};
                        double[][] cc = new double[4][3];
                        //
                        // splice in staged code here
                        //
                        `(smmult(<| aa |>, <| bb |>, <| cc |>, 4, 4, 3));
                        double[][] res = cc;
                        //
                        //
                        //
                    }
                } |>;
            }
        }));
        
//        Benchmark.print(0, results); // relative to unstaged
//        System.out.println();
        Benchmark.print(results.length-1, results); // relative to staged
    }
}
