#
# Copyright (c) 2011, EPFL (Ecole Politechnique Federale de Lausanne)
# All rights reserved.
#
# Created by Marco Canini, Daniele Venzano, Dejan Kostic, Jennifer Rexford
# Contributed to this file: Peter Peresini, Maciej Kuzniar
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#   -  Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#   -  Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#   -  Neither the names of the contributors, nor their associated universities or
#      organizations may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from lib.node import Node
from invariants.violation import Violation
from invariants.invariant_dispatcher import testPoint, reportViolation

from nox.lib.packet import ethernet

import nox.lib.openflow as openflow
import nox.lib.core as core
import nox.lib.util as of_util
from nox.lib.packet.mac_address import MacAddress
from nox.lib.packet.ip_address import IpAddress

from lib.of_packet_out_message import PacketOutMessage
from lib.of_modify_state_message import ModifyStateMessage
from lib.of_flow_table_modification_message import FlowTableModificationMessage

import logging
import utils

PKTI_VN = -1

class FlowTableEntry:
	def __init__(self, attrs, actions, priority, send_flow_rem=False, vn = PKTI_VN):
		self.attrs = attrs
		self.actions = actions
		self.priority = priority
		self.send_flow_rem = send_flow_rem

		'''
		vn = version number
		wl = waiting list
		'''
# X-CHANGE
		self.vn = vn
		self.wl = []
		'''
		preflow field
		'''
# XX-CHANGE
		self.next_entry_attr = {}
		self.next_entry_vn = PKTI_VN

# XX-CHANGE
	def set_preflow_field(self, next_entry_attr = {}, next_entry_vn = PKTI_VN):
		self.next_entry_attr = next_entry_attr
		self.next_entry_vn = next_entry_vn
	def clr_preflow_field(self):
		self.next_entry_attr = {}
		self.next_entry_vn = PKTI_VN

	'''
		hidden called by sort()
		define what is less than
	'''
	def __lt__(self, other):	
		my_attrs = utils.flatten_dict(self.attrs)
		other_attrs = utils.flatten_dict(other.attrs)
		if my_attrs != other_attrs:
			return my_attrs < other_attrs
		else:

# X-CHANGE
			if self.actions != other.actions:
				return self.actions < other.actions
			else:
				return self.vn < other.vn

	def __eq__(self, other):
		eq = True
		eq = eq and self.attrs == other.attrs
		eq = eq and self.actions == other.actions
		eq = eq and self.priority == other.priority
		eq = eq and self.send_flow_rem == other.send_flow_rem

# X-CHANGE
		#eq = eq and self.vn == other.vn
		#eq = eq and self.wl == other.wl

# XX-CHANGE
		#eq = eq and self.next_entry_attr== other.next_entry_attr
		#eq = eq and self.next_entry_vn == other.next_entry_vn
		return eq

	def __ne__(self, other):
		return not self.__eq__(other)

	def __getstate__(self):
		filtered_dict = {}
		filtered_dict["actions"] = self.actions

		filtered_dict["attrs"] = utils.serialize_dict(self.attrs)

		filtered_dict["priority"] = self.priority
		filtered_dict["send_flow_rem"] = self.send_flow_rem

# X-CHANGE
		filtered_dict["vn"] = self.vn
		filtered_dict["wl"] = self.wl

# XX-CHANGE
		filtered_dict["next_entry_attr"] = self.next_entry_attr
		filtered_dict["next_entry_vn"] = self.next_entry_vn

		return filtered_dict

	def __repr__(self):
# X-CHANGE
		return str((self.attrs, self.actions, self.priority, self.vn, self.wl, self.next_entry_attr,self.next_entry_vn))

class OpenflowSwitch(Node):
	ALWAYS_NEW_STATE = False

	def __init__(self, name, port_count, of_id, expire_entries=False):
		Node.__init__(self, name, port_count)
		self.log = logging.getLogger("nice.mc.%s" % self.name)
		self.flow_table_object = []
		self.controller = None
		self.openflow_id = of_id
		self.buffers = []
		self.next_buffer_id = 0
		self.packet_store = {}
		self.command_queue = []
		self.fault_injection_count = 0
		self.state_cnt = 0
		self.expire_entries = expire_entries

	@property
	def flow_table(self):
#		self.communicationObjectUsed(self, self.name + ".flowTable")
		return self.flow_table_object

	def __repr__(self):
		return "%s (id: %d)" % (self.name, self.openflow_id)

	def setController(self, cont):
		self.controller = cont

	def getOpenflowID(self):
		return self.openflow_id

	def setFaultInjectionCount(self, count):
		""" this gets called after initTopology and when a fault action is executed """
		self.fault_injection_count = count

	# NOTE: function overloaded from Node.
	def enqueuePacket(self, packet, inport):
		self.log.debug("Queued packet %s on port %d" % (packet, inport))
		self.ports[inport].queueIn(packet)
		self.enableAction("process_packet", skip_dup=True)
		if self.fault_injection_count > 0:
#			self.enableAction("drop_packet", args=(inport,), skip_dup=True)
			self.enableAction("duplicate_packet", args=(inport,), skip_dup=True)
#			self.enableAction("reorder_packet", args=(inport,), skip_dup=True)

	def drop_packet(self, inport):
		""" Dequeues the first packet from the specified port and throws it away """
		more_packets = False
		pkt = self.getWaitingPacket(inport)
		if pkt != None and self.checkWaitingPacket(inport):
			more_packets = True

		return not (more_packets and self.fault_injection_count > 0)

	def duplicate_packet(self, inport):
		""" Creates a copy of the the first packet on the port and puts it on the end of the buffer """
		pkt = self.getWaitingPacket(inport)
		if pkt == None:
			self.log.debug("Empty buffer, no packets duplicated")
			return True

		self.ports[inport].in_buffer.insert(0, pkt)
		pkt2 = pkt.copy()
		pkt.fault_injection.append("HAS DUP")
		pkt2.fault_injection.append("DUP")
		self.ports[inport].in_buffer.append(pkt2)
		self.log.debug("Duplicated packet on port %d: %s" % (inport, pkt))

		return not self.fault_injection_count > 0

	def reorder_packet(self, inport):
		""" Appends the first packet at the end of the buffer """
		pkt = self.getWaitingPacket(inport)
		if pkt == None:
			return True

		pkt.fault_injection.append("REORD")
		self.ports[inport].in_buffer.append(pkt)

		return not self.fault_injection_count > 0

	'''
		called by match flow table
		before exexute the action, packet should be encapsulated
	'''
	def processActions(self, packet, actions, inport, entry = None):
		"""Process a set of actions on the packet"""
		for action in actions:
			if action[0] == openflow.OFPAT_OUTPUT:
				port = action[1][1]
				if port < openflow.OFPP_MAX:
					peer = self.getPeer(port)
					if peer == None:
						continue # Skip disconnected port
					peer.enqueuePacket(packet.copy(), self.getPeerPort(port))
					testPoint("switch_sent_packet_on_port", switch=self, packet=packet, port=port)
				elif port == openflow.OFPP_FLOOD:
					testPoint("switch_flood_packet_start", switch=self, packet=packet)
					for port in self.ports:
						if port == inport:
							continue # Skip the inport
						peer = self.getPeer(port)
						if peer == None:
							continue # Skip disconnected port
						peer.enqueuePacket(packet.copy(), self.getPeerPort(port))
						testPoint("switch_sent_packet_on_port", switch=self, packet=packet, port=port)
				elif port == openflow.OFPP_CONTROLLER:
					self.controller.enqueueQuery(self.openflow_id, None, packet.copy(), port, openflow.OFPR_ACTION)
				else:
					utils.crash("Unknown port action: 0x%x" % port)

# X-CHANGE
			elif action[0] == openflow.OFPAT_WAIT:
				assert isinstance(entry, FlowTableEntry)
				if len(actions) > 1:
					utils.crash("length of actions larger than 1 when action is WAIT")
				entry.wl.append((packet,inport))

			elif action[0] == openflow.OFPAT_SET_DL_SRC: # Set Ethernet source address
				packet.src = MacAddress(action[1])
			elif action[0] == openflow.OFPAT_SET_NW_SRC: # Set IPv4 source address
				packet.next.srcip = IpAddress(action[1]) # We assume the next is an ipv4
			elif action[0] == openflow.OFPAT_SET_DL_DST: # Set Ethernet destination address
				packet.dst = MacAddress(action[1])
			elif action[0] == openflow.OFPAT_SET_NW_DST: # Set IPv4 source address
				packet.next.dstip = IpAddress(action[1]) # We assume the next is an ipv4
			else:
				utils.crash("Action not implemented: %x" % action[0])

	def packetIsMatching(self, pkt, inport, attrs):
		"""Checks whether a packet is matching a specific table entry
		
		   attrs is a dictionary of attributes to match, missing attributes
		   are considered wildcarded
		"""
		pkt_attrs = of_util.extract_flow(pkt)
		pkt_attrs[core.IN_PORT] = inport

		skip_nw_src = False
		for a in attrs:
			if a == "nw_src_n_wild":
				mask = int(0xffffffff << attrs[a]) # see openflow.h in NOX for this definition of mask
				if pkt_attrs[core.NW_SRC] & mask != attrs["nw_src"]:
					return False
				else:
					skip_nw_src = True
			elif a == "nw_src" and skip_nw_src:
				continue
			elif attrs[a] != pkt_attrs[a]: # If this throws an exception, we have an usupported attribute
				return False
		return True

	def matchFlowTable(self, pkt, inport):
		pkt_attrs = of_util.extract_flow(pkt)
		pkt_attrs[core.IN_PORT] = inport
		self.communicationObjectUsed(self, "flowTable_read", pkt_attrs)
		matching_entries = []
		matching_entries_priority = []
		'''
		look for the flow table
		'''
		for entry in self.flow_table:
			if self.packetIsMatching(pkt, inport, entry.attrs):
# X-CHANGE
				self.log.debug("match the entry: *FTE: " + str(entry))

				matching_entries.append(entry)
			else:
				self.log.debug("FTE: " + str(entry))
		'''
		select an entry with highest version number among the highest priority
		'''
		if len(matching_entries) == 0: # no match
			return False
		elif len(matching_entries) > 1: # multiple matches, select on priority
			matching_entries.sort(key=lambda x: x.priority, reverse=True)
		entry = matching_entries[0]
# XX-CHANGE
		if pkt.next_entry_vn == PKTI_VN:
			if entry.next_entry_vn != PKTI_VN:
				pkt.set_preflow_field(entry.next_entry_attr, entry.next_entry_vn)
		self.processActions(pkt, entry.actions, inport, entry)
		'''
		delete the preflow field of the entry
		'''
		for ENTRY in self.flow_table:
			if ENTRY == entry:
				ENTRY.clr_preflow_field()

		return True

	def processPacketOutMessage(self, command):
		self.log.debug("Processing a PacketOut: %s" % repr(command))
		if command.buffer_id != None:
			(packet, inport) = self.packet_store[command.buffer_id]
		else:
			packet = command.packet
			inport = command.inport


# XX-CHANGE
		'''
			encapsulate the pre-flow field to the packet
		'''
		if command.wait_bit == True:													# wait_bit is True
			packet.set_preflow_field(command.next_entry_attr, command.next_entry_vn)	# encapsulate the preflow field to the packet

		'''
			do the action defined in the packet-out message
		'''
		if len(command.actions) > 0:
			self.processActions(packet, command.actions, inport)
		else:
			self.log.debug("Dropping packet with empty action list")
		if command.buffer_id != None:
			self.releaseBuffer(command.buffer_id)


# X-CHANGE
		'''
			PKTO will delete the entry installed when the PKTI message creates and pop its waiting list
		'''
		for entry in self.flow_table:
			if entry.actions[0] == [openflow.OFPAT_WAIT]:		# get the entry waiting for command message
				if entry.vn == PKTI_VN:							# get the entry installed by a packet-in message
					pkt_attrs = of_util.extract_flow(packet)
					pkt_attrs[core.IN_PORT] = inport
					if entry.attrs == pkt_attrs:				# get the entry installed by the packet in packet-out message
						while len(entry.wl) > 0:
							(wl_pkt, wl_inport) = entry.wl.pop(0)
							if len(command.actions) > 0:		# process the packet, copy from processpacketoutmessage
								self.processActions(wl_pkt, command.actions, inport)
							else:
								self.log.debug("Dropping packet with empty action list")
						self.flow_table.remove(entry)

		return


	def processFlowTableModification(self, command):
		self.communicationObjectUsed(self, "flowTable_write", command.arguments["attrs"])
		self.log.debug("Processing a FlowTableModification command: %s" % repr(command))

# XX-CHANGE
		'''
			insert the entry to flow table
		'''
		if command.command == openflow.OFPFC_ADD:
			pop_wl = []
			buf_id = command.arguments["buffer_id"]
			# TODO: idle_timeout, hard_timeout
			'''
				find the WAIT entry needs to be modified
			'''
			for entry in self.flow_table:						
				'''
				2 cases generate the WAIT entry
				'''
				if entry.actions[0] == [openflow.OFPAT_WAIT]:							# WAIT entry, need update
					'''
						if the MFTE does not contain pkt-out
						then this entry wait entry is set by preflow
						attributes should be equal to the match field in the command
					'''
					if buf_id == None:													# MFTE has no packet-out, 
						if entry.attrs == command.arguments["attrs"]:					# use the command's attrs to match
							if entry.vn <= command.vn and entry.vn != PKTI_VN:			# if the entry needs update and is created by preflow
								while len(entry.wl) > 0 :
									pop_wl.append(entry.wl.pop(0))
								self.flow_table.remove(entry)
								break
						'''
							if the MFTE contains pkt-out
							WAIT entry set by pkt-in sent by the current switch
							attributes should be equal to the packet header and the inport
						'''
					else :
						if buf_id not in self.packet_store:
							v = Violation(None, "Trying to access buffer %d %s" % (buf_id, self.packet_store))
							reportViolation(v)
						(packet, inport) = self.packet_store[buf_id]
						pkt_attrs = of_util.extract_flow(packet)
						pkt_attrs[core.IN_PORT] = inport
						if entry.attrs == pkt_attrs:									# found the match entry
							if entry.vn == PKTI_VN:
								while len(entry.wl) > 0 :
									pop_wl.append(entry.wl.pop(0))
								self.flow_table.remove(entry)
								break

			'''
				create an entry
			'''
			e = FlowTableEntry(command.arguments["attrs"], command.arguments["actions"], 
				command.arguments["priority"],vn = command.vn)
			'''
				set the preflow field
				if wait_bit is True and there is no pkt-out carried
			'''
# X-CHANGE			
			if buf_id == None: 
				if command.wait_bit == True:
					e.set_preflow_field(command.next_entry_attr, command.next_entry_vn)
			insert = True
			for entry in self.flow_table:
				if e == entry:
					insert = False
					break
			if insert == True:
				self.flow_table.append(e)
			self.flow_table.sort()
			if self.expire_entries:
				self.enableAction("expire_entry", e)
				# Process the packet specified in buffer_id
			'''
				process the PKTO in MFTE
			'''
			if buf_id != None: 
				if buf_id not in self.packet_store:
					v = Violation(None, "Trying to access buffer %d %s" % (buf_id, self.packet_store))
					reportViolation(v)
				(packet, inport) = self.packet_store[buf_id]
				'''
					encapsulate the packet
				'''
				if command.wait_bit == True:
					packet.set_preflow_field(command.next_entry_attr, command.next_entry_vn)
				'''
					match the flow table
				'''
				self.matchFlowTable(packet, inport)
				self.releaseBuffer(buf_id)
			'''
				pop the packet gathered
			'''
			while len(pop_wl) > 0:
				(wl_pkt, wl_inport) = pop_wl.pop(0)
				self.matchFlowTable(wl_pkt, wl_inport)

		elif command.command == openflow.OFPFC_DELETE:
			attrs = command.arguments["attrs"]
			for e in self.flow_table:
				if e.attrs == attrs:
					self.log.debug("Deleting flow entry %s" % e)
					del self.flow_table[self.flow_table.index(e)]
					if command.arguments.has_key("flags") and openflow.OFPFF_SEND_FLOW_REM in command.arguments["flags"]:
						msg = openflow.OfpFlowRemoved()
						msg.priority = e.priority
						msg.reason = openflow.OFPRR_DELETE
						msg.table_id = 0
						msg.duration_sec = 0
						msg.duration_nsec = 0
						msg.idle_timeout = 0
						msg.packet_count = 0
						msg.byte_count = 0
						msg.match = e.attrs # match attributes
						self.controller.flowRemoved(msg)
		elif command.command == openflow.OFPFC_DELETE_STRICT:
			attrs = command.arguments["attrs"]
			priority = command.arguments["priority"]
			for e in self.flow_table:
				if e.attrs == attrs and e.priority == priority:
					self.log.debug("Deleting flow entry %s" % e)
					del self.flow_table[self.flow_table.index(e)]
					if command.arguments.has_key("flags") and openflow.OFPFF_SEND_FLOW_REM in command.arguments["flags"]:
						msg = openflow.OfpFlowRemoved()
						msg.priority = e.priority
						msg.reason = openflow.OFPRR_DELETE
						msg.table_id = 0
						msg.duration_sec = 0
						msg.duration_nsec = 0
						msg.idle_timeout = 0
						msg.packet_count = 0
						msg.byte_count = 0
						msg.match = e.attrs # match attributes
						self.controller.flowRemoved(msg)
					break

	def acquireBuffer(self):
		if len(self.buffers) == 0:
			self.buffers.append(self.next_buffer_id)
			self.next_buffer_id = self.next_buffer_id + 1
		return self.buffers.pop()
	
	def releaseBuffer(self, buffer_id):
		del self.packet_store[buffer_id]
		self.buffers.append(buffer_id)

	def enqueueCommand(self, command):
		self.command_queue.append(command)
		self.enableAction("process_command", skip_dup=True)
		testPoint("switch_enqueue_command", switch=self, command=command)
		self.log.debug("Queued command: %s" % repr(command))

	def expire_entry(self, entry):
		self.communicationObjectUsed(self, "flowTable_write", entry.attrs)
		del self.flow_table[self.flow_table.index(entry)]
		if entry.send_flow_rem:
			msg = openflow.OfpFlowRemoved()
			msg.priority = e.priority
			msg.reason = openflow.OFPRR_HARD_TIMEOUT # could be also IDLE_TIMEOUT
			msg.table_id = 0
			msg.duration_sec = 0
			msg.duration_nsec = 0
			msg.idle_timeout = 0
			msg.packet_count = 0
			msg.byte_count = 0
			msg.match = entry.attrs # match attributes
			self.controller.flowRemoved(msg)
		return True

	def getWaitingPacket(self, port_name):
		port = self.ports[port_name]
		if len(port.in_buffer) > 0:
			pkt = port.in_buffer.pop(0)
			return pkt
		else:
			return None

	def process_packet(self):
		""" Dequeues the first packet from all ports and processes it """
#		import pdb; pdb.set_trace()
		more_packets = False
		return_string = ""
		for p in self.ports:
			pkt = self.getWaitingPacket(p)

			if pkt == None:
				continue
			elif self.checkWaitingPacket(p):
				more_packets = True
			temp_string = "port=" + str(p) + ",id=" + pkt.packet_id + "|"
			return_string += temp_string

			self.log.debug("Processing packet %s" % pkt)

			'''
			process the packet having preflow field
			'''
# X-CHANGE
			if pkt.next_entry_vn != PKTI_VN:
				add_entry = True
				for entry in self.flow_table:
					if entry.attrs == pkt.next_entry_attr:
						if entry.vn == pkt.next_entry_vn:
							if entry.actions[0] != [openflow.OFPAT_WAIT]:
								add_entry = False
				if add_entry == True:
					e = FlowTableEntry(pkt.next_entry_attr, [[openflow.OFPAT_WAIT]], openflow.OFP_DEFAULT_PRIORITY, vn = pkt.next_entry_vn)
					self.flow_table.append(e)
					self.flow_table.sort()
					if self.expire_entries:
						self.enableAction("expire_entry", e)
				pkt.clr_preflow_field()

			testPoint("switch_process_packet", switch=self, packet=pkt, port=p)
			if not self.matchFlowTable(pkt, p):
				buffer_id = self.acquireBuffer()
				self.controller.enqueueQuery(self.openflow_id, buffer_id, pkt, p, openflow.OFPR_NO_MATCH)
				self.packet_store[buffer_id] = (pkt, p)
				self.log.debug("Queued query to controller")

# X-CHANGE
				pkt_attrs = of_util.extract_flow(pkt)
				pkt_attrs[core.IN_PORT] = p
				e = FlowTableEntry(pkt_attrs,[[openflow.OFPAT_WAIT]], openflow.OFP_DEFAULT_PRIORITY , vn = PKTI_VN)
				self.flow_table.append(e)
				self.flow_table.sort()

		rv = []
		rv.append(not more_packets)
		rv.append(return_string)
		return rv
		#return not more_packets

	def process_command(self):
		""" Process a command from the controller """
		command = self.command_queue.pop(0)
		if isinstance(command, PacketOutMessage):
			command_type = "PKTO"
			self.processPacketOutMessage(command)
		elif isinstance(command, FlowTableModificationMessage):
			command_type = "MFTE"
			self.processFlowTableModification(command)
		else:
			utils.crash("Switch received an unknown command: %s" % command)

# X-CHANGE
		'''
			create the return value
		'''
		rv_temp = ""
		rv = []
		rv.append(len(self.command_queue) == 0)
		if command_type == "PKTO":
			rv_temp = "PKTO" + ".NEXT_vn=" + str(command.next_entry_vn)
		elif command_type == "MFTE" :
			rv_temp = "MFTE" + ".vn=" + str(command.vn) + ".NEXT_vn=" + str(command.next_entry_vn)
		rv.append(rv_temp)
		return rv

		#return len(self.command_queue) == 0

	def __getstate__(self):
		filtered_dict = Node.__getstate__(self)
		filtered_dict["command_queue"] = self.command_queue

		filtered_dict["flow_table"] = []
		for j in self.flow_table:
			filtered_dict["flow_table"].append(j.__getstate__())

		if self.ALWAYS_NEW_STATE:
			self.state_cnt += 1
			filtered_dict["state_cnt"] = self.state_cnt

		return filtered_dict

