Mercurial > code
view C++/modules/Socket/Sockets.h @ 441:21f1a6ed0570
Socket: the full Socket constructor must also call create
author | David Demelier <markand@malikania.fr> |
---|---|
date | Fri, 23 Oct 2015 08:19:09 +0200 |
parents | aaf975293996 |
children | 44887104242a |
line wrap: on
line source
/* * Sockets.h -- portable C++ socket wrappers * * Copyright (c) 2013-2015 David Demelier <markand@malikania.fr> * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #ifndef _SOCKETS_H_ #define _SOCKETS_H_ /** * @file Sockets.h * @brief Portable socket abstraction * * This file is a portable network library. * * ### User definable options * * User may set the following variables before compiling these files: * * - **SOCKET_NO_AUTO_INIT**: (bool) Set to 0 if you don't want Socket class to * automatically calls net::init function and net::finish functions. * - **SOCKET_NO_SSL**: (bool) Set to 0 if you don't have access to OpenSSL library. * - **SOCKET_NO_AUTO_SSL_INIT**: (bool) Set to 0 if you don't want Socket class with Tls to automatically init * the OpenSSL library. You will need to call net::ssl::init and net::ssl::finish. * * ### Options for Listener class * * Feature detection, multiple implementations may be avaible, for example, * Linux has poll, select and epoll. * * We assume that select(2) is always available. * * Of course, you can set the variables yourself if you test it with your * build system. * * - **SOCKET_HAVE_POLL**: Defined on all BSD, Linux. Also defined on Windows * if _WIN32_WINNT is set to 0x0600 or greater. * - **SOCKET_HAVE_KQUEUE**: Defined on all BSD and Apple. * - **SOCKET_HAVE_EPOLL**: Defined on Linux only. */ #if defined(_WIN32) # if _WIN32_WINNT >= 0x0600 # define SOCKET_HAVE_POLL # endif #elif defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) || defined(__APPLE__) # define SOCKET_HAVE_KQUEUE # define SOCKET_HAVE_POLL #elif defined(__linux__) # define SOCKET_HAVE_EPOLL # define SOCKET_HAVE_POLL #endif /** * This sets the default backend to use depending on the system. The following * table summaries. * * The preference priority is ordered from left to right. * * | System | Backend | Class name | * |---------------|-------------------------|--------------| * | Linux | epoll(7) | Epoll | * | *BSD | kqueue(2) | Kqueue | * | Windows | poll(2), select(2) | Poll, Select | * | Mac OS X | kqueue(2) | Kqueue | */ #if defined(_WIN32) # if defined(SOCKET_HAVE_POLL) # define SOCKET_DEFAULT_BACKEND Poll # else # define SOCKET_DEFAULT_BACKEND Select # endif #elif defined(__linux__) # include <sys/epoll.h> # define SOCKET_DEFAULT_BACKEND Epoll #elif defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) || defined(__DragonFly__) || defined(__APPLE__) # include <sys/types.h> # include <sys/event.h> # include <sys/time.h> # define SOCKET_DEFAULT_BACKEND Kqueue #else # define SOCKET_DEFAULT_BACKEND Select #endif #if defined(SOCKET_HAVE_POLL) && !defined(_WIN32) # include <poll.h> #endif #if defined(_WIN32) # include <cstdlib> # include <WinSock2.h> # include <WS2tcpip.h> #else # include <cerrno> # include <sys/ioctl.h> # include <sys/types.h> # include <sys/socket.h> # include <sys/un.h> # include <arpa/inet.h> # include <netinet/in.h> # include <fcntl.h> # include <netdb.h> # include <unistd.h> #endif #if !defined(SOCKET_NO_SSL) # include <openssl/err.h> # include <openssl/evp.h> # include <openssl/ssl.h> #endif #include <chrono> #include <cstdlib> #include <cstring> #include <exception> #include <map> #include <memory> #include <string> #include <vector> /** * General network namespace. */ namespace net { /* * Portables types * ------------------------------------------------------------------ * * The following types are defined differently between Unix and Windows. */ /* {{{ Types */ #if defined(_WIN32) /** * Socket type, SOCKET. */ using Handle = SOCKET; /** * Argument to pass to set. */ using ConstArg = const char *; /** * Argument to pass to get. */ using Arg = char *; #else /** * Socket type, int. */ using Handle = int; /** * Argument to pass to set. */ using ConstArg = const void *; /** * Argument to pass to get. */ using Arg = void *; #endif /* }}} */ /* * Portable constants * ------------------------------------------------------------------ * * These constants are needed to check functions return codes, they are rarely needed in end user code. */ /* {{{ Constants */ /* * The following constants are defined differently from Unix * to Windows. */ #if defined(_WIN32) /** * Socket creation failure or invalidation. */ extern const Handle Invalid; /** * Socket operation failure. */ extern const int Failure; #else /** * Socket creation failure or invalidation. */ extern const int Invalid; /** * Socket operation failure. */ extern const int Failure; #endif /* }}} */ /* * Portable functions * ------------------------------------------------------------------ * * The following free functions can be used to initialize the library or to get the last system error. */ /* {{{ Functions */ /** * Initialize the socket library. Except if you defined SOCKET_NO_AUTO_INIT, you don't need to call this * function manually. */ void init() noexcept; /** * Close the socket library. */ void finish() noexcept; #if !defined(SOCKET_NO_SSL) /** * OpenSSL namespace. */ namespace ssl { /** * Initialize the OpenSSL library. Except if you defined SOCKET_NO_AUTO_SSL_INIT, you don't need to call this function * manually. */ void init() noexcept; /** * Close the OpenSSL library. */ void finish() noexcept; } // !ssl #endif // SOCKET_NO_SSL /** * Get the last socket system error. The error is set from errno or from * WSAGetLastError on Windows. * * @return a string message */ std::string error(); /** * Get the last system error. * * @param errn the error number (errno or WSAGetLastError) * @return the error */ std::string error(int errn); /* }}} */ /* * Error class * ------------------------------------------------------------------ * * This is the main exception thrown on socket operations. */ /* {{{ Error */ /** * @class Error * @brief Base class for sockets error */ class Error : public std::exception { public: /** * @enum Code * @brief Which kind of error */ enum Code { WouldBlockRead, ///!< The operation would block for reading WouldBlockWrite, ///!< The operation would block for writing Timeout, ///!< The action did timeout System ///!< There is a system error }; private: Code m_code; std::string m_function; std::string m_error; public: /** * Constructor that use the last system error. * * @param code which kind of error * @param function the function name */ Error(Code code, std::string function); /** * Constructor that use the system error set by the user. * * @param code which kind of error * @param function the function name * @param error the error */ Error(Code code, std::string function, int error); /** * Constructor that set the error specified by the user. * * @param code which kind of error * @param function the function name * @param error the error */ Error(Code code, std::string function, std::string error); /** * Get which function has triggered the error. * * @return the function name (e.g connect) */ inline const std::string &function() const noexcept { return m_function; } /** * The error code. * * @return the code */ inline Code code() const noexcept { return m_code; } /** * Get the error (only the error content). * * @return the error */ const char *what() const noexcept { return m_error.c_str(); } }; /* }}} */ /* * Base Socket class * ------------------------------------------------------------------ * * This base class has operations that are common to all types of sockets but you usually instanciate * a SocketTcp or SocketUdp */ /* {{{ Socket */ /** * @class Socket * @brief Base socket class for socket operations */ template <typename Address, typename Type> class Socket { private: Type m_type; protected: /** * The native handle. */ Handle m_handle{Invalid}; public: /** * This tries to create a socket. * * @param type the type instance */ inline Socket(Type type = Type{}) noexcept : Socket{Address::domain(), Type::type(), 0} { /* Some implementation requires more things */ m_type = std::move(type); m_type.create(*this); } /** * Construct a socket with an already created descriptor. * * @param handle the native descriptor * @param type the type of socket implementation */ explicit inline Socket(Handle handle, Type type = Type{}) noexcept : m_type{std::move(type)} , m_handle{handle} { } /** * Create a socket handle. * * @param domain the domain AF_* * @param type the type SOCK_* * @param protocol the protocol * @param iface the implementation * @throw Error on failures */ Socket(int domain, int type, int protocol, Type iface = Type{}) : m_type{std::move(iface)} { #if !defined(SOCKET_NO_AUTO_INIT) init(); #endif m_handle = ::socket(domain, type, protocol); if (m_handle == Invalid) { throw Error{Error::System, "socket"}; } m_type.create(*this); } /** * Copy constructor deleted. */ Socket(const Socket &) = delete; /** * Transfer ownership from other to this. * * @param other the other socket */ inline Socket(Socket &&other) noexcept : m_handle{other.m_handle} { /* Invalidate other */ other.m_handle = -1; } /** * Default destructor. */ virtual ~Socket() { close(); } /** * Set an option for the socket. * * @param level the setting level * @param name the name * @param arg the value * @throw Error on error */ template <typename Argument> inline void set(int level, int name, const Argument &arg) { #if defined(_WIN32) if (setsockopt(m_handle, level, name, (ConstArg)&arg, sizeof (arg)) == Failure) #else if (setsockopt(m_handle, level, name, (ConstArg)&arg, sizeof (arg)) < 0) #endif throw Error{Error::System, "set"}; } /** * Get an option for the socket. * * @param level the setting level * @param name the name * @throw Error on error */ template <typename Argument> inline Argument get(int level, int name) { Argument desired, result{}; socklen_t size = sizeof (result); #if defined(_WIN32) if (getsockopt(m_handle, level, name, (Arg)&desired, &size) == Failure) #else if (getsockopt(m_handle, level, name, (Arg)&desired, &size) < 0) #endif throw Error{Error::System, "get"}; std::memcpy(&result, &desired, size); return result; } /** * Get the native handle. * * @return the handle * @warning Not portable */ inline Handle handle() const noexcept { return m_handle; } /** * Set the blocking mode, if set to false, the socket will be marked * **non-blocking**. * * @param block set to false to mark **non-blocking** * @throw Error on any error */ void setBlockMode(bool block) { #if defined(O_NONBLOCK) && !defined(_WIN32) int flags; if ((flags = fcntl(m_handle, F_GETFL, 0)) == -1) { flags = 0; } if (block) { flags &= ~(O_NONBLOCK); } else { flags |= O_NONBLOCK; } if (fcntl(m_handle, F_SETFL, flags) == Failure) { throw Error{Error::System, "setBlockMode"}; } #else unsigned long flags = (block) ? 0 : 1; if (ioctlsocket(m_handle, FIONBIO, &flags) == Failure) { throw Error{Error::System, "setBlockMode"}; } #endif } /** * Bind using a native address. * * @param address the address * @param length the size */ inline void bind(const sockaddr *address, socklen_t length) { if (::bind(m_handle, address, length) == Failure) { throw Error{Error::System, "bind"}; } } /** * Overload that takes an address. * * @param address the address * @throw Error on errors */ inline void bind(const Address &address) { bind(address.address(), address.length()); } /** * Listen for pending connection. * * @param max the maximum number */ inline void listen(int max = 128) { if (::listen(this->m_handle, max) == Failure) { throw Error{Error::System, "listen"}; } } /** * Connect to an address. * * @param address the address * @throw Error on errors */ inline void connect(const Address &address) { m_type.connect(*this, address); } /** * Accept a pending connection. * * @param info the client address * @return the new socket * @throw Error on errors */ inline Socket<Address, Type> accept(Address &info) { return m_type.accept(*this, info); } /** * Overloaded function without information. * * @return the new socket * @throw Error on errors */ inline Socket<Address, Type> accept() { Address dummy; return accept(dummy); } /** * Get the local name. This is a wrapper of getsockname(). * * @return the address * @throw Error on failures */ Address address() const { // TODO: to reimplement return {}; } /** * Receive some data. * * @param data the destination buffer * @param length the buffer length * @throw Error on error */ inline unsigned recv(void *data, unsigned length) { return m_type.recv(*this, data, length); } /** * Overloaded function. * * @param count the number of bytes to receive * @return the string * @throw Error on error */ inline std::string recv(unsigned count) { std::string result; result.resize(count); auto n = recv(const_cast<char *>(result.data()), count); result.resize(n); return result; } /** * Send some data. * * @param data the data buffer * @param length the buffer length * @throw Error on error */ inline unsigned send(const void *data, unsigned length) { return m_type.send(*this, data, length); } /** * Overloaded function. * * @param data the string to send * @return the number of bytes sent * @throw Error on error */ inline unsigned send(const std::string &data) { return send(data.c_str(), data.size()); } /** * Send data to an end point. * * @param data the buffer * @param length the buffer length * @param address the client address * @return the number of bytes sent * @throw Error on error */ inline unsigned sendto(const void *data, unsigned length, const Address &address) { return m_type.sendto(*this, data, length, address); } /** * Overloaded function. * * @param data the data * @param address the address * @return the number of bytes sent * @throw Error on error */ inline unsigned sendto(const std::string &data, const Address &address) { return sendto(data.c_str(), data.length(), address); } /** * Receive data from an end point. * * @param data the destination buffer * @param length the buffer length * @param info the client information * @return the number of bytes received * @throw Error on error */ unsigned recvfrom(void *data, unsigned length, Address &info) { return m_type.recvfrom(*this, data, length, info); } /** * Overloaded function. * * @param count the maximum number of bytes to receive * @param info the client information * @return the string * @throw Error on error */ inline std::string recvfrom(unsigned count, Address &info) { std::string result; result.resize(count); auto n = recvfrom(const_cast<char *>(result.data()), count, info); result.resize(n); return result; } /** * Overloaded function. * * @param count the number of bytes to read * @return the string * @throw Error on errors */ inline std::string recvfrom(unsigned count) { Address dummy; return recvfrom(count, dummy); } /** * Close the socket. * * Automatically called from the destructor. */ virtual void close() { if (m_handle != Invalid) { #if defined(_WIN32) ::closesocket(m_handle); #else ::close(m_handle); #endif m_handle = Invalid; } } /** * Assignment operator forbidden. * * @return *this */ Socket &operator=(const Socket &) = delete; /** * Transfer ownership from other to this. The other socket is left * invalid and will not be closed. * * @param other the other socket * @return this */ Socket &operator=(Socket &&other) noexcept { m_handle = other.m_handle; /* Invalidate other */ other.m_handle = Invalid; return *this; } }; /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if they equals */ template <typename Address, typename Type> bool operator==(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() == s2.handle(); } /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if they are different */ template <typename Address, typename Type> bool operator!=(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() != s2.handle(); } /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if s1 < s2 */ template <typename Address, typename Type> bool operator<(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() < s2.handle(); } /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if s1 > s2 */ template <typename Address, typename Type> bool operator>(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() > s2.handle(); } /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if s1 <= s2 */ template <typename Address, typename Type> bool operator<=(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() <= s2.handle(); } /** * Compare two sockets. * * @param s1 the first socket * @param s2 the second socket * @return true if s1 >= s2 */ template <typename Address, typename Type> bool operator>=(const Socket<Address, Type> &s1, const Socket<Address, Type> &s2) { return s1.handle() >= s2.handle(); } /* }}} */ /* * Predefine addressed to be used * ------------------------------------------------------------------ * * - Ipv6, * - Ipv4, * - Local. */ /* {{{ Addresses */ /** * @class Ip * @brief Base class for IPv6 and IPv4, don't use it directly */ class Ip { private: friend class Ipv6; friend class Ipv4; union { sockaddr_in m_sin; sockaddr_in6 m_sin6; }; socklen_t m_length{0}; int m_domain{AF_INET}; Ip(int domain) noexcept; Ip(const std::string &host, int port, int domain); Ip(const struct sockaddr_storage *ss, socklen_t length); public: /** * Return the underlying address, either sockaddr_in6 or sockaddr_in. * * @return the address */ inline const sockaddr *address() const noexcept { if (m_domain == AF_INET6) { return reinterpret_cast<const sockaddr *>(&m_sin6); } return reinterpret_cast<const sockaddr *>(&m_sin); } /** * Return the underlying address length. * * @return the length */ inline socklen_t length() const noexcept { return m_length; } }; /** * @class Ipv6 * @brief Use IPv6 address */ class Ipv6 : public Ip { public: /** * Get the domain AF_INET6. * * @return AF_INET6 */ static inline int domain() noexcept { return AF_INET6; } /** * Construct an empty address. */ inline Ipv6() noexcept : Ip{AF_INET6} { } /** * Construct an address on a specific host. * * @param host the host ("*" for any) * @param port the port * @throw Error on errors */ inline Ipv6(const std::string &host, int port) : Ip{host, port, AF_INET6} { } /** * Construct an address from a storage. * * @param ss the storage * @param length the length */ inline Ipv6(const struct sockaddr_storage *ss, socklen_t length) : Ip{ss, length} { } }; /** * @class Ipv4 * @brief Use IPv4 address */ class Ipv4 : public Ip { public: /** * Get the domain AF_INET. * * @return AF_INET */ static inline int domain() noexcept { return AF_INET; } /** * Construct an empty address. */ inline Ipv4() noexcept : Ip{AF_INET} { } /** * Construct an address on a specific host. * * @param host the host ("*" for any) * @param port the port * @throw Error on errors */ inline Ipv4(const std::string &host, int port) : Ip{host, port, AF_INET} { } /** * Construct an address from a storage. * * @param ss the storage * @param length the length */ inline Ipv4(const struct sockaddr_storage *ss, socklen_t length) : Ip{ss, length} { } }; #if !defined(_WIN32) /** * @class Local * @brief unix family sockets * * Create an address to a specific path. Only available on Unix. */ class Local { private: sockaddr_un m_sun; std::string m_path; public: /** * Get the domain AF_LOCAL. * * @return AF_LOCAL */ static inline int domain() noexcept { return AF_LOCAL; } /** * Default constructor. */ Local() = default; /** * Construct an address to a path. * * @param path the path * @param rm remove the file before (default: false) */ Local(std::string path, bool rm = false); /** * Construct an unix address from a storage address. * * @param ss the storage * @param length the length */ Local(const sockaddr_storage &ss, socklen_t length); /** * Get the sockaddr_un. * * @return the address */ inline const sockaddr *address() const noexcept { return reinterpret_cast<const sockaddr *>(&m_sun); } /** * Get the address length. * * @return the length */ inline socklen_t length() const noexcept { #if defined(SOCKET_HAVE_SUN_LEN) return SUN_LEN(&m_sun); #else return sizeof (m_sun); #endif } }; #endif // !_WIN32 /* }}} */ /* * Predefined types * ------------------------------------------------------------------ * * - Tcp, for standard stream connections, * - Udp, for standard datagram connections, * - Tls, for secure stream connections. */ /* {{{ Types */ /* {{{ Tcp */ /** * @class Tcp * @brief Clear TCP implementation. */ class Tcp { protected: /** * Standard accept. * * @param sc the socket * @param address the address destination * @param length the address initial length * @return the client handle * @throw Error on errors */ Handle accept(Handle sc, sockaddr *address, socklen_t *length) { Handle client = ::accept(sc, address, length); if (client == Invalid) { #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockRead, "accept", error}; } throw Error{Error::System, "accept", error}; #else if (errno == EAGAIN || errno == EWOULDBLOCK) { throw Error{Error::WouldBlockRead, "accept"}; } throw Error{Error::System, "accept"}; #endif } return client; } /** * Standard connect. * * @param sc the socket * @param address the address * @param length the length */ void connect(Handle sc, const sockaddr *address, socklen_t length) { if (::connect(sc, address, length) == Failure) { /* * Determine if the error comes from a non-blocking connect that cannot be * accomplished yet. */ #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockWrite, "connect", error}; } throw Error{Error::System, "connect", error}; #else if (errno == EINPROGRESS) { throw Error{Error::WouldBlockWrite, "connect"}; } throw Error{Error::System, "connect"}; #endif } } public: /** * Socket type. * * @return SOCK_STREAM */ static inline int type() noexcept { return SOCK_STREAM; } /** * Do nothing. */ template <typename Address> inline void create(Socket<Address, Tcp> &) noexcept { /* No-op */ } /** * Accept a clear client. Wrapper of accept(2). * * @param sc the socket * @param address the address destination * @return the socket * @throw Error on errors */ template <typename Address> Socket<Address, Tcp> accept(Socket<Address, Tcp> &sc, Address &address) { sockaddr_storage ss; socklen_t length = sizeof (sockaddr_storage); Handle handle = accept(sc.handle(), reinterpret_cast<sockaddr *>(&ss), &length); address = Address{&ss, length}; return Socket<Address, Tcp>{handle}; } /** * Connect to the end point. Wrapper for connect(2). * * @param sc the socket * @param address the address * @throw Error on errors */ template <typename Address> void connect(Socket<Address, Tcp> &sc, const Address &address) { connect(sc.handle(), address.address(), address.length()); } /** * Receive data. Wrapper of recv(2). * * @param sc the socket * @param data the destination * @param length the destination length * @return the number of bytes read * @throw Error on errors */ template <typename Address> unsigned recv(Socket<Address, Tcp> &sc, void *data, unsigned length) { int nbread; nbread = ::recv(sc.handle(), (Arg)data, length, 0); if (nbread == Failure) { #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockRead, "recv", error}; } throw Error{Error::System, "recv", error}; #else if (errno == EAGAIN || errno == EWOULDBLOCK) { throw Error{Error::WouldBlockRead, "recv"}; } throw Error{Error::System, "recv"}; #endif } return static_cast<unsigned>(nbread); } /** * Send some data. Wrapper for send(2). * * @param sc the socket * @param data the buffer to send * @param length the buffer length * @return the number of bytes sent * @throw Error on errors */ template <typename Address> unsigned send(Socket<Address, Tcp> &sc, const void *data, unsigned length) { int nbsent; nbsent = ::send(sc.handle(), (ConstArg)data, length, 0); if (nbsent == Failure) { #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockWrite, "send", error}; } throw Error{Error::System, "send", error}; #else if (errno == EAGAIN || errno == EWOULDBLOCK) { throw Error{Error::WouldBlockWrite, "send"}; } throw Error{Error::System, "send"}; #endif } return static_cast<unsigned>(nbsent); } }; /* }}} */ /* {{{ Udp */ /** * @class Udp * @brief Clear UDP type. * * This class is the basic implementation of UDP sockets. */ class Udp { public: /** * Socket type. * * @return SOCK_DGRAM */ static inline int type() noexcept { return SOCK_DGRAM; } /** * Do nothing. */ template <typename Address> inline void create(Socket<Address, Udp> &) noexcept { /* No-op */ } /** * Receive data from an end point. * * @param sc the socket * @param data the destination buffer * @param length the buffer length * @param info the client information * @return the number of bytes received * @throw Error on error */ template <typename Address> unsigned recvfrom(Socket<Address, Udp> &sc, void *data, unsigned length, Address &info) { int nbread; /* Store information */ sockaddr_storage address; socklen_t addrlen = sizeof (struct sockaddr_storage); nbread = ::recvfrom(sc.handle(), (Arg)data, length, 0, reinterpret_cast<sockaddr *>(&address), &addrlen); info = Address{&address, addrlen}; if (nbread == Failure) { #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockRead, "recvfrom", error}; } throw Error{Error::System, "recvfrom", error}; #else if (errno == EAGAIN || errno == EWOULDBLOCK) { throw Error{Error::WouldBlockRead, "recvfrom"}; } throw Error{Error::System, "recvfrom"}; #endif } return static_cast<unsigned>(nbread); } /** * Send data to an end point. * * @param sc the socket * @param data the buffer * @param length the buffer length * @param address the client address * @return the number of bytes sent * @throw Error on error */ template <typename Address> unsigned sendto(Socket<Address, Udp> &sc, const void *data, unsigned length, const Address &address) { int nbsent; nbsent = ::sendto(sc.handle(), (ConstArg)data, length, 0, address.address(), address.length()); if (nbsent == Failure) { #if defined(_WIN32) int error = WSAGetLastError(); if (error == WSAEWOULDBLOCK) { throw Error{Error::WouldBlockWrite, "sendto", error}; } throw Error{Error::System, "sendto", error}; #else if (errno == EAGAIN || errno == EWOULDBLOCK) { throw Error{Error::WouldBlockWrite, "sendto"}; } throw Error{Error::System, "sendto"}; #endif } return static_cast<unsigned>(nbsent); } }; /* }}} */ /* {{{ Tls */ #if !defined(SOCKET_NO_SSL) /** * @class Tls * @brief OpenSSL secure layer for TCP */ class Tls : private Tcp { public: /** * OpenSSL method to use. */ enum Method { Tlsv1, //!< Tlsv1 (recommended) Sslv3 //!< SSL v3 }; private: using Context = std::unique_ptr<SSL_CTX, void (*)(SSL_CTX *)>; using Ssl = std::unique_ptr<SSL, void (*)(SSL *)>; /* OpenSSL objects */ Context m_context{nullptr, nullptr}; Ssl m_ssl{nullptr, nullptr}; /* Parameters */ Method m_method; std::string m_key; std::string m_certificate; bool m_verify{false}; Tls(Context context, Ssl ssl) : m_context{std::move(context)} , m_ssl{std::move(ssl)} { } inline std::string error(int error) { return ERR_reason_error_string(error); } public: /** * @copydoc Tcp::type */ static inline int type() noexcept { return SOCK_STREAM; } /** * Empty TLS constructor. */ Tls() { #if !defined(SOCKET_NO_SSL_AUTO_INIT) ssl::init(); #endif } /** * Construct a specific Tls object. * * @param method the method to use * @param verify true to verify the certificate * @param key the private key * @param certificate the certificate file */ Tls(Method method, bool verify = true, std::string key = "", std::string certificate = "") : Tls() { m_method = method; m_verify = verify; m_key = std::move(key); m_certificate = std::move(certificate); } /** * Initialize the SSL objects after have created. * * @param sc the socket */ template <typename Address> inline void create(Socket<Address, Tls> &sc) { auto method = (m_method == Tlsv1) ? TLSv1_method() : SSLv3_method(); m_context = {SSL_CTX_new(method), SSL_CTX_free}; m_ssl = {SSL_new(m_context.get()), SSL_free}; SSL_set_fd(m_ssl.get(), sc.handle()); /* Load certificates */ if (m_certificate.size() > 0) SSL_CTX_use_certificate_file(m_context.get(), m_certificate.c_str(), SSL_FILETYPE_PEM); if (m_key.size() > 0) SSL_CTX_use_PrivateKey_file(m_context.get(), m_key.c_str(), SSL_FILETYPE_PEM); if (m_verify && !SSL_CTX_check_private_key(m_context.get())) { throw Error(Error::System, "accept", "certificate failure"); } } /** * Connect to a secure host. * * @param sc the socket * @param address the address * @throw Error on errors */ template <typename Address> void connect(Socket<Address, Tls> &sc, const Address &address) { /* 1. Standard connect */ Tcp::connect(sc.handle(), address.address(), address.length()); /* 2. OpenSSL handshake */ auto ret = SSL_connect(m_ssl.get()); if (ret <= 0) { auto no = SSL_get_error(m_ssl.get(), ret); if (no == SSL_ERROR_WANT_READ) { throw Error{Error::WouldBlockRead, "connect", "Operation in progress"}; } else if (no == SSL_ERROR_WANT_WRITE) { throw Error{Error::WouldBlockWrite, "connect", "Operation in progress"}; } else { throw Error{Error::System, "connect", error(no)}; } } } /** * Accept a secure client. * * @param sc the socket * @param address the address destination * @return the client */ template <typename Address> Socket<Address, Tls> accept(Socket<Address, Tls> &sc, Address &address) { /* 1. Do standard accept */ sockaddr_storage ss; socklen_t length = sizeof (sockaddr_storage); Handle handle = Tcp::accept(sc.handle(), reinterpret_cast<sockaddr *>(&ss), &length); address = Address{&ss, length}; /* 2. Create OpenSSL related stuff */ auto method = (m_method == Tlsv1) ? TLSv1_method() : SSLv3_method(); auto context = Context{SSL_CTX_new(method), SSL_CTX_free}; auto ssl = Ssl{SSL_new(context.get()), SSL_free}; SSL_set_fd(ssl.get(), handle); /* 3. Do the OpenSSL accept */ auto ret = SSL_accept(ssl.get()); if (ret <= 0) { auto no = SSL_get_error(ssl.get(), ret); if (no == SSL_ERROR_WANT_READ) { throw Error(Error::WouldBlockRead, "accept", "Operation would block"); } else if (no == SSL_ERROR_WANT_WRITE) { throw Error(Error::WouldBlockWrite, "accept", "Operation would block"); } else { throw Error(Error::System, "accept", error(no)); } } return Socket<Address, Tls>{handle, Tls{std::move(context), std::move(ssl)}}; } /** * Receive some secure data. * * @param data the destination * @param len the buffer length * @return the number of bytes read * @throw Error on errors */ template <typename Address> unsigned recv(Socket<Address, Tls> &, void *data, unsigned len) { auto nbread = SSL_read(m_ssl.get(), data, len); if (nbread <= 0) { auto no = SSL_get_error(m_ssl.get(), nbread); if (no == SSL_ERROR_WANT_READ) { throw Error{Error::WouldBlockRead, "recv", "Operation would block"}; } else if (no == SSL_ERROR_WANT_WRITE) { throw Error{Error::WouldBlockWrite, "recv", "Operation would block"}; } else { throw Error{Error::System, "recv", error(no)}; } } return nbread; } /** * Send some data. * * @param data the data to send * @param len the buffer length * @return the number of bytes sent * @throw Error on errors */ template <typename Address> unsigned send(Socket<Address, Tls> &, const void *data, unsigned len) { auto nbread = SSL_write(m_ssl.get(), data, len); if (nbread <= 0) { auto no = SSL_get_error(m_ssl.get(), nbread); if (no == SSL_ERROR_WANT_READ) { throw Error{Error::WouldBlockRead, "send", "Operation would block"}; } else if (no == SSL_ERROR_WANT_WRITE) { throw Error{Error::WouldBlockWrite, "send", "Operation would block"}; } else { throw Error{Error::System, "send", error(no)}; } } return nbread; } }; #endif // !SOCKET_NO_SSL /* }}} */ /* }}} */ /* * Convenient helpers * ------------------------------------------------------------------ * * - SocketTcp<Address>, for TCP sockets, * - SocketUdp<Address>, for UDP sockets, * - SocketTls<Address>, for secure TCP sockets. */ /* {{{ Helpers */ /** * Helper to create TCP sockets. */ template <typename Address> using SocketTcp = Socket<Address, Tcp>; /** * Helper to create UDP sockets. */ template <typename Address> using SocketUdp = Socket<Address, Udp>; #if !defined(SOCKET_NO_SSL) /** * Helper to create OpenSSL TCP sockets. */ template <typename Address> using SocketTls = Socket<Address, Tls>; #endif /* }}} */ /* * Select wrapper * ------------------------------------------------------------------ * * Wrapper for select(2) and other various implementations. */ /* {{{ Listener */ /** * @struct ListenerStatus * @brief Result of polling * * Result of a select call, returns the first ready socket found with its * flags. */ class ListenerStatus { public: Handle socket; //!< which socket is ready int flags; //!< the flags }; /** * Table used in the socket listener to store which sockets have been * set in which directions. */ using ListenerTable = std::map<Handle, int>; /** * @class Select * @brief Implements select(2) * * This class is the fallback of any other method, it is not preferred at all for many reasons. */ class Select { public: /** * Backend identifier */ inline const char *name() const noexcept { return "select"; } /** * No-op, uses the ListenerTable directly. */ inline void set(const ListenerTable &, Handle, int, bool) noexcept {} /** * No-op, uses the ListenerTable directly. */ inline void unset(const ListenerTable &, Handle, int, bool) noexcept {} /** * Return the sockets */ std::vector<ListenerStatus> wait(const ListenerTable &table, int ms); }; #if defined(SOCKET_HAVE_POLL) /** * @class Poll * @brief Implements poll(2) * * Poll is widely supported and is better than select(2). It is still not the * best option as selecting the sockets is O(n). */ class Poll { private: std::vector<pollfd> m_fds; short topoll(int flags) const noexcept; int toflags(short &event) const noexcept; public: void set(const ListenerTable &, Handle sc, int flags, bool add); void unset(const ListenerTable &, Handle sc, int flags, bool remove); std::vector<ListenerStatus> wait(const ListenerTable &, int ms); /** * Backend identifier */ inline const char *name() const noexcept { return "poll"; } }; #endif #if defined(SOCKET_HAVE_EPOLL) class Epoll { private: int m_handle; std::vector<struct epoll_event> m_events; Epoll(const Epoll &) = delete; Epoll &operator=(const Epoll &) = delete; Epoll(const Epoll &&) = delete; Epoll &operator=(const Epoll &&) = delete; uint32_t toepoll(int flags) const noexcept; int toflags(uint32_t events) const noexcept; void update(Handle sc, int op, int flags); public: Epoll(); ~Epoll(); void set(const ListenerTable &, Handle sc, int flags, bool add); void unset(const ListenerTable &, Handle sc, int flags, bool remove); std::vector<ListenerStatus> wait(const ListenerTable &table, int ms); /** * Backend identifier */ inline const char *name() const noexcept { return "epoll"; } }; #endif #if defined(SOCKET_HAVE_KQUEUE) /** * @class Kqueue * @brief Implements kqueue(2) * * This implementation is available on all BSD and Mac OS X. It is better than * poll(2) because it's O(1), however it's a bit more memory consuming. */ class Kqueue { private: std::vector<struct kevent> m_result; int m_handle; Kqueue(const Kqueue &) = delete; Kqueue &operator=(const Kqueue &) = delete; Kqueue(Kqueue &&) = delete; Kqueue &operator=(Kqueue &&) = delete; void update(Handle sc, int filter, int flags); public: Kqueue(); ~Kqueue(); void set(const ListenerTable &, Handle sc, int flags, bool add); void unset(const ListenerTable &, Handle sc, int flags, bool remove); std::vector<ListenerStatus> wait(const ListenerTable &, int ms); /** * Backend identifier */ inline const char *name() const noexcept { return "kqueue"; } }; #endif /** * Mark the socket for read operation. */ extern const int FlagRead; /** * Mark the socket for write operation. */ extern const int FlagWrite; /** * @class Listener * @brief Synchronous multiplexing * * Convenient wrapper around the select() system call. * * This class is implemented using a bridge pattern to allow different uses * of listener implementation. * * You should not reinstanciate a new Listener at each iteartion of your * main loop as it can be extremely costly. Instead use the same listener that * you can safely modify on the fly. * * Currently, poll, epoll, select and kqueue are available. * * To implement the backend, the following functions must be available: * * ### Set * * @code * void set(const ListenerTable &, Handle sc, int flags, bool add); * @endcode * * This function, takes the socket to be added and the flags. The flags are * always guaranteed to be correct and the function will never be called twice * even if the user tries to set the same flag again. * * An optional add argument is added for backends which needs to do different * operation depending if the socket was already set before or if it is the * first time (e.g EPOLL_CTL_ADD vs EPOLL_CTL_MOD for epoll(7). * * ### Unset * * @code * void unset(const ListenerTable &, Handle sc, int flags, bool remove); * @endcode * * Like set, this function is only called if the flags are actually set and will * not be called multiple times. * * Also like set, an optional remove argument is set if the socket is being * completely removed (e.g no more flags are set for this socket). * * ### Wait * * @code * std::vector<ListenerStatus> wait(const ListenerTable &, int ms); * @endcode * * Wait for the sockets to be ready with the specified milliseconds. Must return a list of ListenerStatus, * may throw any exceptions. * * ### Name * * @code * inline const char *name() const noexcept * @endcode * * Returns the backend name. Usually the class in lower case. */ template <typename Backend = SOCKET_DEFAULT_BACKEND> class Listener { public: private: Backend m_backend; ListenerTable m_table; public: /** * Construct an empty listener. */ Listener() = default; /** * Get the backend. * * @return the backend */ inline const Backend &backend() const noexcept { return m_backend; } /** * Get the non-modifiable table. * * @return the table */ inline const ListenerTable &table() const noexcept { return m_table; } /** * Overloaded function. * * @return the iterator */ inline ListenerTable::const_iterator begin() const noexcept { return m_table.begin(); } /** * Overloaded function. * * @return the iterator */ inline ListenerTable::const_iterator cbegin() const noexcept { return m_table.cbegin(); } /** * Overloaded function. * * @return the iterator */ inline ListenerTable::const_iterator end() const noexcept { return m_table.end(); } /** * Overloaded function. * * @return the iterator */ inline ListenerTable::const_iterator cend() const noexcept { return m_table.cend(); } /** * Add or update a socket to the listener. * * If the socket is already placed with the appropriate flags, the * function is a no-op. * * If incorrect flags are passed, the function does nothing. * * @param sc the socket * @param flags (may be OR'ed) * @throw Error if the backend failed to set */ void set(Handle sc, int flags) { /* Invalid or useless flags */ if (flags == 0 || flags > 0x3) return; auto it = m_table.find(sc); /* * Do not update the table if the backend failed to add * or update. */ if (it == m_table.end()) { m_backend.set(m_table, sc, flags, true); m_table.emplace(sc, flags); } else { if ((flags & FlagRead) && (it->second & FlagRead)) { flags &= ~(FlagRead); } if ((flags & FlagWrite) && (it->second & FlagWrite)) { flags &= ~(FlagWrite); } /* Still need a call? */ if (flags != 0) { m_backend.set(m_table, sc, flags, false); it->second |= flags; } } } /** * Unset a socket from the listener, only the flags is removed * unless the two flagss are requested. * * For example, if you added a socket for both reading and writing, * unsetting the write flags will keep the socket for reading. * * @param sc the socket * @param flags the flags (may be OR'ed) * @see remove */ void unset(Handle sc, int flags) { auto it = m_table.find(sc); /* Invalid or useless flags */ if (flags == 0 || flags > 0x3 || it == m_table.end()) return; /* * Like set, do not update if the socket is already at the appropriate * state. */ if ((flags & FlagRead) && !(it->second & FlagRead)) { flags &= ~(FlagRead); } if ((flags & FlagWrite) && !(it->second & FlagWrite)) { flags &= ~(FlagWrite); } if (flags != 0) { /* Determine if it's a complete removal */ bool removal = ((it->second) & ~(flags)) == 0; m_backend.unset(m_table, sc, flags, removal); if (removal) { m_table.erase(it); } else { it->second &= ~(flags); } } } /** * Remove completely the socket from the listener. * * It is a shorthand for unset(sc, FlagRead | FlagWrite); * * @param sc the socket */ inline void remove(Handle sc) { unset(sc, FlagRead | FlagWrite); } /** * Remove all sockets. */ inline void clear() { while (!m_table.empty()) { remove(*m_table.begin()); } } /** * Get the number of sockets in the listener. */ inline ListenerTable::size_type size() const noexcept { return m_table.size(); } /** * Select a socket. Waits for a specific amount of time specified as the duration. * * @param duration the duration * @return the socket ready */ template <typename Rep, typename Ratio> inline ListenerStatus wait(const std::chrono::duration<Rep, Ratio> &duration) { auto cvt = std::chrono::duration_cast<std::chrono::milliseconds>(duration); return m_backend.wait(m_table, cvt.count())[0]; } /** * Overload with milliseconds. * * @param timeout the optional timeout in milliseconds * @return the socket ready */ inline ListenerStatus wait(int timeout = -1) { return wait(std::chrono::milliseconds(timeout)); } /** * Select multiple sockets. * * @param duration the duration * @return the socket ready */ template <typename Rep, typename Ratio> inline std::vector<ListenerStatus> waitMultiple(const std::chrono::duration<Rep, Ratio> &duration) { auto cvt = std::chrono::duration_cast<std::chrono::milliseconds>(duration); return m_backend.wait(m_table, cvt.count()); } /** * Overload with milliseconds. * * @return the socket ready */ inline std::vector<ListenerStatus> waitMultiple(int timeout = -1) { return waitMultiple(std::chrono::milliseconds(timeout)); } }; /* }}} */ } // !net #endif // !_SOCKETS_H_