/* Splits a trace based on any MPLS labels or VLAN tags that may be present
 * within the packets. Each label/tag will create a separate trace file, named
 * after the label number or tag ID.
 *
 * @author Shane Alcock, salcock@cs.waikato.ac.nz
 */
#include "libtrace.h"
#include <stdio.h>
#include <assert.h>
#include <getopt.h>

#include <map>

typedef std::map<uint32_t, libtrace_out_t *> TraceMap;

TraceMap outputs;

char *directory = ".";

uint32_t extract_mpls_label(void *mpls_hdr, uint32_t remaining) {

	uint32_t stack_top = 0;
	if (remaining < 4)
		return ~0U;
	stack_top = *(uint32_t *)mpls_hdr;
	stack_top = ntohl(stack_top);
	stack_top = (stack_top >> 12) & 0x000fffff;

	return stack_top;
}

uint32_t extract_vlan_tag(void *vlan_hdr, uint32_t remaining) {

	libtrace_8021q_t *vlan = (libtrace_8021q_t *)vlan_hdr;
	uint16_t tag;
	
	if (remaining < 4)
		return ~0U;
	
	tag = *(uint16_t *)vlan_hdr;
	tag = ntohs(tag);
	tag = tag & 0x0fff;
	return (uint32_t)tag;
}

int extract_label(libtrace_packet_t *packet, uint32_t *label) {
	void *ethernet = NULL;
	void *payload = NULL;
	uint16_t ethertype;
	libtrace_linktype_t linktype;
	uint32_t remaining;

	ethernet = trace_get_layer2(packet, &linktype, &remaining);
	if (linktype != TRACE_TYPE_ETH)
		return 0;

	payload = trace_get_payload_from_layer2(ethernet, linktype, 
			&ethertype, &remaining);

	/* This while(1) may seem unnecessary now, but I'll need it if I
	 * ever have to do anything with labels below the topmost one */
	while (1) {
		if (payload == NULL || remaining == 0)
			return 0;
		switch(ethertype) {
		/* Just go with the topmost labels for now, so we can
		 * insta-return as soon as we find something of interest */
		case 0x8100:
			*label = extract_vlan_tag(payload, remaining);
			return 1;
		case 0x8847:
			*label = extract_mpls_label(payload, remaining);
			return 1;
		case 0x0800:
			/* IP */
			return 0;
		case 0x9000:
			/* Loopback / Loopguard */
			return 0;
		default:
			fprintf(stderr,"Unexpected ethertype: %u\n",ethertype);
			return 0;
		}
	}
		
	assert(0);
	

}

libtrace_out_t *open_new_trace(uint32_t label) {
	char uri[1024];
	libtrace_out_t *out = NULL;
	uint32_t level = 1;
	
	snprintf(uri, 1024, "erf:%s/%u.erf.gz", directory, label);
	
	out = trace_create_output(uri);
	if (trace_is_err_output(out)) {
		trace_perror_output(out, "%s", uri);
		return NULL;
	}

	if (trace_config_output(out,TRACE_OPTION_OUTPUT_COMPRESS,&level) == -1)
	{
		trace_perror_output(out, "Failed to set compression level");
		return NULL;
	}
	trace_start_output(out);

	return out;
}
	

libtrace_out_t *find_output_trace(uint32_t label) {
	TraceMap::iterator it;
	libtrace_out_t *out = NULL;
	
	it = outputs.find(label);
	if (it == outputs.end()) {
		out = open_new_trace(label);
		outputs[label] = out;
	} else {
		out = it->second;
	}

	return out;
}

void per_packet(libtrace_packet_t *packet)
{
	uint32_t label = ~0U;
	libtrace_out_t *out = NULL;
	
	/* Get the label - vlan or mpls */
	if (extract_label(packet, &label) == 0)
		return;
	if (label == ~0U)
		return;

	/* Find the output trace to write to */
	out = find_output_trace(label);
	if (out == NULL)
		return;
	
	/* Write the packet to the trace */
	trace_write_packet(out, packet);
}

void close_traces() {
	TraceMap::iterator it;
	
	for (it = outputs.begin(); it != outputs.end(); it ++) {
		libtrace_out_t *out = it->second;

		if (out != NULL)
			trace_destroy_output(out);
	}
}

void usage(char *prog) {
	printf("I haven't added usage details yet - if you see this message, remind me to do so!\n");
	return;

}

int main(int argc, char *argv[])
{
	libtrace_t *trace;
	libtrace_packet_t *packet;
	int opt, i;
	
	if (argc<2) {
		usage(argv[0]);
		return 1;
	}

	while ((opt = getopt(argc, argv, "d:")) != EOF) {
		switch (opt) {
			case 'd':
				directory = optarg;
				break;
			default:
				usage(argv[0]);
		}
	}

	if (optind + 1 > argc) {
		usage(argv[0]);
		return 1;
	}
	packet = trace_create_packet();

	for (i = optind; i < argc; i++) {
	
		trace = trace_create(argv[i]);

		if (trace_is_err(trace)) {
			trace_perror(trace,"Opening trace file");
			return 1;
		}

		if (trace_start(trace)) {
			trace_perror(trace,"Starting trace");
			trace_destroy(trace);
			return 1;
		}


		while (trace_read_packet(trace,packet)>0) {
			per_packet(packet);
		}


		if (trace_is_err(trace)) {
			trace_perror(trace,"Reading packets");
			trace_destroy(trace);
			return 1;
		}

		trace_destroy(trace);
	}
		
	trace_destroy_packet(packet);
	close_traces();
	return 0;
}
