#
# Copyright (c) 2011, EPFL (Ecole Politechnique Federale de Lausanne)
# All rights reserved.
#
# Created by Marco Canini, Daniele Venzano, Dejan Kostic, Jennifer Rexford
# Contributed to this file: Peter Peresini, Maciej Kuzniar
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#   -  Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#   -  Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#   -  Neither the names of the contributors, nor their associated universities or
#      organizations may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import copy

from lib.node import Node
from invariants.invariant_dispatcher import testPoint
from nox.lib import openflow

import logging
log = logging.getLogger("nice.mc.ctrl")

class Controller(Node):
	"""Interface for a controller implementation"""
	def __init__(self, name, component_class, ctxt, max_callbacks=0):
		Node.__init__(self, name, 0)
		self.component_object = component_class(ctxt=ctxt)
		self.connection_count = 0
		self.in_connections = {}
		self.callbacks = []
		self.start_callbacks = []
		self.max_callbacks = max_callbacks

	@property
	def component(self):
		self.communicationObjectUsed(self, "ctrl_component")
		return self.component_object

	def start(self, model_checker):
		Node.start(self, model_checker)
		for callback in self.start_callbacks:
			callback()
		self.start_callbacks = None

	def install(self):
		self.component.install()

	def enqueueQuery(self, dp_id, buffer_id, packet, inport, reason):
		self.in_connections[dp_id][1].append((buffer_id, packet.copy(), inport, reason))
		self.enableAction("process_switch_query", dp_id, skip_dup=True)

	def portStatsRequest(self, dp_id, ports):
		self.enableAction("port_stats_special", (dp_id))

	def port_stats_special(self, dp_id):
		return True

	def port_stats(self, dp_id, stats):
		self.component.port_stats_in_cb(dp_id, copy.deepcopy(stats))
		self.enabled_actions[:] = [x for x in self.enabled_actions if x.target != "port_stats"]
		return False

	def postCallback(self, time, callback):
		if (len(self.callbacks) >= self.max_callbacks):
			# sorry, too many callbacks
			return
		self.enableAction("process_callback", len(self.callbacks))
		self.callbacks.append(callback)

	def process_callback(self, callback_id):
		# for now do not erase old callbacks, it is not worth it, the list is small anyway
		self.callbacks[int(callback_id)]()
		return True

	def process_switch_query(self, dp_id):
		buffer_id, packet, inport, reason = self.in_connections[dp_id][1].pop(0)
		testPoint("before_cnt_packet_in", buffer_id=buffer_id, packet=packet, inport=inport, reason=reason)
		length = 0
		result = self.component.packet_in_cb(dp_id, inport, reason, length, buffer_id, packet)
		testPoint("after_cnt_packet_in", controller=self.component, packet=packet, return_value=result)
		if reason == openflow.OFPR_ACTION:
			testPoint("packet_received", receiver=self, packet=packet, port=inport)

		# NOTE: this code is specific to the loadbalancer
		if isinstance(self, LoadBalancerController) and self.callback0_count < 2:
			if packet.type == packet.ARP_TYPE and packet.dst == (0x11, 0x22, 0x33, 0x44, 0x55, 0x66):
				self.enableAction("process_callback", 0, skip_dup=True)
				self.count_packets += 1

		rv = []
		rv.append(len(self.in_connections[dp_id][1]) == 0)
		rv.append("id="+packet.packet_id)
		return rv
		#return len(self.in_connections[dp_id][1]) == 0

	def addSwitch(self, switch):
		# pass empty statistics when a new switch is connected
		stats = {'ports' : []}
		if self.component.datapath_join_cb != None:
			self.component.datapath_join_cb(switch.getOpenflowID(), stats)
		self.in_connections[switch.getOpenflowID()] = (switch, [])
		self.connection_count += 1

	def isSameMicroflow(self, packet1, packet2):
		raise NotImplementedError
	
	def __getstate__(self):
		filtered_dict = Node.__getstate__(self)
		filtered_dict["component"] = self.component.__getstate__()

		filtered_dict["in_connections"] = {}
		keys = self.in_connections.keys()
		keys.sort()
		for j in keys:
			filtered_dict["in_connections"][j] = self.in_connections[j][1] # get only the buffer

		return filtered_dict

	def getControllerAppState(self):
		return self.component.__getstate__()

	def process_packet(self):
		raise NotImplementedError

	def packetLeftNetworkHandler(self):
		pass

	def flowRemoved(self, msg):
		self.enableAction("flow_removed", msg)

	def flow_removed(self, msg):
		if self.component.flow_removed_cb != None:
			self.component.flow_removed_cb(self, msg)
		return True

class PySwitchController(Controller):
	"""A pyswitch controller"""
	def __init__(self, name, ctxt, version="pyswitch"):
		if version == "pyswitch":
			# x-CHANGE
			# Xiaoye: change the load module
			import pyswitch.pyswitch as pyswitch_mod
			Controller.__init__(self, name, pyswitch_mod.pyswitch, ctxt)
		elif version == "wildswitch":
			import wildswitch.wildswitch as wildswitch_mod
			Controller.__init__(self, name, wildswitch_mod.wildswitch, ctxt)
		else:
			assert False

	def process_packet(self):
		raise NotImplementedError

	def isSameMicroflow(self, packet1, packet2):
		return (packet1.src == packet2.src and packet1.dst == packet2.dst)

	def __repr__(self):
		return str(self.component)

class EateController(Controller):
	def __init__(self, name, ctxt, version):
		eate_mod = __import__("eate_app." + version)
		eate_mod = getattr(eate_mod, version)
		Controller.__init__(self, name, eate_mod.eate, ctxt)

	def start(self, model_checker):
		Controller.start(self, model_checker)
		self.enableAction("port_stats_special", (1))

	def packetLeftNetworkHandler(self):
		self.enableAction("port_stats_special", (1))

	def process_packet(self):
		raise NotImplementedError

class LoadBalancerController(Controller):
	MOD_LIST = ["lbtest", "Alphas", "Arps",
			"Bins", "EvalRules", "Globals",
			"IPRules", "IPs", "IPTransition",
			"Multipath", "Stats"]

	CONF_0 = "1, 3\n2, 1"
	CONF_1 = "1, 1\n2, 3"

	def __init__(self, name, ctxt, use_fixed=False):
		if use_fixed:
			self.package = "loadbalancer.loadbalancer_fixed."
		else:
			self.package = "loadbalancer.loadbalancer."
		self.modules = []
		for i in range(0, len(self.MOD_LIST)):
				self.modules.append(self.package + self.MOD_LIST[i])

		self.callback0_count = 0
		self.count_packets = 0
		if self.package + "lbtest" in sys.modules:
			for m in self.modules:
				reload(sys.modules[m])
			lb_mod = sys.modules[self.package + "lbtest"]
		else:
			lb_mod = __import__(self.package)
			lb_mod = getattr(lb_mod, self.package.split(".")[1])
			lb_mod = lb_mod.lbtest

		sys.modules[self.package + "Globals"].ALPHAFILE = self.CONF_0
		Controller.__init__(self, name, lb_mod.lbtest, ctxt, max_callbacks=2)

	def setConfigFile(self):
		if self.callback0_count > 1:
			return
		elif self.callback0_count == 0:
			self.callback0_count = 1
			sys.modules[self.package + "Globals"].ALPHAFILE = self.CONF_0
		elif self.callback0_count == 1:
			testPoint("reload_config")
			self.callback0_count = 2
			sys.modules[self.package + "Globals"].ALPHAFILE = self.CONF_1

	def postCallback(self, time, callback):
		if (len(self.callbacks) >= self.max_callbacks):
			# too many callbacks
			return
		self.callbacks.append(callback)
		if len(self.callbacks) == 1: # callback is the alphaFileUpdate
			return # delayed
		else: # callback is the arpRequestReplicas
			self.enableAction("process_callback", 1)

	def process_callback(self, callback_id):
		# for now do not erase old callbacks, it is not worth it, the list is small anyway
		if callback_id == 0:
			self.setConfigFile()

		self.callbacks[int(callback_id)]()

		if self.callback0_count == 1 and self.count_packets == 2:
			return False
	
		return True

	def isSameMicroflow(self, packet1, packet2):
		tcp_pkt1 = packet1.find("tcp")
		tcp_pkt2 = packet2.find("tcp")
		if tcp_pkt1 == None or tcp_pkt2 == None:
			return (packet1.src == packet2.src and packet1.dst == packet2.dst)
		else:
			return tcp_pkt1.flow_id == tcp_pkt2.flow_id

	def process_packet(self):
		raise NotImplementedError

