hi all,

i encounter strange behavior of SSL_write, which causes program to
loop and take 100% cpu. i admit the handling of the library is not
ideal in this case, but nevertheless i think that the function should
return error, not to loop.
see attatched source file - there is simple server which handles ssl
connections. on line 235 there is a sleep(). if the client initiates
the renegotiation while server sleeps, the consecutive call to
SSL_write loops.
as client i am using just: openssl s_client -connect $IP:55555

so the repro is following:
compile source: g++  ssl_server_poll.c -lssl
start server: ./a.out
connect client: openssl s_client -connect $IP:55555
type whatever in client<ENTER>
now server hits the sleep(3) so you have 3secs to start renegotiation:
type R<ENTER> in client
you should see that client disconnects with something like:

3074251016:error:140940F5:SSL routines:SSL3_READ_BYTES:unexpected
record:s3_pkt.c:1405

but server starts to run on 100%

tried that on debian and on redhat 6, openssl version: OpenSSL 1.0.1e-fips

what do you think?

jd
#include <errno.h>
#include <unistd.h>
#include <malloc.h>
#include <string.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <resolv.h>
#include "openssl/ssl.h"
#include "openssl/err.h"
#include <iostream>
#include <sys/epoll.h>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>
#include <assert.h>
#include <map>
#include <sstream>

#define FAIL    -1
#define MAXEVENTS 164

#define IS_MAIN_SOCKET(fd) (server_fd == fd)
#define LOGE(x) cerr<<x<<endl
#define LOGI(x) cout<<x<<endl

using namespace std;

enum SslState { NON_SSL, CONNECTING, READY, SHUTDOWN, STATE_MAX };
struct SslInfo
{
	SSL *ssl;
	SslState state;
	SslInfo(SSL *s): ssl(s) { state = CONNECTING; }
	~SslInfo() 
	{
		SSL_free(ssl);
	}
};



int server_fd;
int efd = -1;
SSL_CTX *m_sslCtx;
map<int, SslInfo*> m_conInfo;



#define SSL_ERROR() ERR_print_errors_cb(ssl_print_cb, NULL)

int ssl_print_cb(const char *str, size_t len, void *u)
{
	cerr<<string(str,len)<<endl;
	return 0;
}

int createServer(int port)
{   int sd;
	struct sockaddr_in addr;

	sd = socket(PF_INET, SOCK_STREAM, 0);
	int a = 1;
	setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &a,sizeof(int));
	bzero(&addr, sizeof(addr));
	addr.sin_family = AF_INET;
	addr.sin_port = htons(port);
	addr.sin_addr.s_addr = INADDR_ANY;
	if ( bind(sd, (struct sockaddr*)&addr, sizeof(addr)) != 0 )
	{
		perror("can't bind port");
		abort();
	}
	if ( listen(sd, 10) != 0 )
	{
		perror("Can't configure listening port");
		abort();
	}
	return sd;
}

SSL_CTX* InitServerCTX(void)
{  
	SSL_CTX *ctx;

	OpenSSL_add_all_algorithms();  /* load & register all cryptos, etc. */
	SSL_load_error_strings();   /* load all error messages */
	ctx = SSL_CTX_new(SSLv3_server_method());   /* create new context from method */
	if ( ctx == NULL )
	{
		ERR_print_errors_fp(stderr);
		abort();
	}
	return ctx;
}

void LoadCertificates(SSL_CTX* ctx, char* CertFile, char* KeyFile)
{
	/* set the local certificate from CertFile */
	if ( SSL_CTX_use_certificate_file(ctx, CertFile, SSL_FILETYPE_PEM) <= 0 )
	{
		//ERR_print_errors_fp(stderr);
		SSL_ERROR();
		abort();
	}
	/* set the private key from KeyFile (may be the same as CertFile) */
	if ( SSL_CTX_use_PrivateKey_file(ctx, KeyFile, SSL_FILETYPE_PEM) <= 0 )
	{
		ERR_print_errors_fp(stderr);
		abort();
	}
	/* verify private key */
	if ( !SSL_CTX_check_private_key(ctx) )
	{
		fprintf(stderr, "Private key does not match the public certificate\n");
		abort();
	}
}

void ShowCerts(SSL* ssl)
{   X509 *cert;
	char *line;

	cert = SSL_get_peer_certificate(ssl); /* Get certificates (if available) */
	if ( cert != NULL )
	{
		printf("Server certificates:\n");
		line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0);
		printf("Subject: %s\n", line);
		free(line);
		line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0);
		printf("Issuer: %s\n", line);
		free(line);
		X509_free(cert);
	}
	else
		printf("No certificates.\n");
}

//const char* resp="{\"addr\":\"172.16.0.8/32\",\"dns\":\"8.8.8.8\",\"mtu\":\"1400\",\"pdu\":\"hs_resp\",\"route\":\"0.0.0.0/0\"}";


int doWrite(SSL *ssl, string buf)
{
	bool doWrite = true;
	int wcnt = 0;
	int r;
	while (doWrite && wcnt<10) 
	{
		r = SSL_write(ssl, buf.c_str(), buf.size());
		switch (SSL_get_error(ssl, r)) {
			case SSL_ERROR_WANT_WRITE:
				cout<<"SSL_read returned want_write"<<endl;
				wcnt++;
				continue;
			case SSL_ERROR_WANT_READ:
				cout<<"SSL_read returned want read"<<endl;
				wcnt++;
				continue;

			case SSL_ERROR_NONE:
				doWrite = false;
				return r;
		}
	}
	return -1;
}

void closeSslConnection(SSL *ssl)
{
	int fd = SSL_get_fd(ssl);
	SslInfo *sslInfo = m_conInfo[fd];
	delete sslInfo;
	m_conInfo.erase(fd);
}

#define SZ 8194
void processEvent(struct epoll_event *ev) /* Serve the connection -- threadable */
{   
	if (!ev) {
		LOGE("ev == NULL");
		return;
	}

	int fd = ev->data.fd;

	SslInfo *sslInfo = m_conInfo[fd];

	if (!sslInfo) {
		LOGE("sslInfo == null");
		close(fd);
		return;
	}
	SSL *ssl = sslInfo->ssl;
	if (!ssl) {
		LOGE("ssl == null");
		close(fd);
		return;
	}

	char buf[SZ];
	int r;
	if (sslInfo->state == CONNECTING) {
		r = SSL_accept(ssl);
		switch (SSL_get_error(ssl, r)) {
			case SSL_ERROR_WANT_WRITE:
			case SSL_ERROR_WANT_READ:
				cout<<"still connecting...want R/W"<<endl;
				return;
			case SSL_ERROR_NONE:
				cout<<"connected..."<<endl;
				sslInfo->state = READY;
				return;
		}
	} else {
		if (ev->events & EPOLLIN) {
			string reply;
			int rcnt = 0;

			while (rcnt<10) 
			{
				r = SSL_read(ssl, buf, sizeof(buf));

				int err = SSL_get_error(ssl, r);
				cerr<<"ERR="<<err<<" r="<<r<<endl;
				switch (err) {
					case SSL_ERROR_WANT_WRITE:
						cout<<"SSL_read returned want_write"<<endl;
						rcnt++;
						continue;
					case SSL_ERROR_WANT_READ:
						{
							cout<<"SSL_read returned want read"<<endl;
							sleep(3);	
							LOGI("sending some more to client");
							string s = "some more data\n";
							if (doWrite(ssl, s) <= 0) {
								cerr<<"failed to write:"<<reply<<endl;
							}
							LOGI("more data sent");
						}

						return;
					case SSL_ERROR_ZERO_RETURN:
						//closed conn
						closeSslConnection(ssl);
						return;
					case SSL_ERROR_SSL:
					case SSL_ERROR_SYSCALL:
						//error
						closeSslConnection(ssl);
						return;

					case SSL_ERROR_NONE:
						LOGI("READ:"<<string(buf,r));
						reply = "REPLY to:"+ string(buf,r);
						if (doWrite(ssl, reply) <= 0) {
							cerr<<"failed to write:"<<reply<<endl;
						}
						continue;
					default:
						//what to do
						closeSslConnection(ssl);
						return;
				}
			}
			if (rcnt >= 10) {
				//TODO:close the socket, free the resources
				cerr<<"reached maximum rcnt"<<endl;
			}
		
		}
		else {
			cerr<<hex<<ev->events<<endl;
			if (ev->events & EPOLLOUT) {
				cerr<<"want pollout"<<endl;
			}
			return;
		}
	}

}

static int make_socket_non_blocking (int sfd)
{
	int flags, s;

	flags = fcntl (sfd, F_GETFL, 0);
	if (flags == -1) {
		perror ("fcntl");
		return -1;
	}

	flags |= O_NONBLOCK;
	s = fcntl (sfd, F_SETFL, flags);
	if (s == -1) {
		perror ("fcntl");
		return -1;
	}

	return 0;
}

#define clear_ev(e) memset(&e,0, sizeof(struct epoll_event))
bool addToEpoll(int fd, int mode)
{
	struct epoll_event event;
	clear_ev(event);

	event.data.fd = fd;
	event.events = mode | EPOLLHUP | EPOLLERR  /*| EPOLLET  EPOLLOUT*/;

	if ( epoll_ctl (efd, EPOLL_CTL_ADD, fd, &event) < 0) {
		LOGE("problem with epoll_ctl("<< fd<<")");
		LOGE(string(strerror(errno)));
		return false;
	}
	return true;
}


int acceptClient()
{
	//because of possible  event triggering it is needed to accept all signalled
	//connections, i.e.  accept returns something != EAGAIN
	LOGI("acceptClient()");
	while (1) {
		struct sockaddr_in addr;
		socklen_t len = sizeof(addr);
		SSL *ssl = NULL;

		int client = accept(server_fd, (struct sockaddr*)&addr, &len);
		if (client == -1) {
			if ((errno == EAGAIN) || (errno == EWOULDBLOCK))	{
				LOGE("accept would block");
				break;
			} else {
				LOGE("error on accept");
				LOGE(string(strerror(errno)));
				break;
			}
		}
		if (make_socket_non_blocking(client)) {
			close(client);
			break;
		}
		stringstream sin;
		sin<<inet_ntoa(addr.sin_addr)<<":"<<ntohs(addr.sin_port);
		string ipPort = sin.str();
		LOGI("New connection: "<<ipPort);

		ssl = SSL_new(m_sslCtx);
		SSL_set_fd(ssl, client);


		m_conInfo[client] = new SslInfo(ssl);
		addToEpoll(client, EPOLLIN);
	}
	return 1;
}


int main(int count, char *strings[])
{   
	int s;
	char portnum[] = "55555";
	struct epoll_event event;
	struct epoll_event *events;

	SSL_library_init();

	m_sslCtx = InitServerCTX();   
	char certFile[] = "mycert.pem";
	LoadCertificates(m_sslCtx, certFile, certFile); 
	server_fd = createServer(atoi(portnum)); 
	cout<<"server fd="<<server_fd<<endl;
	if ( make_socket_non_blocking (server_fd)) {
		abort();
	}

	efd = epoll_create1(0);

	if (efd <= 0) {
		perror("epoll_create");
		abort();
	}


	addToEpoll(server_fd, EPOLLIN);
	events = new epoll_event[MAXEVENTS];

	while (1) 
	{
		int n, i;
		n = epoll_wait(efd, events, MAXEVENTS, 30000);
	
		for (i = 0; i<n; i++) {
			int fd = events[i].data.fd;
			if ( (events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP) ) {
				if (IS_MAIN_SOCKET(fd)) {
					LOGE("error on main socket");
					assert(0);
				} else {
					close(fd);

				}
			} else {
				if (IS_MAIN_SOCKET(fd)) {
					acceptClient();
				} else {
					processEvent(&events[i]);
				}
			}
		}

	}

	close(server_fd);          /* close server socket */
	SSL_CTX_free(m_sslCtx);         /* release context */
}




Reply via email to