import edu.rice.cs.mint.runtime.Code;
import edu.rice.cs.mint.runtime.SafeCode;
import edu.rice.cs.mint.util.Benchmark;

import java.util.Iterator;
import java.util.NoSuchElementException;

import static edu.rice.cs.mint.util.Range.*;
import static edu.rice.cs.mint.util.Lift.*;

public class MatrixMultiplySparse
{   
    public static void printArray(Double[][] a)
    {
        for(int i = 0; i< a.length; i++){
            for(int j = 0; j < a[0].length; j++)
                System.out.print(a[i][j] + " ");
            System.out.println();
        }
    }
    
    public static void printArray(double[][] a)
    {
        for(int i = 0; i< a.length; i++){
            for(int j = 0; j < a[0].length; j++)
                System.out.print(a[i][j] + " ");
            System.out.println();
        }
    }
    
    public static void mmult(double[][] a, double[][] b, double[][] output, int l, int m, int n)
    {
        for(int i = 0; i < l; i++)
            for(int j = 0; j < m; j++) {
            double c = 0.0;
            for(int k = 0; k < n; k++){
                if(a[i][k] == 0.0) {
                    continue;
                } else if(a[i][k] == 1.0)
                    c += b[k][j];
                else
                    c += a[i][k] * b[k][j];
            }
            output[i][j] = c;
        }
    }
    
    public static separable Code<Void> smmult(final double[][] a, Code<double[][]> b, Code<double[][]> output, int l, int m, int n) {
        Code<Void> stats = <| { } |>;
        for(final int i: range(0, l)) { 
            for(final int j: range(0, m)) {
                Code<Double> c = <| 0.0 |>;
                for(final int k: range(0, n)) {
                    if(a[i][k] == 0.0)
                        continue;
                    else if(a[i][k] == 1.0)
                        c = <| `c + (`b)[k][j] |>;
                    else
                        c = <| `c + (`(lift(a[i][k])) * (`b)[k][j]) |>;
                }
                stats = <| { `stats; (`output)[i][j] = `c; } |>;
            }
        }
        return stats;
    }
    
    public static double theta = Math.PI / 4;
    public static double[][] A = {{ 1.0, 0.0, 0.0, 0.0 },
        { 0.0, 1.0, 0.0, 0.0 },
        { 0.0, 0.0, Math.cos (theta), Math.sin (theta) },
        { 0.0, 0.0, -(Math.sin (theta)), Math.cos (theta) }};
    //    final public static double[][] B = {{4.2}, {3.2}, {4.6}, {5.2}};
    public static double[][] B = {{4.0, 7.0, 3.0, 0.2},
        {5.5, 8.2, 4.4, 0.0},
        {0.0, 5.6, 1.0, 5.7},
        {4.8, 3.8, 4.2, 3.4}};
    public static double[][] C;
    public static int l = 4;
    public static int m = 4;
    public static int n = 4;
    
    public static void main(String[] args) {        
        String fileName = "../misc/Tina_DisCal.mtx";
        if (args.length!=0) {
            fileName = args[0];
        }
        java.io.File f = new java.io.File(fileName);
        if (f.exists()) {
            System.out.println(fileName);
            A = mintTestUtil.MatrixMarket.load(f);
            B = mintTestUtil.MatrixMarket.load(f);
        }
        else {
            System.out.println("built-in matrices");
        }
        l = A.length;
        m = B[0].length;
        n = B.length;
        System.out.println("mmult(a,b) l="+l+", m="+m+", n="+n+", "+f);
        Benchmark.TimedTask[] results =
            Benchmark.benchmark(Benchmark.stagingTasks(new Benchmark.Task() {
            public void run() {
                C = new double[l][n];
                //
                // call unstaged code here
                //
                mmult(A, B, C, l, m, n);
                double[][] res = C;
                //
                //
                //
            }
        }, new Benchmark.Thunk<Code<? extends Benchmark.Task>>() {
            public Code<? extends Benchmark.Task> value() {
                return <| new Benchmark.Task() {
                    public void run() {
                        C = new double[l][n];
                        //
                        // splice in staged code here
                        //
                        `(smmult(A, <| B |>, <| C |>, l, m, n));
                        double[][] res = C;
                        //
                        //
                        //
                    }
                } |>;
            }
        }));
        
//        Benchmark.print(0, results); // relative to unstaged
//        System.out.println();
        Benchmark.print(results.length-1, results); // relative to staged
    }
}
