diff options
-rw-r--r-- | raul/Socket.hpp | 259 | ||||
-rw-r--r-- | test/socket_test.cpp | 107 | ||||
-rw-r--r-- | wscript | 1 |
3 files changed, 367 insertions, 0 deletions
diff --git a/raul/Socket.hpp b/raul/Socket.hpp new file mode 100644 index 0000000..6523ef1 --- /dev/null +++ b/raul/Socket.hpp @@ -0,0 +1,259 @@ +/* + This file is part of Raul. + Copyright 2007-2013 David Robillard <http://drobilla.net> + + Raul is free software: you can redistribute it and/or modify it under the + terms of the GNU General Public License as published by the Free Software + Foundation, either version 3 of the License, or any later version. + + Raul is distributed in the hope that it will be useful, but WITHOUT ANY + WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR + A PARTICULAR PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Raul. If not, see <http://www.gnu.org/licenses/>. +*/ + +#ifndef RAUL_SOCKET_HPP +#define RAUL_SOCKET_HPP + +#include <memory> + +#include <errno.h> +#include <netdb.h> +#include <netinet/in.h> +#include <poll.h> +#include <stdint.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <unistd.h> + +#include "raul/Noncopyable.hpp" +#include "raul/URI.hpp" + +namespace Raul { + +/** A safe and simple interface for UNIX or TCP sockets. */ +class Socket : public Raul::Noncopyable { +public: + enum class Type { + UNIX, + TCP + }; + + /** Create a new unbound/unconnected socket of a given type. */ + explicit Socket(Type t); + + /** Wrap an existing open socket. */ + Socket(Type t, + const Raul::URI& uri, + struct sockaddr* addr, + socklen_t addr_len, + int fd); + + ~Socket(); + + /** Bind a server socket to an address. + * @param uri Address URI, e.g. unix:///tmp/foo or tcp://somehost:1234 + * @return True on success. + */ + bool bind(const Raul::URI& uri); + + /** Connect a client socket to a server address. + * @param uri Address URI, e.g. unix:///tmp/foo or tcp://somehost:1234 + * @return True on success. + */ + bool connect(const Raul::URI& uri); + + /** Mark server socket as passive to listen for incoming connections. + * @return True on success. + */ + bool listen(); + + /** Accept a connection. + * @return An new open socket for the connection. + */ + std::shared_ptr<Socket> accept(); + + /** Return the file descriptor for the socket. */ + int fd() { return _sock; } + + const Raul::URI& uri() const { return _uri; } + + /** Close the socket. */ + void close(); + + /** Shut down the socket. + * This terminates any connections associated with the sockets, and will + * (unlike close()) cause a poll on the socket to return. + */ + void shutdown(); + +private: + bool set_addr(const Raul::URI& uri); + + Type _type; + Raul::URI _uri; + struct sockaddr* _addr; + socklen_t _addr_len; + int _sock; +}; + +#ifndef NI_MAXHOST +# define NI_MAXHOST 1025 +#endif + +inline +Socket::Socket(Type t) + : _type(t) + , _uri(t == Type::UNIX ? "unix:" : "tcp:") + , _addr(NULL) + , _addr_len(0) + , _sock(-1) +{ + switch (t) { + case Type::UNIX: + _sock = socket(AF_UNIX, SOCK_STREAM, 0); + break; + case Type::TCP: + _sock = socket(AF_INET, SOCK_STREAM, 0); + break; + } +} + +inline +Socket::Socket(Type t, + const Raul::URI& uri, + struct sockaddr* addr, + socklen_t addr_len, + int fd) + : _type(t) + , _uri(uri) + , _addr(addr) + , _addr_len(addr_len) + , _sock(fd) +{ +} + +inline +Socket::~Socket() +{ + free(_addr); + close(); +} + +inline bool +Socket::set_addr(const Raul::URI& uri) +{ + free(_addr); + if (_type == Type::UNIX && uri.substr(0, strlen("unix://")) == "unix://") { + const std::string path = uri.substr(strlen("unix://")); + struct sockaddr_un* uaddr = (struct sockaddr_un*)calloc( + 1, sizeof(struct sockaddr_un)); + uaddr->sun_family = AF_UNIX; + strncpy(uaddr->sun_path, path.c_str(), sizeof(uaddr->sun_path) - 1); + _uri = uri; + _addr = (sockaddr*)uaddr; + _addr_len = sizeof(struct sockaddr_un); + return true; + } else if (_type == Type::TCP && uri.find("://") != std::string::npos) { + const std::string authority = uri.substr(uri.find("://") + 3); + const size_t port_sep = authority.find(':'); + if (port_sep == std::string::npos) { + return false; + } + + const std::string host = authority.substr(0, port_sep); + const std::string port = authority.substr(port_sep + 1).c_str(); + + struct addrinfo* ainfo; + int st = 0; + if ((st = getaddrinfo(host.c_str(), port.c_str(), NULL, &ainfo))) { + return false; + } + + _uri = uri; + _addr = (struct sockaddr*)malloc(ainfo->ai_addrlen); + _addr_len = ainfo->ai_addrlen; + memcpy(_addr, ainfo->ai_addr, ainfo->ai_addrlen); + freeaddrinfo(ainfo); + return true; + } + return false; +} + +inline bool +Socket::bind(const Raul::URI& uri) +{ + if (set_addr(uri) && ::bind(_sock, _addr, _addr_len) != -1) { + return true; + } + + return false; +} + +inline bool +Socket::connect(const Raul::URI& uri) +{ + if (set_addr(uri) && ::connect(_sock, _addr, _addr_len) != -1) { + return true; + } + + return false; +} + +inline bool +Socket::listen() +{ + if (::listen(_sock, 64) == -1) { + return false; + } else { + return true; + } +} + +inline std::shared_ptr<Socket> +Socket::accept() +{ + socklen_t client_addr_len = _addr_len; + struct sockaddr* client_addr = (struct sockaddr*)calloc( + 1, client_addr_len); + + int conn = ::accept(_sock, client_addr, &client_addr_len); + if (conn == -1) { + return std::shared_ptr<Socket>(); + } + + Raul::URI client_uri = _uri; + char host[NI_MAXHOST]; + if (getnameinfo(client_addr, client_addr_len, + host, sizeof(host), NULL, 0, 0)) { + client_uri = Raul::URI(_uri.scheme() + "://" + host); + } + + return std::shared_ptr<Socket>( + new Socket(_type, client_uri, client_addr, client_addr_len, conn)); +} + +inline void +Socket::close() +{ + if (_sock != -1) { + ::close(_sock); + _sock = -1; + } +} + +inline void +Socket::shutdown() +{ + if (_sock != -1) { + ::shutdown(_sock, SHUT_RDWR); + } +} + +} // namespace Raul + +#endif // RAUL_SOCKET_HPP diff --git a/test/socket_test.cpp b/test/socket_test.cpp new file mode 100644 index 0000000..5c7ffae --- /dev/null +++ b/test/socket_test.cpp @@ -0,0 +1,107 @@ +/* + This file is part of Raul. + Copyright 2007-2013 David Robillard <http://drobilla.net> + + Raul is free software: you can redistribute it and/or modify it under the + terms of the GNU General Public License as published by the Free Software + Foundation, either version 3 of the License, or any later version. + + Raul is distributed in the hope that it will be useful, but WITHOUT ANY + WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR + A PARTICULAR PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Raul. If not, see <http://www.gnu.org/licenses/>. +*/ + +#include <stdio.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +#include "raul/Socket.hpp" + +using namespace std; +using namespace Raul; + +int +main(int argc, char** argv) +{ + Raul::URI unix_uri("unix:///tmp/raul_test_sock"); + Raul::URI tcp_uri("tcp://localhost:12345"); + + Raul::Socket unix_server_sock(Socket::Type::UNIX); + Raul::Socket tcp_server_sock(Socket::Type::TCP); + if (!unix_server_sock.bind(unix_uri)) { + fprintf(stderr, "Failed to bind UNIX server socket\n"); + return 1; + } else if (!unix_server_sock.listen()) { + fprintf(stderr, "Failed to listen on UNIX server socket\n"); + return 1; + } else if (!tcp_server_sock.bind(tcp_uri)) { + fprintf(stderr, "Failed to bind TCP server socket\n"); + return 1; + } else if (!tcp_server_sock.listen()) { + fprintf(stderr, "Failed to listen on TCP server socket\n"); + return 1; + } + + const pid_t child_pid = fork(); + if (child_pid) { + // This is the parent (server) process + int status = 0; + waitpid(child_pid, &status, 0); + + struct pollfd pfds[2]; + pfds[0].fd = unix_server_sock.fd(); + pfds[0].events = POLLIN; + pfds[0].revents = 0; + pfds[1].fd = tcp_server_sock.fd(); + pfds[1].events = POLLIN; + pfds[1].revents = 0; + + unsigned n_received = 0; + while (n_received < 2) { + const int ret = poll(pfds, 2, -1); + if (ret == -1) { + fprintf(stderr, "poll error (%s)\n", strerror(errno)); + break; + } else if ((pfds[0].revents & POLLHUP) || pfds[1].revents & POLLHUP) { + break; + } else if (ret == 0) { + fprintf(stderr, "poll returned with no data\n"); + continue; + } + + if (pfds[0].revents & POLLIN) { + std::shared_ptr<Socket> conn = unix_server_sock.accept(); + ++n_received; + } + + if (pfds[1].revents & POLLIN) { + std::shared_ptr<Socket> conn = tcp_server_sock.accept(); + ++n_received; + } + } + + unix_server_sock.shutdown(); + tcp_server_sock.shutdown(); + unlink("/tmp/raul_test_sock"); + fprintf(stderr, "n received: %d\n", n_received); + return n_received != 2; + } + + // This is the child (client) process + Raul::Socket unix_sock(Socket::Type::UNIX); + Raul::Socket tcp_sock(Socket::Type::TCP); + + if (!unix_sock.connect(unix_uri)) { + fprintf(stderr, "Failed to connect to UNIX socket\n"); + return 1; + } else if (!tcp_sock.connect(tcp_uri)) { + fprintf(stderr, "Failed to connect to TCP socket\n"); + return 1; + } + + return 0; +} @@ -76,6 +76,7 @@ tests = ''' test/queue_test test/ringbuffer_test test/sem_test + test/socket_test test/symbol_test test/thread_test test/time_test |