view lib/irccd/conn.c @ 1037:8f8ce47aba8a

make: switch to GNU make
author David Demelier <markand@malikania.fr>
date Tue, 27 Apr 2021 09:22:16 +0200
parents 3ea3361f0fc7
children 89478faef566
line wrap: on
line source

/*
 * conn.c -- an IRC server channel
 *
 * Copyright (c) 2013-2021 David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <sys/socket.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <poll.h>
#include <string.h>
#include <unistd.h>

#include "conn.h"
#include "log.h"
#include "server.h"
#include "util.h"

static void
cleanup(struct irc_conn *conn)
{
	if (conn->fd != 0)
		close(conn->fd);

#if defined(IRCCD_WITH_SSL)
	if (conn->ssl)
		SSL_free(conn->ssl);
	if (conn->ctx)
		SSL_CTX_free(conn->ctx);

	conn->ssl_cond = IRC_CONN_SSL_ACT_NONE;
	conn->ssl_step = IRC_CONN_SSL_ACT_NONE;
	conn->ssl = NULL;
	conn->ctx = NULL;
#endif

	conn->state = IRC_CONN_STATE_NONE;
	conn->fd = -1;
}

static inline void
scan(char **line, char **str)
{
	char *p = strchr(*line, ' ');

	if (p)
		*p = '\0';

	*str = *line;
	*line = p ? p + 1 : strchr(*line, '\0');
}

static int
parse(struct irc_conn_msg *msg, const char *line)
{
	char *ptr = msg->buf;
	size_t a;

	memset(msg, 0, sizeof (*msg));
	strlcpy(msg->buf, line, sizeof (msg->buf));

	/*
	 * IRC message is defined as following:
	 *
	 * [:prefix] command arg1 arg2 [:last-argument]
	 */
	if (*ptr == ':')
		scan((++ptr, &ptr), &msg->prefix);     /* prefix */

	scan(&ptr, &msg->cmd);                         /* command */

	/* And finally arguments. */
	for (a = 0; *ptr && a < IRC_UTIL_SIZE(msg->args); ++a) {
		if (*ptr == ':') {
			msg->args[a] = ptr + 1;
			ptr = strchr(ptr, '\0');
		} else
			scan(&ptr, &msg->args[a]);
	}

	if (a >= IRC_UTIL_SIZE(msg->args))
		return errno = EMSGSIZE, -1;
	if (msg->cmd == NULL)
		return errno = EBADMSG, -1;

	return 0;
}

static int
create(struct irc_conn *conn)
{
	struct addrinfo *ai = conn->aip;
	int cflags = 0;

	cleanup(conn);

	if ((conn->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol)) < 0)
		return -1;
	if ((cflags = fcntl(conn->fd, F_GETFL)) < 0)
		return -1;
	if (fcntl(conn->fd, F_SETFL, cflags | O_NONBLOCK) < 0)
		return -1;

	return 0;
}

static inline int
update_ssl_state(struct irc_conn *conn, int ret)
{
	switch (SSL_get_error(conn->ssl, ret)) {
	case SSL_ERROR_WANT_READ:
		irc_log_debug("server %s: step %d now needs read condition",
		    conn->sv->name, conn->ssl_step);
		conn->ssl_cond = IRC_CONN_SSL_ACT_READ;
		break;
	case SSL_ERROR_WANT_WRITE:
		irc_log_debug("server %s: step %d now needs write condition",
		    conn->sv->name, conn->ssl_step);
		conn->ssl_cond = IRC_CONN_SSL_ACT_WRITE;
		break;
	case SSL_ERROR_SSL:
		return irc_conn_disconnect(conn), -1;
	default:
		break;
	}

	return 0;
}

static inline ssize_t
input_ssl(struct irc_conn *conn, char *dst, size_t dstsz)
{
	int nr;

	if ((nr = SSL_read(conn->ssl, dst, dstsz)) <= 0) {
		irc_log_debug("server %s: SSL read incomplete", conn->sv->name);
		conn->ssl_step = IRC_CONN_SSL_ACT_READ;
		return update_ssl_state(conn, nr);
	}

	if (conn->ssl_cond)
		irc_log_debug("server %s: condition back to normal", conn->sv->name);

	conn->ssl_cond = IRC_CONN_SSL_ACT_NONE;
	conn->ssl_step = IRC_CONN_SSL_ACT_NONE;

	return nr;
}

static inline ssize_t
input_clear(struct irc_conn *conn, char *buf, size_t bufsz)
{
	ssize_t nr;

	if ((nr = recv(conn->fd, buf, bufsz, 0)) <= 0) {
		errno = ECONNRESET;
		return irc_conn_disconnect(conn), -1;
	}

	return nr;
}

static int
input(struct irc_conn *conn)
{
	size_t len = strlen(conn->in);
	size_t cap = sizeof (conn->in) - len - 1;
	ssize_t nr = 0;

	if (conn->flags & IRC_CONN_SSL)
		nr = input_ssl(conn, conn->in + len, cap);
	else
		nr = input_clear(conn, conn->in + len, cap);

	if (nr > 0)
		conn->in[len + nr] = '\0';

	return nr;
}

static inline ssize_t
output_ssl(struct irc_conn *conn)
{
	int ns;

	if ((ns = SSL_write(conn->ssl, conn->out, strlen(conn->out))) <= 0) {
		irc_log_debug("server %s: SSL write incomplete", conn->sv->name);
		conn->ssl_step = IRC_CONN_SSL_ACT_WRITE;
		return update_ssl_state(conn, ns);
	}

	if (conn->ssl_cond)
		irc_log_debug("server %s: condition back to normal", conn->sv->name);

	conn->ssl_cond = IRC_CONN_SSL_ACT_NONE;
	conn->ssl_step = IRC_CONN_SSL_ACT_NONE;

	return ns;
}

static inline ssize_t
output_clear(struct irc_conn *conn)
{
	ssize_t ns;

	if ((ns = send(conn->fd, conn->out, strlen(conn->out), 0)) < 0)
		return irc_conn_disconnect(conn), -1;

	return ns;
}

static int
output(struct irc_conn *conn)
{
	ssize_t ns = 0;

	if (conn->flags & IRC_CONN_SSL)
		ns = output_ssl(conn);
	else
		ns = output_clear(conn);

	if (ns > 0) {
		/* Optimize if everything was sent. */
		if ((size_t)ns >= sizeof (conn->out) - 1)
			conn->out[0] = '\0';
		else
			memmove(conn->out, conn->out + ns, sizeof (conn->out) - ns);
	}

	return ns;
}

static int
handshake(struct irc_conn *conn)
{
	if (conn->flags & IRC_CONN_SSL) {
#if defined(IRCCD_WITH_SSL)
		int r;

		conn->state = IRC_CONN_STATE_HANDSHAKING;

		/*
		 * This function is called several time until it completes so we
		 * must keep the same context/ssl stuff once it has been
		 * created.
		 */
		if (!conn->ctx)
			conn->ctx = SSL_CTX_new(TLS_method());
		if (!conn->ssl) {
			conn->ssl = SSL_new(conn->ctx);

			SSL_set_fd(conn->ssl, conn->fd);
			SSL_set_connect_state(conn->ssl);
		}

		if ((r = SSL_do_handshake(conn->ssl)) <= 0)
			return update_ssl_state(conn, r);

		conn->state = IRC_CONN_STATE_READY;
		conn->ssl_cond = IRC_CONN_SSL_ACT_NONE;
		conn->ssl_step = IRC_CONN_SSL_ACT_NONE;
#endif
	} else
		conn->state = IRC_CONN_STATE_READY;

	return 0;
}

static int
dial(struct irc_conn *conn)
{
	/* No more address available. */
	if (conn->aip == NULL) {
		irc_log_warn("server %s: could not connect", conn->sv->name);
		return irc_conn_disconnect(conn), -1;
	}

	for (; conn->aip; conn->aip = conn->aip->ai_next) {
		if (create(conn) < 0)
			continue;

		/*
		 * With some luck, the connection completes immediately,
		 * otherwise we will need to wait until the socket is writable.
		 */
		if (connect(conn->fd, conn->aip->ai_addr, conn->aip->ai_addrlen) == 0)
			return handshake(conn);

		/* Connect "succeeds" but isn't complete yet. */
		if (errno == EINPROGRESS || errno == EAGAIN) {
			conn->state = IRC_CONN_STATE_CONNECTING;
			return 0;
		}
	}

	return -1;
}

static int
lookup(struct irc_conn *conn)
{
	struct addrinfo hints = {
		.ai_socktype = SOCK_STREAM,
		.ai_flags = AI_NUMERICSERV
	};
	char service[16];
	int ret;

	snprintf(service, sizeof (service), "%hu", conn->port);

	if ((ret = getaddrinfo(conn->hostname, service, &hints, &conn->ai)) != 0) {
		irc_log_warn("server %s: %s", conn->sv->name, gai_strerror(ret));
		return -1;
	}

	conn->aip = conn->ai;

	return 0;
}

static int
check_connect(struct irc_conn *conn)
{
	int res, err = -1;
	socklen_t len = sizeof (int);

	/* Determine if the non blocking connect(2) call succeeded. */
	if ((res = getsockopt(conn->fd, SOL_SOCKET, SO_ERROR, &err, &len)) < 0 || err)
		return dial(conn);

	return handshake(conn);
}

static inline void
prepare_ssl(const struct irc_conn *conn, struct pollfd *pfd)
{
#if defined(IRCCD_WITH_SSL)
	switch (conn->ssl_cond) {
	case IRC_CONN_SSL_ACT_READ:
		irc_log_debug("server %s: need read condition", conn->sv->name);
		pfd->events |= POLLIN;
		break;
	case IRC_CONN_SSL_ACT_WRITE:
		irc_log_debug("server %s: need write condition", conn->sv->name);
		pfd->events |= POLLOUT;
		break;
	default:
		break;
	}
#else
	(void)conn;
#endif
}

static inline int
renegotiate(struct irc_conn *conn)
{
	irc_log_debug("server %s: renegociate step=%d", conn->sv->name, conn->ssl_step);

	return conn->ssl_step == IRC_CONN_SSL_ACT_READ
		? input(conn)
		: output(conn);
}

int
irc_conn_connect(struct irc_conn *conn)
{
	assert(conn);

	if (lookup(conn) < 0)
		return irc_conn_disconnect(conn), -1;

	return dial(conn);
}

void
irc_conn_disconnect(struct irc_conn *conn)
{
	assert(conn);

	cleanup(conn);
}

void
irc_conn_prepare(const struct irc_conn *conn, struct pollfd *pfd)
{
	assert(conn);
	assert(pfd);

	pfd->fd = conn->fd;

	if (conn->ssl_cond)
		prepare_ssl(conn, pfd);
	else {
		switch (conn->state) {
		case IRC_CONN_STATE_CONNECTING:
			pfd->events = POLLOUT;
			break;
		case IRC_CONN_STATE_READY:
			pfd->events = POLLIN;

			if (conn->out[0])
				pfd->events |= POLLOUT;
			break;
		default:
			break;
		}
	}
}

int
irc_conn_flush(struct irc_conn *conn, const struct pollfd *pfd)
{
	assert(conn);
	assert(pfd);

	switch (conn->state) {
	case IRC_CONN_STATE_CONNECTING:
		return check_connect(conn);
	case IRC_CONN_STATE_HANDSHAKING:
		return handshake(conn);
	case IRC_CONN_STATE_READY:
		if (pfd->revents & (POLLERR | POLLHUP))
			return irc_conn_disconnect(conn), -1;

		if (conn->ssl_cond) {
			if (renegotiate(conn) < 0)
				return irc_conn_disconnect(conn), -1;
		} else {
			if (pfd->revents & POLLIN && input(conn) < 0)
				return irc_conn_disconnect(conn), -1;
			if (pfd->revents & POLLOUT && output(conn) < 0)
				return irc_conn_disconnect(conn), -1;
		}
		break;
	default:
		break;
	}

	return 0;
}

int
irc_conn_poll(struct irc_conn *conn, struct irc_conn_msg *msg)
{
	assert(conn);
	assert(msg);

	char *pos;
	size_t length;

	if (!(pos = strstr(conn->in, "\r\n")))
		return 0;

	/* Turn end of the string at delimiter. */
	*pos = 0;
	length = pos - conn->in;

	if (length > 0)
		parse(msg, conn->in);

	/* (Re)move the first message received. */
	memmove(conn->in, pos + 2, sizeof (conn->in) - (length + 2));

	return 1;
}

int
irc_conn_send(struct irc_conn *conn, const char *data)
{
	assert(conn);
	assert(data);

	if (strlcat(conn->out, data, sizeof (conn->out)) >= sizeof (conn->out))
		return errno = EMSGSIZE, -1;
	if (strlcat(conn->out, "\r\n", sizeof (conn->out)) >= sizeof (conn->out))
		return errno = EMSGSIZE, -1;

	return 0;
}

void
irc_conn_finish(struct irc_conn *conn)
{
	assert(conn);

	cleanup(conn);
}