/* Copyright 2011, Rice University.  All rights reserved.
   No warranty of usability express or implied.  Have a lovely day! */
#include <stdlib.h>
#include <stdint.h>
#include <stdio.h>
#include "barrier.h"


static unsigned int log2_ceil(unsigned int num) {
  unsigned int res = 0;
  unsigned v = num; 
  if (num == 0 ) return 0;
  if (num == 1) return 1;
  
  while (v/2 > 0) {
    v /= 2; 
    res++;
  }

  if (1<<res != num) res++;

  return res;
}


inline static dissem_barrier_node * barrier_node(dissem_barrier_node * b, unsigned int round, unsigned int tid, unsigned int thrs) {
  return (b + (round * thrs) + tid);
}

int dissem_barrier_create(unsigned int num_threads, dissem_barrier * b) {
  unsigned int t,r;
  dissem_barrier_node  * nodes;
  unsigned int rounds = log2_ceil(num_threads);
  
/* Allocate memory for a dissem_barrier_node[log num_threads][num_threads] array */
  void * mem = malloc(rounds*num_threads*sizeof(dissem_barrier_node)+LINESIZE);
  if (mem==NULL) return 1;
  
  /* Align to LINESIZE*/
  nodes = (dissem_barrier_node *) ((((intptr_t)mem)+LINESIZE)&(-(LINESIZE-1)));

  /* Initialize barrier */  
  for (r=0; r < rounds; ++r) {
    for (t=0; t < num_threads; ++t) {
      /* Set partner and initialize flags */
      barrier_node(nodes, r, t, num_threads)->data.partner = barrier_node(nodes, r, (t + (1<<r))%num_threads, num_threads);
      barrier_node(nodes, r, t, num_threads)->data.flag[0] = 0;
      barrier_node(nodes, r, t, num_threads)->data.flag[1] = 0;
    } 
  }

  b->mem_to_free = mem;
  b->nodes = nodes;

  return 0;
}

void dissem_barrier_destroy(dissem_barrier * b) {
  free(b->mem_to_free);
  b->mem_to_free = 0;
  b->nodes = 0;
}

void dissem_barrier_thread_local_data_init(dissem_barrier_thread_local_data * d, dissem_barrier * b, unsigned int thread_id, unsigned int num_threads) {
  d->nodes = b->nodes;
  d->thread_id = thread_id;
  d->num_threads = num_threads;
  d->log_num_threads = log2_ceil(num_threads);
  d->sense = 1; 
  d->parity = 0;
}

void dissem_barrier_wait(dissem_barrier_thread_local_data * d) {
  unsigned int log_num_threads = d->log_num_threads;  
  unsigned int num_threads = d->num_threads;
  unsigned int thread_id = d->thread_id;
  unsigned char parity = d->parity;
  unsigned char sense = d->sense;
  dissem_barrier_node * nodes = d->nodes;
  
  unsigned int r;
  
  for (r=0; r < log_num_threads; ++r) {
    /* Get current barrier node */
    dissem_barrier_node * c = barrier_node(nodes, r, thread_id, num_threads);
    /* Set partner's flag */
    c->data.partner->data.flag[parity] = sense;
    /* Spin until flag is set */ 
    while (c->data.flag[parity] != sense) fflush(stderr);
  }

  /* Reverse parity and sense */
  if (parity == 1)
    d->sense = 1 - sense;

  d->parity = 1 - parity;
}
