view libirccd/irccd/daemon/transport_service.cpp @ 627:8c0942ee6e63

Irccd: do not use ranges for error codes
author David Demelier <markand@malikania.fr>
date Tue, 02 Jan 2018 15:53:07 +0100
parents 4515082ee83f
children 27587ff92a64
line wrap: on
line source

/*
 * transport_service.cpp -- transport service
 *
 * Copyright (c) 2013-2017 David Demelier <markand@malikania.fr>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <irccd/sysconfig.hpp>

#include <cassert>

#include <irccd/string_util.hpp>

#include "command_service.hpp"
#include "ip_transport_server.hpp"
#include "irccd.hpp"
#include "logger.hpp"
#include "transport_client.hpp"
#include "transport_service.hpp"

#if !defined(IRCCD_SYSTEM_WINDOWS)
#   include "local_transport_server.hpp"
#endif

#if defined(HAVE_SSL)
#   include "tls_transport_server.hpp"
#endif

namespace irccd {

namespace {

std::unique_ptr<transport_server> load_transport_ip(boost::asio::io_service& service, const ini::section& sc)
{
    assert(sc.key() == "transport");

    std::unique_ptr<transport_server> transport;
    ini::section::const_iterator it;

    // Port.
    if ((it = sc.find("port")) == sc.cend())
        throw std::invalid_argument("missing 'port' parameter");

    auto port = string_util::to_uint<std::uint16_t>(it->value());

    // Address.
    std::string address = "*";

    if ((it = sc.find("address")) != sc.end())
        address = it->value();

    // 0011
    //    ^ define IPv4
    //   ^  define IPv6
    auto mode = 1U;

    /*
     * Documentation stated family but code checked for 'domain' option.
     *
     * As irccdctl uses domain, accept both and unify the option name to 'family'.
     *
     * See #637
     */
    if ((it = sc.find("domain")) != sc.end() || (it = sc.find("family")) != sc.end()) {
        mode = 0U;

        for (const auto& v : *it) {
            if (v == "ipv4")
                mode |= (1U << 0);
            if (v == "ipv6")
                mode |= (1U << 1);
        }
    }

    if (mode == 0U)
        throw std::invalid_argument("family must at least have ipv4 or ipv6");

    auto protocol = (mode & 0x2U)
        ? boost::asio::ip::tcp::v4()
        : boost::asio::ip::tcp::v6();

    // Optional SSL.
    std::string pkey;
    std::string cert;

    if ((it = sc.find("ssl")) != sc.end() && string_util::is_boolean(it->value())) {
        if ((it = sc.find("certificate")) == sc.end())
            throw std::invalid_argument("missing 'certificate' parameter");

        cert = it->value();

        if ((it = sc.find("key")) == sc.end())
            throw std::invalid_argument("missing 'key' parameter");

        pkey = it->value();
    }

    auto endpoint = (address == "*")
        ? boost::asio::ip::tcp::endpoint(protocol, port)
        : boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(address), port);

    boost::asio::ip::tcp::acceptor acceptor(service, endpoint, true);

    if (pkey.empty())
        return std::make_unique<ip_transport_server>(std::move(acceptor));

#if defined(HAVE_SSL)
    boost::asio::ssl::context ctx(boost::asio::ssl::context::sslv23);

    ctx.use_private_key_file(pkey, boost::asio::ssl::context::pem);
    ctx.use_certificate_file(cert, boost::asio::ssl::context::pem);

    return std::make_unique<tls_transport_server>(std::move(acceptor), std::move(ctx));
#else
    throw std::invalid_argument("SSL disabled");
#endif
}

std::unique_ptr<transport_server> load_transport_unix(boost::asio::io_service& service, const ini::section& sc)
{
    using boost::asio::local::stream_protocol;

    assert(sc.key() == "transport");

#if !defined(IRCCD_SYSTEM_WINDOWS)
    ini::section::const_iterator it = sc.find("path");

    if (it == sc.end())
        throw std::invalid_argument("missing 'path' parameter");

    // Remove the file first.
    std::remove(it->value().c_str());

    stream_protocol::endpoint endpoint(it->value());
    stream_protocol::acceptor acceptor(service, std::move(endpoint));

    return std::make_unique<local_transport_server>(std::move(acceptor));
#else
    (void)sc;

    throw std::invalid_argument("unix transports not supported on on this platform");
#endif
}

std::unique_ptr<transport_server> load_transport(boost::asio::io_service& service, const ini::section& sc)
{
    assert(sc.key() == "transport");

    std::unique_ptr<transport_server> transport;
    ini::section::const_iterator it = sc.find("type");

    if (it == sc.end())
        throw std::invalid_argument("missing 'type' parameter");

    if (it->value() == "ip")
        transport = load_transport_ip(service, sc);
    else if (it->value() == "unix")
        transport = load_transport_unix(service, sc);
    else
        throw std::invalid_argument(string_util::sprintf("invalid type given: %s", it->value()));

    if ((it = sc.find("password")) != sc.end())
        transport->set_password(it->value());

    return transport;
}

} // !namespace

void transport_service::handle_command(std::shared_ptr<transport_client> tc, const nlohmann::json& object)
{
    assert(object.is_object());

    auto name = object.find("command");

    if (name == object.end() || !name->is_string()) {
        tc->error(irccd_error::invalid_message);
        return;
    }

    auto cmd = irccd_.commands().find(*name);

    if (!cmd)
        tc->error(irccd_error::invalid_command, name->get<std::string>());
    else {
        try {
            cmd->exec(irccd_, *tc, object);
        } catch (const boost::system::system_error& ex) {
            tc->error(ex.code(), cmd->get_name());
        } catch (const std::exception& ex) {
            irccd_.log().warning() << "transport: unknown error not reported" << std::endl;
            irccd_.log().warning() << "transport: " << ex.what() << std::endl;
        }
    }
}

void transport_service::do_recv(std::shared_ptr<transport_client> tc)
{
    tc->recv([this, tc] (auto code, auto json) {
        switch (code.value()) {
        case boost::system::errc::network_down:
            irccd_.log().warning("transport: client disconnected");
            break;
            case boost::system::errc::invalid_argument:
            tc->error(irccd_error::invalid_message);
            break;
        default:
            handle_command(tc, json);

            if (tc->state() == transport_client::state_t::ready)
                do_recv(std::move(tc));

            break;
        }
    });
}

void transport_service::do_accept(transport_server& ts)
{
    ts.accept([this, &ts] (auto code, auto client) {
        if (code)
            irccd_.log().warning() << "transport: new client error: " << code.message() << std::endl;
        else {
            do_accept(ts);
            do_recv(std::move(client));

            irccd_.log().info() << "transport: new client connected" << std::endl;
        }
    });
}

transport_service::transport_service(irccd& irccd) noexcept
    : irccd_(irccd)
{
}

transport_service::~transport_service() noexcept = default;

void transport_service::add(std::unique_ptr<transport_server> ts)
{
    assert(ts);

    do_accept(*ts);
    servers_.push_back(std::move(ts));
}

void transport_service::broadcast(const nlohmann::json& json)
{
    assert(json.is_object());

    for (const auto& servers : servers_)
        for (const auto& client : servers->clients())
            client->send(json);
}

void transport_service::load(const config& cfg) noexcept
{
    for (const auto& section : cfg.doc()) {
        if (section.key() != "transport")
            continue;

        try {
            add(load_transport(irccd_.service(), section));
        } catch (const std::exception& ex) {
            irccd_.log().warning() << "transport: " << ex.what() << std::endl;
        }
    }
}

} // !irccd