/*
 * MatrixSSL helper functions
 *
 * Copyright (C) 2007 Darius Ivanauskas <dasilt@gmail.com>
 *
 * Full remake of older version from Nicloas Thill. Now it is
 * more OpenSSL compatible, including support for
 * SSL_ERROR_WANT_READ/SSL_ERROR_WANT_WRITE functionality.
 * Also implemented SSL_connect() and SSL_get_error() functions.
 * It was tested with the patched boa web server version and also
 * client utility for triggering https url.
 * 
 * Copyright (C) 2005 Nicolas Thill <nthill@free.fr>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * Portions borrowed from MatrixSSL example code
 *
 */

#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>

#include "../matrixssl_helper.h"

#define SSL_SOCKET_EOF  0x0001
#define SSL_SOCKET_CLOSE_NOTIFY  0x0002

#define SSL_OPERATION_NONE	SSL_ERROR_NONE
#define SSL_OPERATION_READ	SSL_ERROR_WANT_READ
#define SSL_OPERATION_WRITE	SSL_ERROR_WANT_WRITE

#ifndef min
#define min(a, b)  ( (a) < (b) ) ? (a) : (b)
#endif

static int _ssl_buf_alloc(sslBuf_t* buf, int size)
{
	buf->buf = buf->start = buf->end = (unsigned char *)malloc(size);
	if (buf->buf)
	{
		buf->size = size;
		return 0;		
	}
	return -1;
}

static int _ssl_buf_compact(sslBuf_t* buf)
{
	if (buf->start != buf->buf)
	{
		if (buf->start == buf->end)
		{
			buf->start = buf->end = buf->buf;
		}
		else
		{
			memmove(buf->buf, buf->start, buf->end - buf->start);
			buf->end -= (buf->start - buf->buf);
			buf->start = buf->buf;
		}	
	}
	return buf->end - buf->start;
}

static int _ssl_buf_grow(sslBuf_t* buf)
{
	int new_size = buf->size * 2;
	unsigned char* new_buf; 

	if (!(new_buf = (unsigned char *)realloc(buf->buf, new_size)))
	{
		return -1;
	}
	buf->end += (new_buf - buf->start);
	buf->buf = buf->start = new_buf;
	buf->size = new_size;
	return 0;
}

static int _ssl_handle_io_error(SSL *ssl, int return_code)
{
	if (return_code == 0)
	{
		ssl->status = SSL_SOCKET_EOF;
		return 0;
	}
	if (errno == EAGAIN)
	{
		ssl->error = ssl->operation;
		ssl->status = 0;
	}
	else
	{
		ssl->error = SSL_ERROR_SYSCALL;
		ssl->status = errno;
	}
	return -1;
}

static int _ssl_read_flush(SSL *ssl, void *buf, int len)
{
	int bytes = 0;
	if (ssl->outbuf.start < ssl->outbuf.end)
	{
		/* which is less len or remaining? */
		bytes = (int)min(len, (ssl->outbuf.end - ssl->outbuf.start));
		memcpy(buf, ssl->outbuf.start, bytes);
		ssl->outbuf.start += bytes;
	}
	if (ssl->outbuf.start == ssl->outbuf.end)
	{
		ssl->outbuf.start = ssl->outbuf.end = ssl->outbuf.buf;
	}
	return bytes;
}

static int _ssl_send(SSL *ssl)
{
	int bytes;
	int total_bytes = (int)(ssl->outbuf.end - ssl->outbuf.start);
	ssl->operation = SSL_OPERATION_WRITE;
	while (ssl->outbuf.start != ssl->outbuf.end)
	{
		bytes = send(ssl->fd, (char *)ssl->outbuf.start, 
			(int)(ssl->outbuf.end - ssl->outbuf.start), MSG_NOSIGNAL);
		if (bytes <= 0)
		{
			return _ssl_handle_io_error(ssl, bytes);
		}
		ssl->outbuf.start += bytes;
	}
	ssl->outbuf.start = ssl->outbuf.end = ssl->outbuf.buf;
	ssl->operation = SSL_OPERATION_NONE;
	if (ssl->decode_error)
	{
		ssl->error = SSL_ERROR_SSL;
		ssl->decode_error = 0;
		return -1;
	}
	return total_bytes;
}

static int _ssl_recv(SSL *ssl)
{
	int bytes;
	ssl->operation = SSL_OPERATION_READ;
	bytes = recv(ssl->fd, (char *)ssl->inbuf.end, 
		(int)((ssl->inbuf.buf + ssl->inbuf.size) - ssl->inbuf.end),
		MSG_NOSIGNAL);
	if (bytes <= 0)
	{
		return _ssl_handle_io_error(ssl, bytes);
	}
	ssl->inbuf.end += bytes;
	ssl->operation = SSL_OPERATION_NONE;
	return bytes;
}

static int _ssl_read(SSL *ssl, void *buf, int len)
{
	int rc = 0;
	unsigned char error, alertLevel, alertDescription;

	ssl->status = 0;
	
	if (ssl->operation == SSL_OPERATION_NONE)
	{
		if ((rc = _ssl_read_flush(ssl, buf, len)))
		{
			return rc;
		}
	}
	else if(ssl->operation == SSL_OPERATION_WRITE)
	{
		if ((rc = _ssl_send(ssl)) <= 0)
		{
			return rc;
		}
	}
	if (_ssl_buf_compact(&ssl->inbuf) == 0 /* If empty */
		|| ssl->operation == SSL_OPERATION_READ)
	{
		if ((rc = _ssl_recv(ssl)) <= 0)
		{
			return rc;
		}
	}
	
	error = 0;
	alertLevel = 0;
	alertDescription = 0;

	while ((rc = matrixSslDecode(ssl->ssl, &ssl->inbuf, &ssl->outbuf, &error,
		&alertLevel, &alertDescription)) == SSL_FULL)
	{
		_ssl_buf_compact(&ssl->outbuf);
		if (_ssl_buf_grow(&ssl->outbuf))
		{   /* bufer grow failed */
			ssl->error = SSL_ERROR_SYSCALL;
			return -1;
		}
	}
	
	switch (rc)
	{
		case SSL_ALERT:
			if (alertDescription == SSL_ALERT_CLOSE_NOTIFY)
			{
				ssl->status = SSL_SOCKET_CLOSE_NOTIFY;
				return 0;
			}
			ssl->error = SSL_ERROR_SSL;
			return -1;
		case SSL_ERROR:
			/* We might need to send out buffer before reporting error */
			ssl->decode_error = 1;
		case SSL_SEND_RESPONSE:
			rc = _ssl_send(ssl);
			if (rc > 0)
			{
				return rc = 0;
			}
			return rc;
		case SSL_PARTIAL:
			if (ssl->inbuf.end == (ssl->inbuf.buf + ssl->inbuf.size) && 
				_ssl_buf_grow(&ssl->inbuf))
			{   /* bufer grow failed */
				ssl->error = SSL_ERROR_SYSCALL;
				return -1;
			}
			ssl->operation = SSL_OPERATION_READ;
			return 0;
		case SSL_SUCCESS:
		case SSL_PROCESS_DATA:
			return _ssl_read_flush(ssl, buf, len);
	}
	/* unknown rc code - should not happen. */
	ssl->error = SSL_ERROR_SSL;
	return -1;
}

int _ssl_write(SSL *ssl, const void *buf, int len)
{
	int		rc;
	ssl->status = 0;
	
	if (ssl->outbuf.end > ssl->outbuf.start) /* If not empty */
	{
		return _ssl_send(ssl);
	}

	_ssl_buf_compact(&ssl->outbuf);	
	while ((rc = matrixSslEncode(ssl->ssl, (unsigned char *)buf,
		len, &ssl->outbuf)) == SSL_FULL)
	{
		if (_ssl_buf_grow(&ssl->outbuf))
		{   /* bufer grow failed */
			ssl->error = SSL_ERROR_SYSCALL;
			return -1;
		}		
	}
	if (rc < 0)
	{
		ssl->error = SSL_ERROR_SSL;		
		return -1;
	}
	return 0;	
}

int SSL_get_error(const SSL *ssl, int ret)
{
	if (ret > 0)
	{
		return SSL_ERROR_NONE;
	}
	if (ret == 0)
	{
		return SSL_ERROR_ZERO_RETURN;
	}
	return ssl->error;
}

SSL * SSL_new(sslKeys_t *keys)
{
	SSL * ssl;
	ssl = (SSL *)malloc(sizeof(SSL));
	if (!ssl) return 0;

	memset(ssl, 0, sizeof(SSL));
	ssl->fd = -1;		
	ssl->keys = keys;
	if (matrixSslNewSession(&(ssl->ssl), ssl->keys, NULL,
		ssl->keys ? SSL_FLAGS_SERVER : 0) < 0)
	{
		SSL_free(ssl);
		return 0;
	}

	_ssl_buf_alloc(&ssl->inbuf, 1024);	
	_ssl_buf_alloc(&ssl->outbuf, 1024);	
	return ssl;
}

int SSL_accept(SSL *ssl)
{

	char buf[1024];
	int rc;
	
	while ((rc = _ssl_read(ssl, buf, sizeof(buf))) == 0
			&& ssl->status != SSL_SOCKET_EOF 
			&& ssl->status != SSL_SOCKET_CLOSE_NOTIFY)
	{
		if (matrixSslHandshakeIsComplete(ssl->ssl))
		{
			return 1;
		}
	}	
	if (rc > 0)
	{
		ssl->error = SSL_ERROR_SSL;
		rc = -1;
	}
	return rc;
}

#ifdef USE_CLIENT_SIDE_SSL
int SSL_connect(SSL *ssl)
{
	int rc;
	if (ssl->operation == SSL_OPERATION_NONE)
	{		
		while ((rc = matrixSslEncodeClientHello(ssl->ssl, &ssl->outbuf, 0))
			== SSL_FULL)
		{
			if (_ssl_buf_grow(&ssl->outbuf))
			{   /* bufer grow failed */
				ssl->error = SSL_ERROR_SYSCALL;
				return -1;
			}
		}
		if (rc == SSL_ERROR)
		{
			ssl->error = SSL_ERROR_SSL;		
			return -1;
		}
		ssl->operation = SSL_OPERATION_WRITE;
	}
	return SSL_accept(ssl);
}
#endif /* USE_CLIENT_SIDE_SSL */

void SSL_set_fd(SSL *ssl, int fd)
{
	ssl->fd = fd;
}

int SSL_read(SSL *ssl, void *buf, int len)
{
	int rc;
	while ((rc = _ssl_read(ssl, buf, len)) == 0
			&& ssl->status != SSL_SOCKET_EOF 
			&& ssl->status != SSL_SOCKET_CLOSE_NOTIFY)
	; /* Nothing to do inside loop */
	return rc;
}


int SSL_write(SSL *ssl, const void *buf, int len)
{
	int rc;
	if (len > SSL_MAX_PLAINTEXT_LEN)
	{
		/* Use smaller chunks */
		len = SSL_MAX_PLAINTEXT_LEN;
	}
	while ((rc = _ssl_write(ssl, buf, len)) == 0
			&& ssl->status != SSL_SOCKET_EOF) 
	; /* Nothing to do inside loop */
	if (rc > 0)
	{ /* _ssl_write returns length of encrypted data. */
		rc = len;
	}
	return rc;
}


void SSL_free(SSL * ssl)
{
	if (!ssl)
	{
		return;
	}
	matrixSslDeleteSession(ssl->ssl);
	if (ssl->inbuf.buf)
	{
		free(ssl->inbuf.buf);
	}
	if (ssl->outbuf.buf)
	{
		free(ssl->outbuf.buf);
	}
	free(ssl);
}
