/*
 * URL request trigger client
 *
 * Copyright (C) 2007 Darius Ivanauskas <dasilt@gmail.com>
 * Ubiquiti Networks, Inc.
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#define _GNU_SOURCE
#include <string.h>
#include <stdarg.h>
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <errno.h>
#include <netinet/in.h>
#include <netdb.h>
#include <ctype.h>

#ifdef MATRIX_SSL_COMPAT
#include <matrixssl/matrixssl_helper.h>
#else
#include <openssl/ssl.h>
#endif

#define OP_NONE 0
#define OP_READ 1
#define OP_WRITE 2

#define DEFAULT_TIMEOUT 2

#define TRIGGER_CONNECTING	1
#define TRIGGER_SSL_CONNECT	2
#define TRIGGER_SEND		3
#define TRIGGER_RECEIVE		4
#define TRIGGER_DONE		5

static int usage(char* progname, ...)
{
	va_list ap;
	char *fmt;
	int rv = 0;
	
	va_start(ap, progname);
	fmt = va_arg(ap, char *);
	if (fmt)
	{
		vprintf(fmt, ap);
		rv = -1;
	}
	va_end(ap);
	printf("HTTP[S] URL request trigger client v0.1 (c) Ubiquiti Networks\n");
	printf("Usage: %s [options] <url>\n"
		"\t-t <seconds>\t\tconnection timeout in seconds; default: %d\n"
		"\t-h\t\t\tdisplay this help and exit.\n",
		progname, DEFAULT_TIMEOUT);
	return rv;		
}

static char b64[] =
	"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    "abcdefghijklmnopqrstuvwxyz"
    "0123456789./";

static void base64encode(unsigned char *from, char *to, int len)
{
	unsigned int k;
	while (len > 2)
	{
		k = *from++;
		k <<= 8;
		k |= *from++;
		k <<= 8;
		k |= *from++;
		len -= 3;
		*to++ = b64[(k >> 18) & 0x3f];
		*to++ = b64[(k >> 12) & 0x3f];
		*to++ = b64[(k >> 6) & 0x3f];
		*to++ = b64[k & 0x3f];
	}
	if (len)
	{
		k = *from++;
		k <<= 8;
		if (--len)
		{
			k |= *from++;
		}
		k <<= 8;
		*to++ = b64[(k >> 18) & 0x3f];
		*to++ = b64[(k >> 12) & 0x3f];
		*to++ = len ? b64[(k >> 6) & 0x3f] : '=';
		*to++ = '=';
	}
	*to++ = 0;
}

static int get_scheme_port(const char* scheme, u_int16_t *port)
{
	if (strcmp(scheme, "http") == 0)
	{
		*port = 80;
	}
	else if (strcmp(scheme, "https") == 0)
	{
		*port = 443;
	}
	else
	{
		return -1;
	}
	return 0;
}

typedef struct url
{
    char* scheme;
    char* auth;
    char* host;
    u_int16_t port;
    char* path;
    struct sockaddr_in dst_addr;
    int ssl;
}
url_t;

static void free_url(url_t* url)
{
    if (!url)
    {
        return;
    }
    if (url->scheme)
    {
        free(url->scheme);
    }
    if (url->auth)
    {
        free(url->auth);
    }
    if (url->host)
    {
        free(url->host);
    }
    if (url->path)
    {
        free(url->path);
    }
    free(url);
}

int resolve_host(const char* hostname, struct sockaddr_in* addr)
{
    struct addrinfo* result;
    struct addrinfo hints;
    int rc;

    memset(&hints, 0, sizeof(hints));
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;

    rc = getaddrinfo(hostname, NULL, &hints, &result);

    if (rc)
    {
    	return -1;
    }
	memcpy(addr, result->ai_addr, sizeof(*addr));
    freeaddrinfo(result);
    return 0;
}

static url_t* parse_url(const char* url_string)
{
    url_t* url = NULL;
    const char* pos;
    const char* path;
    char* port;
    int invalid = 0;
    const char *begin = url_string;
    const char *end = url_string + strlen(url_string);

    if (*url_string == '\0')
    {
        /* Empty url */
        return url;
    }
    url = calloc(1, sizeof(url_t));
    if (!url)
    {
        /* Memory allocation error */
        return url;
    }
    url->port = 80; /* default port if no scheme specified */
    pos = strchr(begin, ':');
    if (pos)
    {
        if (end - pos > 2 && pos[1] == '/' && pos[2] == '/')
        {
            url->scheme = strndup(begin, pos - begin);
            if (get_scheme_port(url->scheme, &url->port))
            {
            	/* unsupported scheme */
                invalid = 1;
            }
            begin = pos + 3;
        }
    }
    for (path = begin; path < end; path++)
    {
    	if (*path == '/' || *path == '?' || *path == ';')
    	{
    		break;
    	}
    }
    pos = (char*)memchr(begin, '@', end - begin);
    if (pos)
    {    	
        url->auth = strndup(begin, pos - begin);
        begin = pos + 1;
    }
    pos = (char*)memchr(begin, ':', path - begin);
    if (pos)
    {
    	int int_val;
        url->host = strndup(begin, pos - begin);
        ++pos;
        port = strndup(pos, path - pos);
        int_val = atoi(port);
        if (int_val < 1 || int_val > 65535)
        {
        	/* Invalid port */
            invalid = 1;
            free(port);
        }
        else
        {
        	url->port = (u_int16_t)int_val;
        }
    }
    else
    {
        url->host = strndup(begin, path - begin);
    }
    if (path < end && *path == '/')
    {
        url->path = strndup(path, end - path);
    }
    else
    {
        --path;
        url->path = strndup(path, end - path);
        *url->path = '/';
    }
    if (!(url->host && *url->host))
    {
    	/* empty host */
        invalid = 1; 
    }
    if (invalid || resolve_host(url->host, &url->dst_addr))
    {
        free_url(url);
        return NULL;
    }
    url->dst_addr.sin_port = htons(url->port);
    url->ssl = (url->scheme && strlen(url->scheme) == 5);
    return url;
}

size_t fill_trigger_http_request(char *buf, size_t size, url_t *url)
{	
/*
"GET %s HTTP/1.0\r\n"
"User-Agent: Speedtest/0.1\r\n"
"Authorization: Basic %s\r\n"
"Host: %s\r\n\r\n"
*/
	char* pos = buf;
	snprintf(pos, size,
		"GET %s HTTP/1.0\r\n"
		"User-Agent: Speedtest/0.1\r\n", url->path);
	pos += strlen(pos);
	if (url->auth)
	{
		snprintf(pos, size - (pos - buf), "Authorization: Basic ");
		pos += strlen(pos);
		if ((strlen(url->auth)+2)/3*4+1 < size - (pos - buf))
		{
			base64encode((unsigned char *)url->auth, pos, strlen(url->auth));
			pos += strlen(pos);
		}
		if (size - (pos - buf) > 2)
		{
			*pos++ = '\r';
			*pos++ = '\n';
		}
	}
	snprintf(pos, size - (pos - buf), "Host: %s\r\n\r\n", url->host);
	pos += strlen(pos);
	return strlen(buf);
}

int extract_redirect_location(const char* buf, size_t len)
{
	const char* pos = buf;
	const char* end;
	for (;pos < buf + len; pos = end + 2)
	{
		end = (char*)memmem(pos, len, "\r\n", 2);
		if (end == 0)
		{
			return 1;
		}
		if (end == pos)
		{
			return -1;
		}
		while (isblank(*pos))
		{
			++pos;
		}
		if (end - pos < 9 || strncasecmp(pos, "Location:", 9))
		{
			continue;
		}
		pos += 9;
		while (isblank(*pos))
		{
			++pos;
		}
		printf("%.*s\n", end - pos, pos);
		return 0;
	}
	return 1;
}
/* Returns operation than need to be retried or 0 when operation cannot
 be retried */
static int handle_io_error_return(SSL *ssl, int result, int operation_hint)
{
	if (result == 0)
	{
		return 0;
	}
	if (ssl)
	{
		int ssl_error = SSL_get_error(ssl, result);
		if (ssl_error == SSL_ERROR_WANT_READ)
		{
			return OP_READ;
		}
		if (ssl_error == SSL_ERROR_WANT_WRITE)
		{
			return OP_WRITE;
		}
	}
	else if (errno == EAGAIN)
	{
		return operation_hint;
	}
	return 0;
}

int set_nonblock(int fd)
{
	if (fcntl(fd, F_SETFL, O_NONBLOCK) < 0)
	{
		return -1;
    }
    return 0;
}

int create_socket(void)
{
	int fd;
	fd = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (fd >= 0)
	{
		if (set_nonblock(fd) < 0)
		{
			close(fd);
			return -2;
		}
	}
	return fd;
}

/*
 * Return values:
 * 
 *   0: operation success(help or test processed)
 *  -1: argument error
 *  -2: socket creation error
 *  -3: trigger connect() error
 *  -4: trigger SSL_connect() error
 *  -5: trigger send() or SSL_write() error
 *  -6: trigger recv() or SSL_read() error
 *  -7: invalid http trigger response
 *  -8: too long http trigger response line ? (buffer overflow atack?)
 *  -9: unsuccessful trigger http response
 * -10: authorization failed http response
 * -11: redirect - url found
 * -12: redirect - no url
 * -20-<status>: operation timeout() on <status>
 * -40-<status>: select() error on <status>
 * <status> values:
 * TRIGGER_CONNECTING	1
 * TRIGGER_SSL_CONNECT	2
 * TRIGGER_SEND			3
 * TRIGGER_RECEIVE		4
 * TRIGGER_DONE			5
 * 
 */
int main(int argc, char* argv[])
{
	int timeout_seconds = DEFAULT_TIMEOUT;
	
	int fd = -1;
	int result = 0;
	SSL* ssl = NULL;
#ifdef MATRIX_SSL_COMPAT
        sslKeys_t *ctx = NULL;
#else   
        SSL_CTX *ctx = NULL;
#endif
	int status = TRIGGER_CONNECTING;
	char buf[65536];
	char *start = buf;
	char *end = buf;
	url_t *remote_trigger_url = NULL;
	int operation = 0;
	int original_operation = 0;
	fd_set read_fdset;
	fd_set write_fdset;	
	struct timeval timeout;
	
	while (1)
	{
	    int c = getopt(argc, argv, ":th");
	    int int_value;
	    if (c == -1)
	    {
	    	/* finished options */
	    	break;
	    }
	    switch(c)
	    {
	    	case 't':
	    		int_value = atoi(optarg);
	    		if (int_value < 1)
	    		{
	    			return usage(argv[0], "Invalid timeout specified.\n");
	    		}
	    		timeout_seconds = int_value;
	    		break;
	    	case '?':
	    		return usage(argv[0], "Unknown option: %c\n", optopt); 
	    		break;
	    	case ':':
	    		return usage(argv[0], "Option '%c' requires an argument.\n", optopt); 
    			break;
    		case 'h':
    			return usage(argv[0], NULL);
	    }
	}
	if (optind >= argc)
	{
		return usage(argv[0], "No remote trigger url specified.\n"); 
	}
	remote_trigger_url = parse_url(argv[optind]);
	if (!remote_trigger_url)
	{
		return usage(argv[0], "Invalid remote trigger url specified.\n");
	}

	fd = create_socket();
	if (fd < 0)
	{
		return -2;
	}

	for (;;)
	{
		if (status == TRIGGER_CONNECTING)
		{
			if (operation == OP_WRITE)
			{
				int so_value;
				socklen_t so_len = sizeof(so_value);				
				result = getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_value,
					&so_len);
				if (result == 0)
				{
					if (so_value)
					{
						result = -1;
						errno = so_value;
					}
					else
					{
						result = 0;
					}
				}					
			}
			else
			{ 
				result = connect(fd, 
					(struct sockaddr *)&remote_trigger_url->dst_addr,
					(socklen_t)sizeof(remote_trigger_url->dst_addr));
				if (remote_trigger_url->ssl)
				{
#ifdef MATRIX_SSL_COMPAT
					matrixSslOpen();
#else
					SSL_library_init();
					ctx = SSL_CTX_new(SSLv23_client_method());
#endif
					ssl = SSL_new(ctx);
					SSL_set_fd (ssl, fd);
				}
			}
			if (result < 0 && errno != EINPROGRESS)
			{
				result = -3;
				/* TODO: error handling messages */
				/* connect error */
				break;
			}
			else if (result == 0)
			{
				/* connected to the remote */
				status = ssl ? TRIGGER_SSL_CONNECT : TRIGGER_SEND;
				end = buf + fill_trigger_http_request(buf, sizeof(buf),
					remote_trigger_url);
			}
			else
			{
				original_operation = OP_WRITE;
			}
		}
		
		if (status == TRIGGER_SSL_CONNECT)
		{
			result = SSL_connect(ssl);
			if (result > 0)
			{
				status = TRIGGER_SEND;
			}
			else if (!(original_operation =
				handle_io_error_return(ssl, result, 0)))
			{
				/* TODO: error handling messages */
				/* SSL_connect error */
				result = -4;
				break;					
			}
		}
				
		if (status == TRIGGER_SEND)
		{			
			while (start < end)
			{
				if (ssl)
				{
					result = SSL_write(ssl, start, end - start);
				}
				else
				{
					result = send(fd, start, end - start, 0);
				}
				if (result <= 0)
				{
					break;
				}
				start += result;
			}			
			if (result > 0)
			{
				start = end = buf;
				status = TRIGGER_RECEIVE;
			}
			else if (!(original_operation =
				handle_io_error_return(ssl, result, OP_WRITE)))
			{
				/* TODO: error handling messages */
				/* send|SSL_write error */
				result = -5;
				break;
			}
		}

		if (status == TRIGGER_RECEIVE)
		{
			while (buf + sizeof(buf) > end)
			{
				if (ssl)
				{
					result = SSL_read(ssl, end, sizeof(buf) - (end - buf));
				}
				else
				{
					result = recv(fd, end, sizeof(buf) - (end - buf), 0);
				}
				if (result > 0)
				{
					const char* pos;
					end += result;
					pos = (char*)memmem(start, end - start,	"\r\n", 2);
					if (pos)
					{
						pos = (char*)memchr(start, ' ', pos - start);
						if (!pos || pos > end - 4)
						{
							/* Invalid http response. */						
							result = -7;
						}
						else if (pos[1] == '2') 
						/* Start of Success code 2xx */
						{
							/* success */
							status = TRIGGER_DONE;
						}
						else if (pos[1] == '4' && pos[2] == '0' 
							&& pos[3] == '1') /* Unauthorized code 401 */
						{
							/* unauthorized */
							result = -10;
						}
						else if (pos[1] == '3' && pos[2] == '0' 
							&& pos[3] == '1') /* Moved Permanently 301 */
						{
							result = extract_redirect_location(pos+2, end - pos);
							if (result < 0)
							{
								result = -12; /* redirect - no url */
							}
							else if (result == 0)
							{
								result = -11; /* redirect - url found */
							}
						}
						else
						{
							/* unsuccessful http response. */
							result = -9;
						}
						break;
					}
					if (buf + sizeof(buf) >= end)
					{
						/* Buffer allready full and we didn't found CRLF. */
						result = -8;
					}
				}
				else if ((original_operation =
					handle_io_error_return(ssl, result, OP_READ)))
				{
					result = 1;
					break;
				}	
				else
				{
					/* TODO: error handling messages */
					/* recv|SSL_read error */
					result = -6;
					break; 
				}		
			}
			if (status == TRIGGER_DONE)
			{
				result = 0;
				break;
			}
			if (result <= 0)
			{
				break;
			}
		} 
		
		FD_ZERO(&read_fdset);
		if (original_operation & OP_READ)
		{
			FD_SET(fd, &read_fdset);
		}
		FD_ZERO(&write_fdset);
		if (original_operation & OP_WRITE)
		{
			FD_SET(fd, &write_fdset);
		}

		timeout.tv_sec = timeout_seconds;
		timeout.tv_usec = 0;
		
		result = select(fd + 1, &read_fdset, &write_fdset, NULL, 
			&timeout);

		operation = 0;
		if (result > 0)
		{
			if (FD_ISSET(fd, &read_fdset))
			{
				operation |= OP_READ; 
			}
			if (FD_ISSET(fd, &write_fdset))
			{
				operation |= OP_WRITE; 
			}
		}
		else if (result == 0)
		{			
			result = -20 - status;
			break;
		}
		else if (errno != EINTR)
		{
			/* select error */
			result = -40 - status;
			break;
		}
	}
	close(fd);
	free_url(remote_trigger_url);
	if (ssl)
	{
		SSL_free(ssl);
#ifndef MATRIX_SSL_COMPAT
		SSL_CTX_free(ctx);
#endif
		ssl = NULL;
	}
	return result;
}
