#
# Copyright (c) 2011, EPFL (Ecole Politechnique Federale de Lausanne)
# All rights reserved.
#
# Created by Marco Canini, Daniele Venzano, Dejan Kostic, Jennifer Rexford
# Contributed to this file: Peter Peresini, Maciej Kuzniar
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#   -  Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#   -  Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#   -  Neither the names of the contributors, nor their associated universities or
#      organizations may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

USE_MD5 = False

import sys
import signal
import random
import cPickle
if USE_MD5:
	import hashlib
import platform
import time
import logging
log = logging.getLogger("nice.mc")
from stats import getStats
stats = getStats()
import utils

from lib.action import Action
from lib.strategies import RandomWalk
from lib.strategies.dpor import DynamicPartialOrderReduction

import invariants.invariant_dispatcher as Invariants

class ModelChecker:
	def __init__(self, options, symbolic_options):
		self.options = options
		self.model_class = options.get("model.class")
		self.strategy_class = options.get("strategy.class")
		self.strategy = None
		self.useDpor = options.get("strategy.dpor")
		self.dpor = None
		self.good_transitions_count = 0
		self.model = None # The model instance
		self.state_stack = None
		self.unique_states = set()
		self.unique_states_count = 0
		self.old_states_count = 0
		self.max_path_length = 0
		self.total_last_sec = 0
		self.sym_engine_call_count = 0
		self.start_time = None
		self.debug = options.get("runtime.replay_debug")
		if self.strategy_class == RandomWalk:
			if options.get("randomwalk.seed") == -1:
				self.random_seed = int(time.time())
			else:
				self.random_seed = options.get("randomwalk.seed")
		self.fault_injections = options.get("model.faults")
		if options.get("runtime.graph") != None:
			from model_checker_graph import ModelCheckingGraph
			self.graph = ModelCheckingGraph(options.get("runtime.graph"), self)
		else:
			self.graph = None

		self.quiet = options.get("runtime.quiet")
		if options.get("model.cutoff") != -1:
			self.path_length_cutoff = options.get("model.cutoff")
		else:
			self.path_length_cutoff = sys.maxint

		self.initial_inputs = []
		self.initial_state = None

		self.queue_inputs = symbolic_options["queues"][0]
		self.queue_states = symbolic_options["queues"][1]
		self.generate_inputs = self.model_class.generate_inputs
		self.generate_stats = self.model_class.generate_stats

		self.contr_app_states_stats = {} # States we already explored symbolically for stats

		Invariants.model_checker = self
		for i in self.model_class.invariants:
			invariant = i()
			Invariants.registerInvariant(invariant)

		if self.options.get("runtime.progress"):
			print "MC init complete"

	def loadInitialInputs(self, inputs):
		self.initial_inputs = inputs

	def initializeSE(self):
		self.sym_engine_call_count += 1
		self.queue_states.put("INITIAL")
		self.queue_states.join()
		self.sendMacsToSymExec(self.model)
		self.queue_states.join()

	def generateInputs(self, state):
		self.sym_engine_call_count += 1
		self.queue_states.put(state)
		self.queue_states.join()
		inputs = self.queue_inputs.get()
		inputs = cPickle.loads(inputs)
		return inputs

	def generatePortStats(self, state):
		assert state != "INITIAL"
		self.sym_engine_call_count += 1
		self.queue_states.put(state)
		self.queue_states.join()
		inputs = self.queue_inputs.get()
		inputs = cPickle.loads(inputs)
		return inputs

	def sendMacsToSymExec(self, model):
		macs = model.getClientMacAddresses()
		l = len(macs)
		self.queue_states.put(l)
		self.queue_states.join()
		for m in macs:
			self.queue_states.put(cPickle.dumps(m), -1)
			self.queue_states.join()

	def start(self):
		self.start_time = time.clock()
		self.strategy = self.strategy_class(self)
		if self.useDpor:
			self.dpor = DynamicPartialOrderReduction(self, self.strategy)

		if isinstance(self.strategy, RandomWalk):
			self.strategy.setSeed(self.random_seed)
		if self.options.get("runtime.progress") and not platform.system() == 'Windows':
			signal.signal(signal.SIGALRM, self.printStats)
			signal.setitimer(signal.ITIMER_REAL, 0.1, 1)
		self.gotoInitialState()
		self.initial_state = self.currentStateHash()
		if self.queue_states is not None:
			self.initializeSE()
		self.state_stack = [[[], self.getEnabledActions(), self.initial_state, []]]
		self.modelCheck()

	def getEnabledActions(self):
		return self.strategy.getEnabledActions()
	
	def onEnableAction(self, node, action):
		return self.strategy.onEnableAction(node, action)
	
	def chooseAction(self, enabled_actions):
		if self.useDpor:
			return self.dpor.chooseAction(enabled_actions)
		else:
			return self.strategy.chooseAction(enabled_actions)

	def checkManageFaults(self, next_action):
		# These transitions will lower the budget available for fault injections
		if next_action.target in ["drop_packet", "duplicate_packet", "reorder_packet"]:
			doable_faults = self.model.fault_injection_count
			doable_faults -= 1
			self.model.setFaultInjectionCount(doable_faults)

	def modelCheck(self):
		backtrack = False
		path_end_cached_state = False
		while len(self.state_stack) > 0:
			log.warning("------> starting model checker transition (stack len: %d) <-------" % len(self.state_stack))
			replay_list, enabled_actions, state_hash, state_list = self.state_stack.pop()
			log.debug("Enabled actions: %s" % enabled_actions)

			# replay the list of actions if necessary
			if backtrack:
				log.info("Backtracking, previous path early termination: %s" % str(path_end_cached_state))
				Invariants.testPoint("path_end", model=self.model, cached_state=path_end_cached_state)
				path_end_cached_state = False
				if self.useDpor:
					stats.pushProfile("dpor")
					pop = self.dpor.startBacktracking(replay_list)
					stats.popProfile()
					if pop:
						path_end_cached_state = True
						continue
				self.replayActions(replay_list, state_list)

			next_action = self.chooseAction(enabled_actions)
#			log.debug("Next action: " + str(next_action))

			# when dpor decided that the available enabled actions do not need to be executed
			if next_action is None:
				backtrack = True
				path_end_cached_state = True
				continue

			if backtrack:
				if self.debug and self.currentStateHash() != state_hash:
					utils.crash("State after replay is different than original")
				for a in [next_action] + enabled_actions:
					assert a in self.getEnabledActions() or a.node_name == "model_checker"
			backtrack = True

			# add the current state back to the stack, if there is still something to explore
			if len(enabled_actions) > 0:
				self.state_stack.append([replay_list, enabled_actions, state_hash, state_list])			

			if next_action.target == "port_stats_special":
				self.portStatsSpecial(next_action.args[0])

			# Performing the transition would make us to go over the cutoff limit, backtrack
			if len(replay_list)+1 > self.path_length_cutoff:
				continue

			self.checkManageFaults(next_action)

			if self.graph != None:
				self.graph.startTransition(self.model)

			if not isinstance(self.strategy, RandomWalk):
				assert next_action in self.getEnabledActions() or next_action.node_name == "model_checker"

			# Info needed for invariant violations
			self.current_replay_list = replay_list
			self.next_action = next_action

			Invariants.testPoint("transition_start", model=self.model)

			# Execute the transition

# xiaoye
			'''
				let executeActing return a string
			'''
			information = self.model.executeAction(next_action, len(replay_list)+1)

			if self.graph != None:
# xiaoye
				transition_name = str(next_action).replace(":", "-") + information

				self.graph.endTransition(self.model, transition_name)

			Invariants.testPoint("transition_end", model=self.model)
			if Invariants.checkNewViolations(): # Check if we need to go on or backtrack
				# backtrack
				continue

			for h in self.model.clients:
				if hasattr(h, "enableDiscoveryIfNeeded"):
					if self.graph != None:
						self.graph.startTransition(self.model)
					h.enableDiscoveryIfNeeded()
					if self.graph != None:
						self.graph.endTransition(self.model, "enable_discovery")

			if self.max_path_length < len(replay_list)+1:
				self.max_path_length = len(replay_list)+1
#				self.max_path = replay_list

			self.good_transitions_count += 1
			state_hash = self.currentStateHash()

			# It we reached a new state, hash it and put on top of the stack
			if state_hash not in self.unique_states:
				self.unique_states_count += 1

				# Xiaoye: print out the state id
				#if (self.unique_states_count == 65):
				#	print state_hash
				#	print "======================================================"
				#	print "state #: ", self.unique_states_count
				#	for node in self.model.clients:
				#		print node.__getstate__()
				#	for node in self.model.switches:
				#		print node.__getstate__()
				#	for node in [self.model.controller]:
				#		print node.__getstate__()
				#if (self.unique_states_count == 73):
				#	print state_hash
				#	print "======================================================"
				#	print "state #: ", self.unique_states_count
				#	for node in self.model.clients:
				#		print node.__getstate__()
				#	for node in self.model.switches:
				#		print node.__getstate__()
				#	for node in [self.model.controller]:
				#		print node.__getstate__()

				self.unique_states.add(state_hash)
				new_actions = self.getEnabledActions()
				current_list = replay_list + [next_action]
				if self.debug:
					state_list = state_list + [self.model.nodesState()]
				if len(new_actions) > 0:
					self.state_stack.append([current_list, new_actions, state_hash, state_list])
					backtrack = False
			else:
				path_end_cached_state = True
				self.old_states_count += 1

	def printStats(self, sig, frame):
		total_per_sec = self.good_transitions_count - self.total_last_sec
		sys.stdout.write("Total: %d, unique: %d, revisited: %d, max path len: %d, violations: %d, SE calls: %d (%d tr/sec)\r" % \
				(self.good_transitions_count, self.unique_states_count, self.old_states_count, self.max_path_length, Invariants.countViolations(), self.sym_engine_call_count, total_per_sec))
		sys.stdout.flush()
		self.total_last_sec = self.good_transitions_count

	def gotoInitialState(self):
		del self.model
		self.model = self.model_class(self.options)
		self.model.initTopology(None)
		self.model.setFaultInjectionCount(self.fault_injections)

		random.seed("NSL")
		self.model.start(self)
#		import sys; sys.stdin = file("/dev/stdin", "r")
#		import pdb; pdb.set_trace()
		self.strategy.visitModel(self.model)
		if self.useDpor:
			self.dpor.updateDependencies(self.model)
			for node in self.model.getNodes():
				node.communicationObjectUsed = self.dpor.communicationObjectUsed
				node.startActionExecution = self.dpor.startActionExecution
				node.finishActionExecution = self.dpor.finishActionExecution

		Invariants.testPoint("path_start", model=self.model)

	def currentStateHash(self):
		stats.pushProfile("state hashing")
		model_state = self.model.serializedState()
		invariants_state = Invariants.serializedState()
		if USE_MD5:
			h = hashlib.md5()
			h.update(model_state)
			h.update(invariants_state)
			h = h.hexdigest()
		else:
			h = hash(model_state + invariants_state)
		stats.popProfile()
		return h

	def replayActions(self, action_list, state_list):
		log.info("Start replay")
		stats.pushProfile("replay actions")
		self.gotoInitialState()
		initial_state = self.currentStateHash()
		if self.initial_state != initial_state:
			utils.crash("Different initial state")
		d = 1
		for a in action_list:
			if self.useDpor:
				self.dpor.replayAction(a)

			if a.target == "port_stats_special":
				self.portStatsSpecial(a.args[0])

			if not a in self.model.getAllEnabledActions():
				log.error("Action: %s" % a)
				self.model.printEnabledActions()
				#import pdb; pdb.set_trace()
				utils.crash("Action not available during replay")
			else:
				self.checkManageFaults(a)
				self.model.executeAction(a, d)
				
				for h in self.model.clients:
					if hasattr(h, "enableDiscoveryIfNeeded"):
						h.enableDiscoveryIfNeeded()

			d += 1
			if self.debug:
				import pickletools
				from pprint import pprint
				expected_state = state_list[d-2]
				state = self.model.nodesState()
				problem = False
				for k in state:
					es = expected_state[k]
					cs = state[k]
					if es != cs:
						pickletools.dis(es, out=file('es.ds', 'w'))
						pickletools.dis(cs, out=file('cs.ds', 'w'))
						log.error("State of %s during replay is different than original" % k)
						ess = cPickle.loads(es)
						pprint(ess.__dict__)
						css = cPickle.loads(cs)
						pprint(css.__dict__)
						problem = True
				if problem:
					raise RuntimeError
				else:
					log.debug("Replay state: ok")
		stats.popProfile()
		log.info("End replay")

	def portStatsSpecial(self, dp_id):
		state = self.model.getControllerAppState()
		ser_state = cPickle.dumps((dp_id, state), -1)
		if ser_state not in self.contr_app_states_stats:
			self.contr_app_states_stats[ser_state] = self.generatePortStats(ser_state)
		else:
			#print "stats from cache"
			pass

		inputs = self.contr_app_states_stats[ser_state]
		assert len(inputs) > 0
		for inp in inputs:
			# FIXME see if it is possible to re-use the enableAction() function ?
			a = Action("ctrl", "port_stats", (dp_id, inp["ports"]))
			self.model.controller.enabled_actions.append(a)


class DebugModelChecker(ModelChecker):
	def __init__(self, options, symbolic_options):
		ModelChecker.__init__(self, options, symbolic_options)
		self.action_list = []
		self.action_cache = []

	def chooseAction(self, enabled_actions):
		print "Actions taken:"
		print self.action_list
		print "Choose next action:"
		for i, e in enumerate(enabled_actions):
			print i, "--", e
			if e.target == "process_command":
				print "CMD:", self.model.nodes[e.node_name].command_queue[0]
			elif e.target == "process_packet":
				for p in self.model.nodes[e.node_name].ports.values():
					if len(p.in_buffer) > 0:
						print "PKT:", p.in_buffer[0]
		if self.action_cache == []:
			actions = sys.stdin.readline().strip()
			for action in actions.split(","):
				action_id = int(action.strip())
				self.action_cache.append(action_id)
		assert len(self.action_cache) > 0

		action_id = self.action_cache.pop(0)
		self.action_list.append(action_id)
		return enabled_actions.pop(action_id)

