/*
 * ---------------------------------------------------------------------------
 *  "THE BEER-WARE LICENSE" (Revision 42): <sam@wand.net.nz> wrote this file.
 *  As long as you retain this notice you can do whatever you want with this
 *  stuff. If we meet some day, and you think this stuff is worth it, you can
 *  buy me a beer in return - Sam Jansen.
 * ---------------------------------------------------------------------------
 *
 * tcpperf. See See http://www.wand.net.nz/~stj2/nsc/software.html. By Sam
 * Jansen.
 * 
 * $Id: tcpperf.c 762 2005-02-14 03:06:27Z stj2 $ */
#define VERSION_STRING "1.754"

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <assert.h>
#include <netdb.h>
#include <string.h>
#include <sys/time.h>
#include <signal.h>
#include <fcntl.h>
#include <errno.h>
#include <arpa/inet.h>
#include <errno.h>

#ifndef timersub
# define timersub(a, b, result)						      \
  do {									      \
    (result)->tv_sec = (a)->tv_sec - (b)->tv_sec;			      \
    (result)->tv_usec = (a)->tv_usec - (b)->tv_usec;			      \
    if ((result)->tv_usec < 0) {					      \
      --(result)->tv_sec;						      \
      (result)->tv_usec += 1000000;					      \
   }									      \
  } while (0)
#endif

int mtime = 10;
int write_size = 1024;
int g_verbose = 0;

void print_buf_sizes(int fd)
{
    if(g_verbose) {
        int bufsize, buflen;

        bufsize = 0;
        buflen = sizeof(bufsize);
        
        if(getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufsize, &buflen) == 0) {
            printf("SO_SNDBUF: %d\n", bufsize);
        } else {
            printf("Error retrieving SO_SNDBUF size.\n");
        }

        bufsize = 0;
        buflen = sizeof(bufsize);

        if(getsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufsize, &buflen) == 0) {
            printf("SO_RCVBUF: %d\n", bufsize);
        } else {
            printf("Error retrieving SO_RCVBUF size.\n");
        }
    }
}

void usage(char *n)
{
    printf("Usage: %s [-c <client name>] [-p port] [-t seconds] "
            "[-s write size] [-w socket buffer size] [-n] [-v] [-V]\n", n);
    exit(1);
}

void print_version()
{
    printf("tcpperf by Sam Jansen. See "
            "http://www.wand.net.nz/~stj2/nsc/software.html\n");
    printf("Version " VERSION_STRING " compiled on " __DATE__ "\n\n");
}

void summarise(struct timeval start, struct timeval end, long long bytes_sent, 
        int fd)
{
    struct timeval tv;

    timersub(&end, &start, &tv);
    printf("Duration: %lu %luusec\n", tv.tv_sec, tv.tv_usec);
    printf("Bytes sent: %llu (%llu kB %llu MB %llu GB)\n", bytes_sent,
            bytes_sent / 1024, bytes_sent / (1024*1024),
            bytes_sent / (1024*1024*1024));
    printf("Bandwidth: %.0f b/s (%.2f kb/s %.2f Mb/s)\n",
            (double)bytes_sent * 1000000.0 * 8.0 / 
            ( (double)tv.tv_sec * 1000000.0 + (double)tv.tv_usec ),
            (double)bytes_sent * 1000.0 * 8.0 / 
            ( (double)tv.tv_sec * 1000000.0 + (double)tv.tv_usec ),
            (double)bytes_sent * 8.0 / 
            ( (double)tv.tv_sec * 1000000.0 + (double)tv.tv_usec ) 
          );

    fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) & ~O_NONBLOCK);

    gettimeofday(&start, NULL);
    close(fd);
    gettimeofday(&end, NULL);

    timersub(&end, &start, &tv);
    printf("Close duration: %lus %dusec\n", tv.tv_sec, (int)tv.tv_usec);
}

// Nonblocking version
void send_loop_nb(int fd, struct sockaddr_in *client_addr)
{
    int psize = 1024, buffer = 0, error = 0;
    long long bytes_sent = 0;
    double bw = 2000000.0f * 2.0f;
    double send_rate = (psize * 8.0f) / bw;
    struct timeval start, end, tv, temp_tv;
    char data[40000];
    //struct sigaction act;

    gettimeofday(&start, NULL);

    error = connect(fd, (struct sockaddr *)client_addr, 
            sizeof(struct sockaddr));

    if(error) {
        perror("connecting");
        exit(1);
    }

    print_buf_sizes(fd);

    // Set non-blocking
    fcntl(fd, F_SETFL, O_NONBLOCK);// | O_ASYNC);
    
    /*act.sa_sigaction = sigio_handler;
    act.sa_mask = 0;
    act.sa_flags = 0;
    
    sigaction(SIGIO, &act, NULL);*/
    
    while(1) {
        int ret = 0;

        buffer += psize;

        do { 
            ret = send(fd, data, psize, 0);
            if(ret > 0) {
                buffer -= ret;
                bytes_sent += ret;
            }
        } while(ret > 0 && buffer > 0);

        gettimeofday(&temp_tv, NULL);
        timersub(&temp_tv, &start, &tv);
        if(tv.tv_sec >= mtime)
            break;

        usleep((unsigned long)(send_rate * 1000000.0));

        gettimeofday(&temp_tv, NULL);
        timersub(&temp_tv, &start, &tv);
        if(tv.tv_sec >= mtime)
            break;
    }

    gettimeofday(&end, NULL);

    summarise(start, end, bytes_sent, fd);
    printf("Information is inaccurate in non-blocking mode currently.\n");
}

void send_loop(int fd, struct sockaddr_in *client_addr)
{
    int error = 0, i;
    long long bytes_sent = 0;
    struct timeval start, end, temp_tv;
    unsigned char data[1024*32];
    fd_set wfds;

    gettimeofday(&start, NULL);

    error = connect(fd, (struct sockaddr *)client_addr,
            sizeof(struct sockaddr));

    if(error) {
        perror("connecting");
        exit(1);
    }

    /* Non blocking */
    /* fcntl(fd, F_SETFL, O_NONBLOCK); */

    print_buf_sizes(fd);

    FD_ZERO(&wfds);
    FD_SET(fd, &wfds);

    while(1) {
        int ret = 0;
        fd_set write_fds;
        struct timeval tv;

        for(i = 0; i < 10; i++) {
            write_fds = wfds;
            tv.tv_sec = 0;
            tv.tv_usec = 100; 

            ret = select(fd+1, NULL, &write_fds, NULL, &tv);

            if(ret > 0) {
                if(FD_ISSET(fd, &write_fds)) {
                    error = send(fd, &data, write_size, 0);
                    bytes_sent += error > 0 ? error : 0;
                }
            }
        }
        gettimeofday(&temp_tv, NULL);
        timersub(&temp_tv, &start, &tv);
        /*printf("d: %lu %dusec %llu\n", tv.tv_sec, (int)tv.tv_usec,
          bytes_sent);*/
        if(tv.tv_sec >= mtime)
            break;
    }

    gettimeofday(&end, NULL);
    
    summarise(start, end, bytes_sent, fd);
}

void server_loop(int fd)
{
    fd_set rfds;
    int cfd = -1;
    long long recv_bytes = 0;
        
    FD_ZERO(&rfds);
    FD_SET(fd, &rfds);

    while(1) {
        fd_set read_fds;
        struct timeval accept_time, end_time, tv;
        int ret = 0;
        int max = 0;

        read_fds = rfds;
        tv.tv_sec = 0;
        tv.tv_usec = 10000;

        if(max < fd) max = fd;
        if(max < cfd) max = cfd;

        ret = select(max +1, &read_fds, NULL, NULL, &tv);

        if(ret > 0) {
            if(FD_ISSET(fd, &read_fds)) {
                struct sockaddr_in caddr;
                socklen_t sin_size = sizeof(struct sockaddr_in);
                cfd = accept(fd, (struct sockaddr *)&caddr, &sin_size);
                if(cfd == -1)
                    perror("accept");

                FD_SET(cfd, &rfds);

                gettimeofday(&accept_time, NULL);
                recv_bytes = 0;

                printf("Accepted connection.\n");

                print_buf_sizes(cfd);
            }
            else if(FD_ISSET(cfd, &read_fds)) {
                unsigned char buf[8192];
                int bufsize = 8192;
                ret = read(cfd, &buf, bufsize);

                if(ret == -1) {
                    perror("read");
                } else if(ret == 0) {
                    FD_CLR(cfd, &rfds);
                    gettimeofday(&end_time, NULL);
                    summarise(accept_time, end_time, recv_bytes, cfd);
                } else {
                    recv_bytes += ret;
                }
            }
        }
        /*printf("S: recv_bytes: %llu\n", recv_bytes);*/

    }
}

int main(int argc, char *argv[])
{
    int ch, port = 9000;
    char *client = NULL;
    struct hostent *host;
    struct sockaddr_in client_addr, my_addr;
    int fd = -1, nonblocking = 0, window_size = 0;

    /* Command line argument parsing: */
    while ((ch = getopt(argc, argv, "c:p:t:s:w:nvV")) != -1)
        switch (ch) {
            case 'c':
                /* Client machine to connect to */
                client = optarg;
                break;
            case 'p':
                /* port */
                port = atoi(optarg);
                break;
            case 't':
                /* time */
                mtime = atoi(optarg);
                break;
            case 's':
                write_size = atoi(optarg);
                break;
            case 'n':
                nonblocking = 1;
                break;
	    case 'w':
		window_size = atoi(optarg);
		break;
            case 'v':
                g_verbose++;
                break;
            case 'V':
                print_version();
                return 0;
            case '?':
            default:
                usage(argv[0]);
        }
    argc -= optind;
    argv += optind;

    /* Create socket */
    fd = socket(AF_INET, SOCK_STREAM, 0);
    if(-1 == fd) {
        perror("socket creation");
        exit(1);
    }

    if(window_size != 0) {
	setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &window_size, 
		sizeof(window_size));
	setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &window_size,
		sizeof(window_size));
    }

    

    if(client == NULL) {
        // server operation
        unsigned int opt = 1;

        if(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1)
        {
            perror("setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, 1)");
            exit(1);
        }

        my_addr.sin_family = AF_INET;         // host byte order
        my_addr.sin_port = htons(port);      // short, network byte order
        my_addr.sin_addr.s_addr = INADDR_ANY; // automatically fill with my IP
        memset(&(my_addr.sin_zero), '\0', 8); // zero the rest of the struct

        if(bind(fd, (struct sockaddr *)&my_addr, sizeof(struct sockaddr)) 
                == -1) {
            perror("bind");
            exit(1);
        }

        if(listen(fd, 5)) {
            perror("listen");
            exit(1);
        }

        server_loop(fd);

    } else {
        // client operation

        /* Do DNS (or whatever) lookup */
        host = gethostbyname(client);
        if(!host) {
            fprintf(stderr, "Error resolving client name.\n");
            exit(1);
        }

        /* Meh. */
        signal(SIGPIPE, SIG_IGN);

        client_addr.sin_family = AF_INET;
        client_addr.sin_port = htons(port);
        memcpy(&client_addr.sin_addr, host->h_addr_list[0], 
                sizeof(struct in_addr));
        memset(&client_addr.sin_zero, 0, 8);

        if(nonblocking)
            send_loop_nb(fd, &client_addr);
        else
            send_loop(fd, &client_addr);
    }

    return 0;
}
