package hj.array.view.test;

import hj.lang.*;
import hj.array.*;
import hj.array.view.*;

import edu.rice.cs.mint.runtime.Code;
import edu.rice.cs.mint.runtime.SafeCode;
import edu.rice.cs.mint.util.Benchmark;

public class SDoubleArrayViewMatrixMultSparse {
    public static void mmult(DoubleArrayView a, DoubleArrayView b, DoubleArrayView output, int l, int m, int n)
    {
        for(int i = 0; i < l; i++)
            for(int j = 0; j < m; j++) {
            double c = 0.0;
            for(int k = 0; k < n; k++){
                double a_ik = a.get(i,k);
                if(a_ik== 0.0) {
                    continue;
                } else if(a_ik == 1.0)
                    c += b.get(k,j);
                else
                    c += a_ik * b.get(k,j);
            }
            output.set(c,i,j);
        }
    }
    
    public static separable Code<Void> smmult(final DoubleArrayView a, SDoubleArrayView b, SDoubleArrayView output, int l, int m, int n)
    {
        Code<Void> stats = <| { } |>;
        for(int i = 0; i < l; i++)
            for(int j = 0; j < m; j++){
            SafeCode<Double> c = <| 0.0 |>;
            for(int k = 0; k < n; k++){
                final double a_ik = a.get(i,k);
                if(a_ik == 0.0)
                    continue;
                else if(a_ik == 1.0)
                    c = <| `c + `(b.get(k,j)) |>;
                else
                    c = <| `c + (a_ik * `(b.get(k,j))) |>;
            }
            stats = <| { `stats; `(output.set(c,i,j)); } |>;
        }
        return stats;
    }
    
    public static double theta = Math.PI / 4;
    public static double[][] A = {{ 1.0, 0.0, 0.0, 0.0 },
        { 0.0, 1.0, 0.0, 0.0 },
        { 0.0, 0.0, Math.cos (theta), Math.sin (theta) },
        { 0.0, 0.0, -(Math.sin (theta)), Math.cos (theta) }};
    public static double[][] B = {{4.0, 7.0, 3.0, 0.2},
        {5.5, 8.2, 4.4, 0.0},
        {0.0, 5.6, 1.0, 5.7},
        {4.8, 3.8, 4.2, 3.4}};
    public static int l = 4;
    public static int m = 4;
    public static int n = 4;
    public static double[] Abase;
    public static double[] Bbase;
    public static double[] Cbase;
    public static DoubleArrayView Aview;
    public static DoubleArrayView Bview;
    public static DoubleArrayView Cview;
    public static SDoubleArrayView SBview;
    public static SDoubleArrayView SCview;
    
    public static void main(String[] args) {
        String fileName = "../misc/Tina_DisCal.mtx";
        if (args.length!=0) {
            fileName = args[0];
        }
        java.io.File f = new java.io.File(fileName);
        if (f.exists()) {
            System.out.println(fileName);
            A = mintTestUtil.MatrixMarket.load(f);
            B = mintTestUtil.MatrixMarket.load(f);
        }
        else {
            System.out.println("built-in matrices");
        }
        l = A.length;
        m = B[0].length;
        n = B.length;
        
        Abase = new double[l*n];
        int aIndex = 0;
        for(int i = 0; i < l; i++) for(int k = 0; k < n; k++) Abase[aIndex++] = A[i][k];
        
        Bbase = new double[n*m];
        int bIndex = 0;
        for(int k = 0; k < n; k++) for(int j = 0; j < m; j++) Bbase[bIndex++] = B[k][j];
        
        Cbase = new double[l*m];
        int cIndex = 0;
        for(int i = 0; i < l; i++) for(int j = 0; j < m; j++) Cbase[cIndex++] = 0;
        
        
        Aview = new DoubleArrayView(Abase, 0, new RegionRectangular(new int[] { 0, l-1, 0, n-1 }, 0));
        Bview = new DoubleArrayView(Bbase, 0, new RegionRectangular(new int[] { 0, n-1, 0, m-1 }, 0));
        Cview = new DoubleArrayView(Cbase, 0, new RegionRectangular(new int[] { 0, l-1, 0, m-1 }, 0));
        SBview = new SDoubleArrayView(<| Bbase |>, 0, new SRegionRectangular(new int[] { 0, n-1, 0, m-1 }, 0));
        SCview = new SDoubleArrayView(<| Cbase |>, 0, new SRegionRectangular(new int[] { 0, l-1, 0, m-1 }, 0));
        
        System.out.println("mmult(a,b) l="+l+", m="+m+", n="+n); //+", "+f);
        Benchmark.TimedTask[] results =
            Benchmark.benchmark(Benchmark.stagingTasks(new Benchmark.Task() {
            public void run() {
                //
                // call unstaged code here
                //
                mmult(Aview, Bview, Cview, l, m, n);
                //
                //
                //
            }
        }, new Benchmark.Thunk<Code<? extends Benchmark.Task>>() {
            public Code<? extends Benchmark.Task> value() {
                return <| new Benchmark.Task() {
                    public void run() {
                        //
                        // splice in staged code here
                        //
                        // FIXME: should be able to use B here
                        `(smmult(Aview, SBview, SCview, l, m, n));
                        //
                        //
                        //
                    }
                } |>;
            }
        }));
        
//        Benchmark.print(0, results); // relative to unstaged
//        System.out.println();
        Benchmark.print(results.length-1, results); // relative to staged
    }
}
