diff options
| -rw-r--r-- | contrib/Socket.cpp | 120 | ||||
| -rw-r--r-- | contrib/Socket.h | 1 | 
2 files changed, 120 insertions, 1 deletions
| diff --git a/contrib/Socket.cpp b/contrib/Socket.cpp index 7ff6b5e..d12c970 100644 --- a/contrib/Socket.cpp +++ b/contrib/Socket.cpp @@ -409,6 +409,121 @@ bool TCPSocket::valid() const      return m_sock != -1;  } +void TCPSocket::connect(const std::string& hostname, int port, int timeout_ms) +{ +    if (m_sock != INVALID_SOCKET) { +        throw std::logic_error("You may only connect an invalid TCPSocket"); +    } + +    char service[NI_MAXSERV]; +    snprintf(service, NI_MAXSERV-1, "%d", port); + +    /* Obtain address(es) matching host/port */ +    struct addrinfo hints; +    memset(&hints, 0, sizeof(struct addrinfo)); +    hints.ai_family = AF_INET; +    hints.ai_socktype = SOCK_STREAM; +    hints.ai_flags = 0; +    hints.ai_protocol = 0; + +    struct addrinfo *result, *rp; +    int s = getaddrinfo(hostname.c_str(), service, &hints, &result); +    if (s != 0) { +        throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); +    } + +    int flags = 0; + +    /* getaddrinfo() returns a list of address structures. +       Try each address until we successfully connect(2). +       If socket(2) (or connect(2)) fails, we (close the socket +       and) try the next address. */ + +    for (rp = result; rp != nullptr; rp = rp->ai_next) { +        int sfd = ::socket(rp->ai_family, rp->ai_socktype, +                rp->ai_protocol); +        if (sfd == -1) +            continue; + +        flags = fcntl(sfd, F_GETFL); +        if (flags == -1) { +            std::string errstr(strerror(errno)); +            throw std::runtime_error("TCP: Could not get socket flags: " + errstr); +        } + +        if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) { +            std::string errstr(strerror(errno)); +            throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); +        } + +        int ret = ::connect(sfd, rp->ai_addr, rp->ai_addrlen); +        if (ret == 0) { +            m_sock = sfd; +            break; +        } +        if (ret == -1 and errno == EINPROGRESS) { +            m_sock = sfd; +            struct pollfd fds[1]; +            fds[0].fd = m_sock; +            fds[0].events = POLLOUT; + +            int retval = poll(fds, 1, timeout_ms); + +            if (retval == -1) { +                std::string errstr(strerror(errno)); +                ::close(m_sock); +                freeaddrinfo(result); +                throw runtime_error("TCP: connect error on poll: " + errstr); +            } +            else if (retval > 0) { +                int so_error = 0; +                socklen_t len = sizeof(so_error); + +                if (getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &so_error, &len) == -1) { +                    std::string errstr(strerror(errno)); +                    ::close(m_sock); +                    freeaddrinfo(result); +                    throw runtime_error("TCP: getsockopt error connect: " + errstr); +                } + +                if (so_error == 0) { +                    break; +                } +            } +            else { +                ::close(m_sock); +                freeaddrinfo(result); +                throw runtime_error("Timeout on connect"); +            } +            break; +        } + +        ::close(sfd); +    } + +    if (m_sock != INVALID_SOCKET) { +#if defined(HAVE_SO_NOSIGPIPE) +        int val = 1; +        if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val)) +                == SOCKET_ERROR) { +            throw runtime_error("Can't set SO_NOSIGPIPE"); +        } +#endif +    } + +    // Don't keep the socket blocking +    if (fcntl(m_sock, F_SETFL, flags) == -1) { +        std::string errstr(strerror(errno)); +        throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); +    } + +    freeaddrinfo(result); + +    if (rp == nullptr) { +        throw runtime_error("Could not connect"); +    } +} +  void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)  {      if (m_sock != INVALID_SOCKET) { @@ -447,11 +562,15 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)              int flags = fcntl(sfd, F_GETFL);              if (flags == -1) {                  std::string errstr(strerror(errno)); +                freeaddrinfo(result); +                ::close(sfd);                  throw std::runtime_error("TCP: Could not get socket flags: " + errstr);              }              if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {                  std::string errstr(strerror(errno)); +                freeaddrinfo(result); +                ::close(sfd);                  throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);              }          } @@ -480,7 +599,6 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)      if (rp == nullptr) {          throw runtime_error("Could not connect");      } -  }  void TCPSocket::listen(int port, const string& name) diff --git a/contrib/Socket.h b/contrib/Socket.h index 33cdc05..08607a5 100644 --- a/contrib/Socket.h +++ b/contrib/Socket.h @@ -168,6 +168,7 @@ class TCPSocket {          bool valid(void) const;          void connect(const std::string& hostname, int port, bool nonblock = false); +        void connect(const std::string& hostname, int port, int timeout_ms);          void listen(int port, const std::string& name);          void close(void); | 
