//===----------------------------------------------------------------------===//
//
//                    The PACE Application Aware Partitioner
//
// Copyright (C) 2009 - 2010, ET International, Inc. All rights reserved.
//
// The information and source code contained herein is the exclusive property
// of ET International, Inc. and may not be disclosed, examined or reproduced
// in whole or in part without explicit written authorization from the company.
//
// This software was produced under a U.S. Government contract with the Air
// Force Research Lab. The U.S. Government is licensed to use, reproduce,
// modify, and distribute this software for use within the U.S. Government.
// These rights are equivalent to:
// GOVERNMENT PURPOSE RIGHTS, CONTRACT: F33615-09-C-7915
//
//===----------------------------------------------------------------------===//

#include <assert.h>

#include <fstream>
#include <iostream>
#include <sstream>
#include <stack>

#include "analysis/arraypadding.h"
#include "utils/options.h"

#include <clang/AST/Type.h>

#include <llvm/Support/Debug.h>
#include <llvm/ADT/EquivalenceClasses.h>

using namespace aap;
using namespace std;

// a mapping from array names to RPUs that use them
static map<string,set<string> >   rpuUses_;

static string kSafePaddingPragma = "AAP ArrayPadding safe";

namespace {
/// Represents an array under consideration for padding.
class Array
{
public:
    Array(std::string n, unsigned d, unsigned o)
        : name(n), dimensions(d), offset(o) { }

    std::string name;           //< The array's name.
    bool global;                //< True for globally scoped variables.
    unsigned dimensions;        //< The number of array dimensions.
    unsigned offset;            //< Position within a function.

    bool operator<(const Array& other) const
    {
        return name < other.name;
    }

    /// Returns the string used to identify the array within a function.
    /// This value differentiates arrays that shadow one another by including
    /// the line offset after the name (separated by a colon).
    std::string localId(void) const
    {
        std::ostringstream oss;
        oss << name  << ":" << offset;
        return oss.str();
    }
};


/// Represents a current program scope. New instances are placed at the top of
/// a stack, accessible through the top pointer. Deleting an object
/// automatically pops it off of the stack.
class Scope
{
public:
    Scope(const std::string& n)
    {
        name = n;
        next = top;
        if (n.empty() && top) function = top->function;
        else function = this;
        top = this;
    }

    ~Scope()
    {
        top = next;
    }

    /// Moves information about arrays that are not declared in the current
    /// scope to the parent scope. Eventually, they will propagate to the scope
    /// in which they are declared.
    void mergeArraysToParent(void)
    {
        if (!next) return;
        // Put each safe array in the parent scope (unless it's unsafe there).
        std::set<std::string>::iterator arrayI, end = safeArrays.end();
        for (arrayI = safeArrays.begin(); arrayI != end; ++arrayI) {
            // If the array is declared locally, it's irrelevant to parent
            if (arrays.count(*arrayI)) continue;
            // If the isn't already unsafe in the parent, mark is safe.
            if (!next->unsafeArrays.count(*arrayI)) {
                DEBUG(llvm::dbgs() << "merging to parent " << *arrayI << "\n");
                next->safeArrays.insert(*arrayI);
            }
        }

        end = unsafeArrays.end();
        for (arrayI = unsafeArrays.begin(); arrayI != end; ++arrayI) {
            if (next->arrays.count(*arrayI)) continue;
            next->unsafeArrays.insert(*arrayI);
        }
    }

    /// Moves the arrays declared in @param scope into the function Scope
    /// containing it, as appropriate.
    void mergeArraysToFunction(void)
    {
        if (!function) return;
        // Check each array against anything already associated with the
        // function.
        std::map<std::string,Array*>::iterator arrayI, end = arrays.end();
        for (arrayI = arrays.begin(); arrayI != end; ++arrayI) {
            Array* array = arrayI->second;
            // If the array isn't safe for padding, we don't care about it.
            if (!safeArrays.count(array->name)) continue;
            // We're only handling arrays declared in this scope.
            if (!arrays.count(array->name)) continue;
            DEBUG (llvm::dbgs() << "merging to function "
                   << array->localId() << "\n");
            // if (array->global) global.nestedArrays.insert(array->localId());
            // else
            function->nestedArrays.insert(array->localId());
            safeArrays.erase(array->name);
        }
    }

    std::string name;
    std::set<std::string> safeIndexes;
    std::set<std::string> unsafeIndexes;
    std::set<std::string> declaredVariables;
    std::set<std::string> safeArrays;
    std::set<std::string> unsafeArrays;
    /// Function scopes hold the finalized safe arrays from nested scopes.
    std::set<std::string> nestedArrays;
    /// Mapping from index variables to names of arrays that use them.
    std::map<std::string, std::set<std::string> > uses;

    /// Mapping from array names to actual objects.
    std::map<std::string, Array*> arrays;

    static Scope  global;       // The global scope (bottom of stack).
    static Scope* top;          // The current top of the stack.
    Scope* next;                // The next scope on the stack.
    Scope* function;            // The next scope representing a function.
};

Scope  Scope::global("<global>");
Scope* Scope::top;

} // anonymous namespace

// A map of function names (globally unique) to the scopes that represented
// them. This map is filled as function scopes are popped, and no longer useful.
static std::map<std::string, Scope*> functionScopes;

void
ArrayPadding::pop (void)
{
    if (!kPaddingActive) return;

    std::set<string> unsafe;

    assert (Scope::top);

    Scope *scope = Scope::top;
    DEBUG (llvm::dbgs() << "pop " << scope->name << "\n");

    Scope* parent = scope->next;

    // for each unsafe index variable U
    //   if U is declared in scope or there is no parent: mark all uses unsafe
    //   else: move the uses up to the next scope & note that it's unsafe
    set<string>::iterator index, indexE = scope->unsafeIndexes.end();
    for (index = scope->unsafeIndexes.begin(); index != indexE; ++index) {
        if (parent && !scope->declaredVariables.count(*index))
        {
            parent->uses[*index].insert (scope->uses[*index].begin(),
                                         scope->uses[*index].end());
            parent->unsafeIndexes.insert (*index);
            parent->safeIndexes.erase (*index);

        } else if (scope->uses.count(*index)) {

            set<string>::iterator use, useE = scope->uses[*index].end();
            for (use = scope->uses[*index].begin(); use != useE; ++use) {
                scope->safeArrays.erase(*use);
                scope->unsafeArrays.insert(*use);
                DEBUG (llvm::dbgs() << *use << " uses unsafe index "
                       << *index << "\n");
            }
        }
    }


    // for each safe index variable S
    //   if S is not declared in scope, propagate its uses and safety
    //   unless: there is no parent: mark all uses unsafe
    indexE = scope->safeIndexes.end();
    for (index = scope->safeIndexes.begin(); index != indexE; ++index) {
        if (scope->declaredVariables.count(*index)) continue;

        if (!parent) {
            set<string>::iterator use, useE = scope->uses[*index].end();
            for (use = scope->uses[*index].begin(); use != useE; ++use) {
                scope->safeArrays.erase(*use);
                DEBUG (llvm::dbgs() << *use << " uses index "
                                    << *index << " at global scope\n");
            }
        } else {
            parent->uses[*index].insert (scope->uses[*index].begin(),
                                         scope->uses[*index].end());
            if (0 == parent->unsafeIndexes.count (*index))
                parent->safeIndexes.insert (*index);
        }
    }

    // If the scope represents a function, we want to hang onto it to keep a
    // record of the safe arrays declared in it.
    if (scope->name.empty()){
        scope->mergeArraysToParent();
        scope->mergeArraysToFunction();
        delete scope;
    } else {
        Scope::top = scope->next;
        // Prune down the scope object as much as possible.
        scope->next = NULL;
        scope->unsafeArrays.clear();
        scope->safeIndexes.clear();
        scope->unsafeIndexes.clear();
        functionScopes[scope->name] = scope;
    }
}

static std::string
nonFunctionName (const std::string& name)
{
    if (name.compare (0, 9, "function ") == 0)
        return name.substr(9, string::npos);
    else
        return name;
}

void
ArrayPadding::push (const std::string& n)
{
    std::string name = nonFunctionName(n);
    if (kPaddingActive) {
        new Scope(name);        // Automatically pushes the scope.
        DEBUG (llvm::dbgs() << "push scope " << name << "\n");
    }
}

void
ArrayPadding::registerSafeIndex (const string& index)
{
    if (kPaddingActive && Scope::top) {
        if (Scope::top->unsafeIndexes.count (index) == 0)
            Scope::top->safeIndexes.insert (index);
        DEBUG (llvm::dbgs() << index << " is a safe index\n");
    }
}

void
ArrayPadding::markAsUnsafeIndex (const string& name)
{
    if (!kPaddingActive || name.empty() || !Scope::top) return;

    Scope::top->safeIndexes.erase (name);
    if (0 == Scope::top->unsafeIndexes.count(name)) {
        Scope::top->unsafeIndexes.insert (name);
        DEBUG (llvm::dbgs() << name << " is unsafe for indexing\n");
    }
}

void
ArrayPadding::registerGlobalArray(const string& name, int dimension,
                                  int offset)
{
    if (!kPaddingActive || Scope::global.unsafeArrays.count (name)) return;

    Array* array = new Array(name, dimension, offset);
    array->global = true;
    Scope::top->arrays[name] = array;
    Scope::top->safeArrays.insert(name);
    DEBUG (llvm::dbgs() << name << " might be paddable globally\n");
}

void
ArrayPadding::registerLocalArray(const string& name, int dimension,
                                 int offset)
{
    assert(Scope::top && "paddable array without a scope in place");
    if (!kPaddingActive || Scope::top->unsafeArrays.count (name)) return;

    Scope::top->arrays[name] = new Array(name, dimension, offset);
    Scope::top->safeArrays.insert(name);
    DEBUG (llvm::dbgs() << name << " might be paddable locally\n");
}

void
ArrayPadding::registerUnsafeArray(const clang::ValueDecl* decl)
{
    if (!isa<clang::ConstantArrayType>(decl->getType().getTypePtr())) return;

    assert(Scope::top && "unpaddable array without a scope in place");
    llvm::StringRef name = decl->getName();
    if (!kPaddingActive || name.empty()) return;

    DEBUG (llvm::dbgs() << name << " to be removed\n");

    Scope *scope = Scope::top;
    while (scope) {
        if (scope->safeArrays.count(name)) {
            DEBUG (llvm::dbgs() << name << " is not safe for padding\n");
            scope->safeArrays.erase(name);
            scope->unsafeArrays.insert(name);
            return;
        }
        scope = scope->next;
    }
}

void
ArrayPadding::recordArrayIndex (const clang::ValueDecl* arrayDecl,
                                const string& index)
{
    if (!kPaddingActive || index.empty()) return;

    llvm::StringRef arrayName = arrayDecl->getName();

    if (!Scope::top) {
        registerUnsafeArray(arrayDecl);
        return;
    }

    if (0 == Scope::top->uses[index].count (arrayName)) {
        Scope::top->uses[index].insert (arrayName);
        DEBUG (llvm::dbgs() << arrayName << " uses " << index << "\n");
    }
}


// (sub-procedure for ArrayPadding::report)
static void
printEquivClass (ostream& os, const llvm::EquivalenceClasses<string>& EC)
{
    os << "<rpu-dependencies>\n";
    llvm::EquivalenceClasses<string>::iterator I, E;
    for (I = EC.begin(), E = EC.end(); I != E; ++I) {
        if (!I->isLeader()) continue;
        // Loop over members in this set.
        os << "\t""<rpu-group>\n";
        llvm::EquivalenceClasses<string>::member_iterator MI, ME;
        for (MI = EC.member_begin(I), ME = EC.member_end(); MI != ME; ++MI)
            os << "\t\t<rpu name=\"" << *MI << "\" />\n";
        os << "\t""</rpu-group>\n";
    }
    os << "</rpu-dependencies>\n";
}

/// Output is directed into array-padding.xml in the output directory. No
/// output will be generated if paddingActive() returns false.
void
ArrayPadding::report (void)
{
    if (!paddingActive) return;

    // Classes of RPU's that are equivalent with respect to array padding - the
    // application of array padding must be consistent within each class. The
    // take away is that, if one RPU in the class is recompiled with different
    // array padding options, the entire class should be recompiled.
    llvm::EquivalenceClasses<string> rpuClasses;

    ofstream os;
    string path = Options::outPath() + "/array-padding.xml";
    os.open (path.c_str());

    os << "<?xml version=\"1.0\"?>\n""<array-padding>\n""<array-uses>\n";

    std::set<std::string>& safeArrays = Scope::top->safeArrays;
    set<string>::const_iterator array_it, array_end, rpu_it, rpu_end;
    array_end = safeArrays.end();
    // Iterate over all of the safe arrays. For each one, list the RPU's in
    // which the array is used while building the equivalency for those RPUs.
    for (array_it = safeArrays.begin(); array_it != array_end; ++array_it) {

        set<string> rpus = rpuUses_[*array_it];
        string prevRpu;

        if (rpus.empty()) continue;

        os << "\t<array name=\"" << *array_it << "\"";
        os << " type=\"global\"";
        os << ">\n";

        rpu_end = rpus.end();
        for (rpu_it = rpus.begin(); rpu_it != rpu_end; ++rpu_it) {
            os << "\t\t<rpu name=\"" << *rpu_it << "\" />\n";
            if (prevRpu.length())
                rpuClasses.unionSets(*rpu_it, prevRpu);
            else
                rpuClasses.insert(*rpu_it);
            prevRpu = *rpu_it;
        }
        os << "\t</array>\n";
    }
    os << "</array-uses>\n";

    printEquivClass (os, rpuClasses);

    os <<"</array-padding>\n";
}

bool
ArrayPadding::shouldPadArray (const string& name)
{
    return kPaddingActive && Scope::top->safeArrays.count (name);
}

/// The names of arrays declared in functions are prefixed with the globally
/// unique name of the function (which, for static functions may reference the
/// file name) and followed by the offset position, separated with a ':'.
set<string>
ArrayPadding::safeArrays (void)
{
    std::set<std::string> safeArrays = Scope::top->safeArrays;
    std::map<std::string,Scope*>::iterator scopeI, sEnd = functionScopes.end();
    for (scopeI = functionScopes.begin(); scopeI != sEnd; ++scopeI) {
        std::string function = scopeI->first;
        Scope* scope = scopeI->second;
        std::set<std::string>::iterator arrayI, aEnd = scope->safeArrays.end();
        for (arrayI = scope->safeArrays.begin(); arrayI != aEnd; ++arrayI) {
            string name = function + ":" + scope->arrays[*arrayI]->localId();
            safeArrays.insert(name);
        }
        aEnd = scope->nestedArrays.end();
        for (arrayI = scope->nestedArrays.begin(); arrayI != aEnd; ++arrayI)
            safeArrays.insert(function + ":" + *arrayI);
    }

    return safeArrays;
}

void
ArrayPadding::associateArrayWithRPU (const string& array, const string& rpu)
{
    if (!paddingActive()) return;
    if (!Scope::top->safeArrays.count(array)) return;

    rpuUses_[array].insert (rpu);
}

int
ArrayPadding::arrayDimensions (const std::string& arrayName)
{
    if (!paddingActive()) return 0;
    if (!Scope::top->safeArrays.count(arrayName)) return 0;

    return Scope::top->arrays[arrayName]->dimensions;
}

void
ArrayPadding::registerDeclaration (const std::string& var,
                                   int offset)
{
    if (!paddingActive() || !Scope::top) return;
    Scope::top->declaredVariables.insert (var);
}

std::set<PragmaInfo>
ArrayPadding::paddingPragmasForFunction(const std::string& function)
{
    std::set<PragmaInfo> pragmas;
    std::string name = nonFunctionName(function);
    if (functionScopes.count(name)) {
        Scope* scope = functionScopes[name];
        std::set<std::string>::iterator arrayI, end = scope->safeArrays.end();
        for (arrayI = scope->safeArrays.begin(); arrayI != end; ++arrayI) {
            int offset = scope->arrays[(*arrayI)]->offset;
            pragmas.insert(safePaddingPragma(*arrayI, offset));
        }
    }
    return pragmas;
}

PragmaInfo
ArrayPadding::safePaddingPragma(const std::string& array, int offset)
{
    return PragmaInfo(kSafePaddingPragma + " " + array, offset);
}
