Mercurial > irccd
changeset 235:5566cf772073
Irccd: support SSL transport part 2
author | David Demelier <markand@malikania.fr> |
---|---|
date | Fri, 12 Aug 2016 15:01:16 +0200 |
parents | 3c631fb06ccf |
children | 5fa15cd7ffe0 |
files | lib/irccd/connection.cpp lib/irccd/connection.hpp lib/irccd/irccdctl.cpp |
diffstat | 3 files changed, 195 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/lib/irccd/connection.cpp Fri Aug 12 13:16:59 2016 +0200 +++ b/lib/irccd/connection.cpp Fri Aug 12 15:01:16 2016 +0200 @@ -26,13 +26,18 @@ namespace irccd { +/* + * Connection. + * ------------------------------------------------------------------ + */ + void Connection::syncInput() { try { std::string buffer; buffer.resize(512); - buffer.resize(m_socket.recv(&buffer[0], buffer.size())); + buffer.resize(recv(&buffer[0], buffer.size())); if (buffer.empty()) throw std::runtime_error("connection lost"); @@ -47,7 +52,7 @@ void Connection::syncOutput() { try { - auto ns = m_socket.send(m_output.data(), m_output.length()); + auto ns = send(m_output.data(), m_output.length()); if (ns > 0) m_output.erase(0, ns); @@ -57,6 +62,16 @@ } } +unsigned Connection::recv(char *buffer, unsigned length) +{ + return m_socket.recv(buffer, length); +} + +unsigned Connection::send(const char *buffer, unsigned length) +{ + return m_socket.send(buffer, length); +} + Connection::Connection() : m_state(std::make_unique<DisconnectedState>()) { @@ -114,4 +129,99 @@ } } +/* + * TlsConnection. + * ------------------------------------------------------------------ + */ + +void TlsConnection::handshake() +{ + try { + m_ssl->handshake(); + m_handshake = HandshakeReady; + } catch (const net::WantReadError &) { + m_handshake = HandshakeRead; + } catch (const net::WantWriteError &) { + m_handshake = HandshakeWrite; + } catch (const std::exception &ex) { + m_state = std::make_unique<DisconnectedState>(); + onDisconnect(ex.what()); + } +} + +unsigned TlsConnection::recv(char *buffer, unsigned length) +{ + unsigned nread = 0; + + try { + nread = m_ssl->recv(buffer, length); + } catch (const net::WantReadError &) { + m_handshake = HandshakeRead; + } catch (const net::WantWriteError &) { + m_handshake = HandshakeWrite; + } + + return nread; +} + +unsigned TlsConnection::send(const char *buffer, unsigned length) +{ + unsigned nsent = 0; + + try { + nsent = m_ssl->send(buffer, length); + } catch (const net::WantReadError &) { + m_handshake = HandshakeRead; + } catch (const net::WantWriteError &) { + m_handshake = HandshakeWrite; + } + + return nsent; +} + +void TlsConnection::connect(const net::Address &address) +{ + Connection::connect(address); + + m_ssl = std::make_unique<net::TlsSocket>(m_socket, net::TlsSocket::Client); +} + +void TlsConnection::prepare(fd_set &in, fd_set &out, net::Handle &max) +{ + if (m_state->status() == Connecting) + Connection::prepare(in, out, max); + else { + if (m_socket.handle() > max) + max = m_socket.handle(); + + /* + * Attempt an immediate handshake immediately if connection succeeded + * in last iteration. + */ + if (m_handshake == HandshakeUndone) + handshake(); + + switch (m_handshake) { + case HandshakeRead: + FD_SET(m_socket.handle(), &in); + break; + case HandshakeWrite: + FD_SET(m_socket.handle(), &out); + break; + default: + Connection::prepare(in, out, max); + } + } +} + +void TlsConnection::sync(fd_set &in, fd_set &out) +{ + if (m_state->status() == Connecting) + Connection::sync(in, out); + else if (m_handshake != HandshakeReady) + handshake(); + else + Connection::sync(in, out); +} + } // !irccd
--- a/lib/irccd/connection.hpp Fri Aug 12 13:16:59 2016 +0200 +++ b/lib/irccd/connection.hpp Fri Aug 12 15:01:16 2016 +0200 @@ -125,15 +125,28 @@ class CheckingState; class ReadyState; -private: +protected: std::unique_ptr<State> m_state; std::unique_ptr<State> m_stateNext; - -protected: net::TcpSocket m_socket{net::Invalid}; - void syncInput(); - void syncOutput(); + /** + * Try to receive some data into the given buffer. + * + * \param buffer the destination buffer + * \param length the buffer length + * \return the number of bytes received + */ + virtual unsigned recv(char *buffer, unsigned length); + + /** + * Try to send some data into the given buffer. + * + * \param buffer the source buffer + * \param length the buffer length + * \return the number of bytes sent + */ + virtual unsigned send(const char *buffer, unsigned length); public: /** @@ -147,6 +160,20 @@ virtual ~Connection(); /** + * Convenient wrapper around recv(). + * + * Must be used in sync() function. + */ + void syncInput(); + + /** + * Convenient wrapper around send(). + * + * Must be used in sync() function. + */ + void syncOutput(); + + /** * Send an asynchronous request to irccd. * * \pre json.is_object @@ -206,6 +233,51 @@ void sync(fd_set &in, fd_set &out) override; }; +/** + * \brief TLS over IP connection. + */ +class TlsConnection : public Connection { +private: + enum { + HandshakeUndone, + HandshakeRead, + HandshakeWrite, + HandshakeReady + } m_handshake{HandshakeUndone}; + +private: + std::unique_ptr<net::TlsSocket> m_ssl; + + void handshake(); + +protected: + /** + * \copydoc Connection::recv + */ + virtual unsigned recv(char *buffer, unsigned length); + + /** + * \copydoc Connection::send + */ + virtual unsigned send(const char *buffer, unsigned length); + +public: + /** + * \copydoc Connection::connect + */ + void connect(const net::Address &address) override; + + /** + * \copydoc Service::prepare + */ + void prepare(fd_set &in, fd_set &out, net::Handle &max) override; + + /** + * \copydoc Service::sync + */ + void sync(fd_set &in, fd_set &out) override; +}; + } // !irccd #endif // !IRCCD_CONNECTION_HPP
--- a/lib/irccd/irccdctl.cpp Fri Aug 12 13:16:59 2016 +0200 +++ b/lib/irccd/irccdctl.cpp Fri Aug 12 15:01:16 2016 +0200 @@ -89,6 +89,7 @@ * host = "ip or hostname" * port = "port number or service" * domain = "ipv4 or ipv6" (Optional, default: ipv4) + * ssl = true | false */ void Irccdctl::readConnectIp(const ini::Section &sc) { @@ -118,7 +119,11 @@ } m_address = net::resolveOne(host, port, domain, SOCK_STREAM); - m_connection = std::make_unique<Connection>(); + + if ((it = sc.find("ssl")) != sc.end() && util::isBoolean(it->value())) + m_connection = std::make_unique<TlsConnection>(); + else + m_connection = std::make_unique<Connection>(); } /*