import edu.rice.hj.api.HjCallable;
import edu.rice.hj.api.HjFuture;
import edu.rice.hj.api.HjMetrics;
import edu.rice.hj.runtime.config.HjSystemProperty;
import edu.rice.hj.runtime.metrics.AbstractMetricsManager;
import edu.rice.hj.runtime.util.Pair;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static edu.rice.hj.Module1.*;

/**
 * Pascal's Triangle --- Computes (n C k) using futures
 * <p>
 * The purpose of this example is to illustrate effects on memoization with abstract metrics while using futures.
 * C(n, k) = C(n - 1, k - 1) + C(n - 1, k)
 *
 * @author Shams Imam (shams@rice.edu)
 * @author Vivek Sarkar (vsarkar@rice.edu)
 */
public class PascalsTriangleMemoizedSolution {

    public static void main(final String[] args) {

        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 8;
        final int k = args.length > 1 ? Integer.parseInt(args[1]) : (n - 3);

        System.out.println(" N = " + n);
        System.out.println(" K = " + k);

        kernel("Recursive Version (Sequential)", n, k, () -> chooseRecursiveSeq(n, k));
        kernel("Recursive Version (Parallel)", n, k, () -> chooseRecursivePar(n, k));
        kernel("Memoized Version (Sequential)", n, k, () -> chooseMemoizedSeq(n, k));
        kernel("Memoized Version (Parallel)", n, k, () -> chooseMemoizedPar(n, k));

    }

    private static void kernel(final String mode, final int N, final int K, final HjCallable<Integer> hjProcedure) {

        System.out.println("===============================================");
        System.out.println("\n Running: " + mode);

        System.setProperty(HjSystemProperty.abstractMetrics.propertyKey(), "true");
        initializeHabanero();

        finish(() -> {
            try {
                final int res = hjProcedure.call();
                System.out.println(N + " choose " + K + " = " + res);
            } catch (final Exception e) {
                e.printStackTrace();
            }
        });

        finalizeHabanero();
        final HjMetrics actualMetrics = abstractMetrics();
        AbstractMetricsManager.dumpStatistics(actualMetrics);

        System.out.println("===============================================");
    }

    private static int computeSum(final int left, final int right) {
        doWork(1);
        return left + right;
    }

    private static int computeBaseCaseResult() {
        doWork(1);
        return 1;
    }

    private static int chooseRecursiveSeq(final int N, final int K) {

        if (N == 0 || K == 0 || N == K) {
            return computeBaseCaseResult();
        }

        final int left = chooseRecursiveSeq(N - 1, K - 1);
        final int right = chooseRecursiveSeq(N - 1, K);

        return computeSum(left, right);
    }

    private static int chooseRecursivePar(final int N, final int K) {

        if (N == 0 || K == 0 || N == K) {
            return computeBaseCaseResult();
        }

        final HjFuture<Integer> left = future(() -> chooseRecursivePar(N - 1, K - 1));
        final HjFuture<Integer> right = future(() -> chooseRecursivePar(N - 1, K));

        final HjFuture<Integer> resultFuture = future(() -> {
            final Integer leftValue = left.get();
            final Integer rightValue = right.get();
            return computeSum(leftValue, rightValue);
        });
        return resultFuture.get();
    }

    private static final Map<Pair<Integer, Integer>, Integer> chooseMemoizedSeqCache = new ConcurrentHashMap<>();

    private static int chooseMemoizedSeq(final int N, final int K) {

        final Pair<Integer, Integer> key = Pair.factory(N, K);
        if (chooseMemoizedSeqCache.containsKey(key)) {
            final Integer result = chooseMemoizedSeqCache.get(key);
            return result;
        }

        if (N == 0 || K == 0 || N == K) {
            final Integer result = computeBaseCaseResult();
            chooseMemoizedSeqCache.put(key, result);
            return result;
        }

        final int left = chooseMemoizedSeq(N - 1, K - 1);
        final int right = chooseMemoizedSeq(N - 1, K);

        final int result = computeSum(left, right);
        chooseMemoizedSeqCache.put(key, result);
        return result;
    }

    private static final Map<Pair<Integer, Integer>, HjFuture<Integer>> chooseMemoizedParCache = new ConcurrentHashMap<>();

    private static int chooseMemoizedPar(final int N, final int K) {

        final Pair<Integer, Integer> key = Pair.factory(N, K);
        if (chooseMemoizedParCache.containsKey(key)) {
            final HjFuture<Integer> result = chooseMemoizedParCache.get(key);
            return result.get();
        }

        final HjFuture<Integer> resultFuture = future(() -> {
            if (N == 0 || K == 0 || N == K) {
                return computeBaseCaseResult();
            }

            final HjFuture<Integer> left = future(() -> chooseMemoizedPar(N - 1, K - 1));
            final HjFuture<Integer> right = future(() -> chooseMemoizedPar(N - 1, K));

            final Integer leftValue = left.get();
            final Integer rightValue = right.get();
            return computeSum(leftValue, rightValue);
        });
        chooseMemoizedParCache.put(key, resultFuture);
        return resultFuture.get();
    }

}