import edu.rice.cs.mint.runtime.Code;

public class StagedDataTypes {
    public static class WrongTypeException extends RuntimeException {
        public static separable <T> T returningThrow() { throw new WrongTypeException(); }
    }
    public static class EnvironmentException extends RuntimeException {
        public static separable <T> T returningThrow() { throw new EnvironmentException(); }
    }

    public static class Value {
        public int intValue() { throw new WrongTypeException(); }
        public boolean booleanValue() { throw new WrongTypeException(); }
    }
    
    public static class IntValue extends Value {
        private int _data;
        public IntValue(int data) { _data = data; }
        public int intValue() { return _data; }
        public String toString() { return _data+":IntValue"; }
    }
    
    public static class BooleanValue extends Value {
        private boolean _data;
        public BooleanValue(boolean data) { _data = data; }
        public boolean booleanValue() { return _data; }
        public String toString() { return _data+":BooleanValue"; }
    }
    
    public static abstract class SValue {
        public separable abstract Code<Value> codeValue();
        public separable Code<Integer> intCodeValue() { throw new WrongTypeException(); }
        public separable Code<Boolean> booleanCodeValue() { throw new WrongTypeException(); }
    }
    
    public static class SCodeValue extends SValue {
        private Code<Value> _valueData;
        public separable SCodeValue(Code<Value> valueData) { _valueData = valueData; }
        public separable Code<Value> codeValue() { return _valueData; }
        public separable Code<Integer> intCodeValue() { return <| (`_valueData).intValue() |>; }
        public separable Code<Boolean> booleanCodeValue() { return <| (`_valueData).booleanValue() |>; }
    }    
    
    public static class SIntValue extends SValue {
        private Code<Integer> _data;
        public separable SIntValue(Code<Integer> data) { _data = data; }
        public separable Code<Value> codeValue() { return <| (Value) new IntValue(`(_data)) |>; }
        public separable Code<Integer> intCodeValue() { return _data; }
    }

    public static class SBooleanValue extends SValue {
        private Code<Boolean> _data;
        public separable SBooleanValue(Code<Boolean> data) { _data = data; }
        public separable Code<Value> codeValue() { return <| (Value) new BooleanValue(`(_data)) |>; }
        public separable Code<Boolean> booleanCodeValue() { return _data; }
    }
    
    public static interface Exp {
        public separable SValue eval(Env e, FEnv f);
    }
    
    public static class Val implements Exp {
        private SValue _value;
        public Val(SValue value) {
            _value = value;
        }
        public separable SValue eval(Env e, FEnv f) {
            return _value;
        }
    }
    
    public static class Var implements Exp {
        private String _s;
        public Var(String s) {
            _s = s;
        }
        public separable SValue eval(Env e, FEnv f) {
            return e.get(_s);
        }
    }
    
    public static class App implements Exp {
        private final String _s;
        private final Exp _body;
        public App(String s, Exp body) {
            _s = s;
            _body = body;
        }
        public separable SValue eval(final Env e, final FEnv f) {
            final String s = _s;
            final Exp body = _body;
            return f.get(s).apply(body.eval(e,f));
        }
    }
    
    public static abstract class BinOp implements Exp {
        protected Exp _left, _right;
        public BinOp(Exp left, Exp right) {
            _left = left;
            _right = right;
        }
    }
    
    public static class Add extends BinOp {
        public Add(Exp left, Exp right) { super(left, right); }
        public separable SValue eval(Env e, FEnv f) {
            return new SIntValue(<| `(_left.eval(e,f).intCodeValue()) + `(_right.eval(e,f).intCodeValue()) |>);
        }
    }
    
    public static class Sub extends BinOp {
        public Sub(Exp left, Exp right) { super(left, right); }
        public separable SValue eval(Env e, FEnv f) {
            return new SIntValue(<| `(_left.eval(e,f).intCodeValue()) - `(_right.eval(e,f).intCodeValue()) |>);
        }
    }
    
    public static class Mul extends BinOp {
        public Mul(Exp left, Exp right) { super(left, right); }
        public separable SValue eval(Env e, FEnv f) {
            return new SIntValue(<| `(_left.eval(e,f).intCodeValue()) * `(_right.eval(e,f).intCodeValue()) |>);
        }
    }
    
    public static class Div extends BinOp {
        public Div(Exp left, Exp right) { super(left, right); }
        public separable SValue eval(Env e, FEnv f) {
            return new SIntValue(<| `(_left.eval(e,f).intCodeValue()) / `(_right.eval(e,f).intCodeValue()) |>);
        }
    }
    
    public static class Equals extends BinOp {
        public Equals(Exp left, Exp right) {
            super(left, right);
        }
        public separable SValue eval(Env e, FEnv f) {
            final SValue l = _left.eval(e,f);
            final SValue r = _right.eval(e,f);
            
            if ((l instanceof SIntValue) && (r instanceof SIntValue)) {
                return new SBooleanValue(<| `(l.intCodeValue())==`(r.intCodeValue()) |>);
            }
            else if ((l instanceof SBooleanValue) && (r instanceof SBooleanValue)) {
                return new SBooleanValue(<| `(l.booleanCodeValue())==`(r.booleanCodeValue()) |>);
            }
            else {
                return new SBooleanValue
                    (<|
                     let Value vl = `(l.codeValue()), vr = `(r.codeValue());
                     ((vl instanceof IntValue) && (vr instanceof IntValue))?(vl.intValue()==vr.intValue()):
                         (((vl instanceof BooleanValue) && (vr instanceof BooleanValue))?
                              (vl.booleanValue()==vr.booleanValue()):WrongTypeException.<Boolean>returningThrow()) |>);
            }
        }
    }
    
    public static class Less extends BinOp {
        public Less(Exp left, Exp right) {
            super(left, right);
        }
        public separable SValue eval(Env e, FEnv f) {
            return new SBooleanValue(<| `(_left.eval(e,f).intCodeValue())==`(_right.eval(e,f).intCodeValue()) |>);
        }
    }
    
    public static class And extends BinOp {
        public And(Exp left, Exp right) {
            super(left, right);
        }
        public separable SValue eval(Env e, FEnv f) {
            return new SBooleanValue(<| `(_left.eval(e,f).booleanCodeValue()) &&
                                     `(_right.eval(e,f).booleanCodeValue()) |>);
        }
    }
    
    public static class Not implements Exp {
        private Exp _exp;
        public Not(Exp exp) {
            _exp = exp;
        }
        public separable SValue eval(Env e, FEnv f) {
            return new SBooleanValue(<| !`(_exp.eval(e,f).booleanCodeValue()) |>);
        }
    }
    
    public static abstract class JoinApplyFun {
        // BUG: The Mint compiler has trouble with this generic method (more specifically the <X>):
        // It generates Code<X> where X is unbound
        public separable abstract <Y,X extends Integer> Code<Y> fun(Code<Y> c1, Code<Y> c2, Class<Y> cy, X x, Class<X> cx);
        
        public separable SValue joinApply(SValue sv1, SValue sv2) {
            if ((sv1 instanceof SIntValue) && (sv2 instanceof SIntValue)) {
                return new SIntValue(fun(sv1.intCodeValue(), sv2.intCodeValue(), Integer.class, 5, Integer.class));
            }
            else if ((sv1 instanceof SBooleanValue) && (sv2 instanceof SBooleanValue)) {
                return new SBooleanValue(fun(sv1.booleanCodeValue(), sv2.booleanCodeValue(), Boolean.class, 5, Integer.class));
            }
            else {
                return new SCodeValue(fun(sv1.codeValue(), sv2.codeValue(), Value.class, 5, Integer.class));
            }
        }
    }

    public static class If implements Exp {
        private Exp _test, _conseq, _alt;
        public If(Exp test, Exp conseq, Exp alt) {
            _test = test;
            _conseq = conseq;
            _alt = alt;
        }
        public separable SValue eval(final Env e, final FEnv f) {
            final Exp lfTest = _test; // hack to help Mint compiler
            return new SCodeValue(new JoinApplyFun() {
                public separable <Y,X extends Integer> Code<Y> fun(Code<Y> c1, Code<Y> c2, final Class<Y> cy, final X x, final Class<X> cx) {
                    return <|
                        ((`(edu.rice.cs.mint.util.Lift.lift(x)))</* gratuitious cast to X */ (X)new Integer(0))
                        ? (/* gratuitous cast to Y */ (Y) ( `(lfTest.eval(e,f).booleanCodeValue()) ? `c2 : `c1 )) 
                        : (/* gratuitous cast to Y */ (Y) ( `(lfTest.eval(e,f).booleanCodeValue()) ? `c1 : `c2 )) |>;
                }
            }.joinApply(_conseq.eval(e,f), _alt.eval(e,f)).codeValue());
        }
    }

    // interface to represent Value -> Value functions
    public static interface Fun {
        public Value apply(Value param);
    }

    // interface to represent SValue -> SValue functions
    public static interface SFun {
        public separable SValue apply(SValue param);
    }
    
    // convert Code of a Fun to an SFun
    public static separable SFun codeFunToSFun(final Code<? extends Fun> codeFun) {
        return new SFun() {
            public separable SValue apply(SValue param) {
                return new SCodeValue(<| (`codeFun).apply(`(param.codeValue())) |>);
            }
        };
    }

    // convert SFun to Code of a Fun
    public static separable Code<? extends Fun> sfunToCodeFun(final SFun sfun) {
        return <| new Fun() {
            public Value apply(Value param) {
                return `(sfun.apply(new SCodeValue(<| param |>)).codeValue());
            }
        } |>;
    }
        
    // environment; implemented as a function object from String to SValue
    public static interface Env {
        public separable SValue get(String y);
    }
    public static final Env env0 = new Env() {
        public separable SValue get(String s) {
            throw new EnvironmentException();
        }
    };
    
    public static separable Env ext(final Env env,
                                    final String x,
                                    final SValue v) {
        return new Env() {
            public separable SValue get(String y) {
                // error: if (x.equals(y))
                if (x==y)
                    return v;
                else
                    return env.get(y);
            }
        };
    }     
    
    // function environment; implemented as a function object from String to Code<SFun>
    public static interface FEnv {
        public separable SFun get(String s);
    }
    public static final FEnv fenv0 = new FEnv() {
        public separable SFun get(String s) {
            throw new EnvironmentException();
        }
    };
    
    public static separable FEnv fext(final FEnv fenv, final String x, final SFun v) {
        return new FEnv() {
            public separable SFun get(String y) {
                // error: if (x.equals(y))
                if (x==y)
                    return v;
                else
                    return fenv.get(y);
            }
        };
    }
    
    public static class Declaration {
        private String _fun, _param;
        private Exp _body;
        public Declaration(String fun, String param, Exp body) {
            _fun = fun;
            _param = param;
            _body = body;
        }
        public String fun() { return _fun; }
        public String param() { return _param; }
        public Exp body() { return _body; }
    }
    
    public static class Program {
        private Declaration[] _defs;
        private Exp _body;
        public Program(Exp body, Declaration ... defs) {
            _defs = defs;
            _body = body;
        }
        public Declaration[] defs() { return _defs; }
        public Exp body() { return _body; }
        
        public separable SValue peval(Env env, FEnv fenv) {
            return peval(env,fenv,0);
        }
        
        private separable SValue peval(final Env env, final FEnv fenv, int defIndex) {
            // match p with
            if (_defs.length<=defIndex) {
                //     Program ([],e) -> eval e env fenv
                return _body.eval(env,fenv);
            }
            else {
                //    |Program (Declaration (s1,s2,e1)::tl,e) ->
                //        let rec f x = eval e1 (ext env s2 x) (ext fenv s1 f)
                //        in peval (Program(tl,e)) env (ext fenv s1 f)
                final Declaration d = _defs[defIndex];
                final String dParam = d._param;
                final String dFun = d._fun;
                final Exp dBody = d._body;
                final Code<? extends Fun> codeFun = <| new Fun() {
                    public Value apply(Value param) {
                        final Fun ffun = this;
                        return `(dBody.eval(ext(env, dParam, new SCodeValue(<| param |>)),
                                            fext(fenv, dFun, codeFunToSFun(<| ffun |>))).codeValue());
                    }
                } |>;
                return peval(env,fext(fenv, dFun, new SFun() {
                    public separable SValue apply(final SValue x) {
                        return dBody.eval(ext(env, dParam, x),
                                          fext(fenv, dFun, codeFunToSFun(codeFun)));
                    }
                }),defIndex+1);
            }
        }
    }

    public static Program termFact = new Program
        (new App("f", new Val(new SIntValue(<|10|>))),
         new Declaration
             ("f", "x", new If
                  (new Equals(new Var("x"), new Val(new SIntValue(<|0|>))),
                   new Val(new SIntValue(<|1|>)),
                   new Mul(new Var("x"),
                           new App
                               ("f",
                                new Sub
                                    (new Var("x"),
                                     new Val(new SIntValue(<|1|>))))))));

    public static Program termFib = new Program
        (new App("f", new Val(new SIntValue(<|10|>))),
         new Declaration("f", "x", new If
                             (new Equals(new Var("x"), new Val(new SIntValue(<|0|>))),
                              new Val(new SIntValue(<|0|>)),
                              new If
                                  (new Equals(new Sub
                                                  (new Var("x"),
                                                   new Val(new SIntValue(<|1|>))),
                                              new Val(new SIntValue(<|0|>))),
                                   new Val(new SIntValue(<|1|>)),
                                   new Add
                                       (new App
                                            ("f",
                                             new Sub
                                                 (new Var("x"),
                                                  new Val(new SIntValue(<|1|>)))),
                                        new App
                                            ("f",
                                             new Sub
                                                 (new Var("x"),
                                                  new Val(new SIntValue(<|2|>)))))))));
    
    public static void main(String[] args) {
        SValue sv = termFact.peval(env0, fenv0);
        System.out.println(sv.codeValue());
        System.out.println("- - - - - - - - - - - - - - - - - - - - - - - - - - - -");
        System.out.println("fact(10) = "+sv.codeValue().run());
        System.out.println("=======================================================");
        SValue sv2 = termFib.peval(env0, fenv0);
        System.out.println(sv2.codeValue());
        System.out.println("- - - - - - - - - - - - - - - - - - - - - - - - - - - -");
        System.out.println("fib(10) = "+sv2.codeValue().run());
    }
}
