socket_client.cpp: prevent buffer overflow in creation of unix socket
[libt2n] / src / socket_client.cpp
1 /*
2 Copyright (C) 2006 by Intra2net AG - Gerd v. Egidy
3
4 The software in this package is distributed under the GNU General
5 Public License version 2 (with a special exception described below).
6
7 A copy of GNU General Public License (GPL) is included in this distribution,
8 in the file COPYING.GPL.
9
10 As a special exception, if other files instantiate templates or use macros
11 or inline functions from this file, or you compile this file and link it
12 with other works to produce a work based on this file, this file
13 does not by itself cause the resulting work to be covered
14 by the GNU General Public License.
15
16 However the source code for this file must still be made available
17 in accordance with section (3) of the GNU General Public License.
18
19 This exception does not invalidate any other reasons why a work based
20 on this file might be covered by the GNU General Public License.
21 */
22
23 #include <stdio.h>
24 #include <errno.h>
25 #include <stdlib.h>
26 #include <unistd.h>
27 #include <sys/types.h>
28 #include <sys/socket.h>
29 #include <sys/un.h>
30 #include <sys/time.h>
31 #include <arpa/inet.h>
32 #include <netinet/in.h>
33 #include <netdb.h>
34 #include <fcntl.h>
35 #include <time.h>
36 #include <pwd.h>
37 #include <grp.h>
38
39 #include <sstream>
40
41 #include "socket_client.hxx"
42 #include "t2n_exception.hxx"
43 #include "log.hxx"
44
45 using namespace std;
46
47 namespace libt2n
48 {
49
50 /// returns a closed connection if connection could not be established, call get_last_error_msg() for details
51 socket_client_connection::socket_client_connection(int _port, const std::string& _server, 
52             long long _connect_timeout_usec, int _max_retries,
53             std::ostream *_logstream, log_level_values _log_level)
54     : client_connection(), socket_handler(0,tcp_s)
55 {
56     max_retries=_max_retries;
57     connect_timeout_usec=_connect_timeout_usec;
58
59     server=_server;
60     port=_port;
61
62     set_logging(_logstream,_log_level);
63
64     try
65     {
66         tcp_connect(max_retries);
67     }
68     catch (t2n_communication_error &e)
69     {
70         lastErrorMsg=e.what();
71         LOGSTREAM(debug,"tcp connect error: " << lastErrorMsg);
72         // FIXME: Don't call virtual function in constructor. Currently not dangerous but bad design.
73         close();
74     }
75
76     if (!connection::is_closed())
77         do_callbacks(new_connection);
78 }
79
80 /// returns a closed connection if connection could not be established, call get_last_error_msg() for details
81 socket_client_connection::socket_client_connection(const std::string& _path,
82             long long _connect_timeout_usec, int _max_retries,
83             std::ostream *_logstream, log_level_values _log_level)
84     : client_connection(), socket_handler(0,unix_s)
85 {
86     max_retries=_max_retries;
87     connect_timeout_usec=_connect_timeout_usec;
88
89     path=_path;
90
91     set_logging(_logstream,_log_level);
92
93     try
94     {
95         unix_connect(max_retries);
96     }
97     catch (t2n_communication_error &e)
98     {
99         lastErrorMsg=e.what();
100         LOGSTREAM(debug,"unix connect error: " << lastErrorMsg);
101         // FIXME: Don't call virtual function in constructor. Currently not dangerous
102         close();
103     }
104
105     if (!connection::is_closed())
106         do_callbacks(new_connection);
107 }
108
109 /**
110  * Destructor. Closes an open connection.
111  */
112 socket_client_connection::~socket_client_connection()
113 {
114     // Destructor of socket_handler will close the socket!
115 }
116
117
118 /// establish a connection via tcp
119 void socket_client_connection::tcp_connect(int max_retries)
120 {
121     struct sockaddr_in sock_addr;
122
123     sock_addr.sin_family = AF_INET;
124     sock_addr.sin_port = htons(port);
125
126     // find the target ip
127     if (inet_aton(server.c_str(),&sock_addr.sin_addr)==0)
128     {
129         struct hostent *server_hent;
130         server_hent=gethostbyname(server.c_str());
131         if (server_hent == NULL)
132             throw t2n_connect_error(string("can't find server ")+server);
133
134         memcpy(&sock_addr.sin_addr,server_hent->h_addr_list[0],sizeof(sock_addr.sin_addr));
135     }
136
137     sock = socket(PF_INET, SOCK_STREAM, 0);
138     if (!sock)
139         throw t2n_connect_error(string("socket() error: ")+strerror(errno));
140
141     try
142     {
143         connect_with_timeout((struct sockaddr *) &sock_addr,sizeof(sock_addr));
144     }
145     catch (t2n_connect_error &e)
146     {
147         // recurse if retries left
148         if (max_retries > 0)
149         {
150             LOGSTREAM(debug,"retrying connect after connect error");
151             tcp_connect(max_retries-1);
152         }
153         else
154             throw t2n_connect_error("no more retries left after connect error");
155     }
156 }
157
158 /// establish a connection via unix-socket
159 void socket_client_connection::unix_connect(int max_retries)
160 {
161     struct sockaddr_un unix_addr;
162     size_t path_size = path.size();
163
164     unix_addr.sun_family = AF_UNIX;
165
166     if (path_size >= sizeof(unix_addr.sun_path))
167     {
168         throw t2n_connect_error((std::string)"path '"
169                                 + path
170                                 + "' exceeds permissible UNIX socket path length");
171     }
172
173     memcpy(unix_addr.sun_path, path.c_str(), path_size);
174     unix_addr.sun_path[path_size] = '\0';
175
176     sock = socket(PF_UNIX, SOCK_STREAM, 0);
177     if (!sock)
178         throw t2n_connect_error(string("socket() error: ")+strerror(errno));
179
180     try
181     {
182         connect_with_timeout((struct sockaddr *) &unix_addr, sizeof(unix_addr));
183     }
184     catch (t2n_connect_error &e)
185     {
186         // recurse if retries left
187         if (max_retries > 0)
188         {
189             LOGSTREAM(debug,"retrying connect after connect error");
190             unix_connect(max_retries-1);
191         }
192         else
193             throw t2n_connect_error("no more retries left after connect error");
194     }
195 }
196
197 /// execute a connect on a prepared socket (tcp or unix) respecting timeouts
198 void socket_client_connection::connect_with_timeout(struct sockaddr *sock_addr,unsigned int sockaddr_size)
199 {
200     set_socket_options(sock);
201
202    /* non-blocking mode */
203     int flflags;
204     flflags=fcntl(sock,F_GETFL,0);
205     if (flflags < 0)
206         EXCEPTIONSTREAM(error,t2n_communication_error,"fcntl error on socket: " << strerror(errno));
207
208     flflags &= (O_NONBLOCK ^ 0xFFFF);
209     if (fcntl(sock,F_SETFL,flflags) < 0)
210         EXCEPTIONSTREAM(error,t2n_communication_error,"fcntl error on socket: " << strerror(errno));
211
212
213     LOGSTREAM(debug,"connect_with_timeout()");
214     int ret=::connect(sock,sock_addr, sockaddr_size);
215
216     if (ret < 0)
217     {
218         if (errno==EINPROGRESS)
219         {
220             LOGSTREAM(debug,"connect_with_timeout(): EINPROGRESS");
221
222             /* set timeout */
223             struct timeval tval;
224             struct timeval *timeout_ptr;
225
226             if (connect_timeout_usec == -1)
227                 timeout_ptr = NULL;
228             else
229             {
230                 timeout_ptr = &tval;
231
232                 // convert timeout from long long usec to int sec + int usec
233                 tval.tv_sec = connect_timeout_usec / 1000000;
234                 tval.tv_usec = connect_timeout_usec % 1000000;
235             }
236
237             fd_set connect_socket_set;
238             FD_ZERO(&connect_socket_set);
239             FD_SET(sock,&connect_socket_set);
240
241             int ret;
242             while ((ret=select(FD_SETSIZE, NULL, &connect_socket_set, NULL, timeout_ptr)) &&
243                     ret < 0 && errno==EINTR);
244
245             if (ret < 0)
246                 throw t2n_connect_error(string("connect() error (select): ")+strerror(errno));
247
248             socklen_t sopt=sizeof(int);
249             int valopt;
250             ret=getsockopt(sock, SOL_SOCKET, SO_ERROR, (void*)(&valopt), &sopt);
251             if (ret < 0 || valopt)
252                 throw t2n_connect_error(string("connect() error (getsockopt): ")+strerror(errno));
253         }
254         else
255             throw t2n_connect_error(string("connect() error: ")+strerror(errno));
256     }
257
258     LOGSTREAM(debug,"connect_with_timeout(): success");
259 }
260
261 void socket_client_connection::close()
262 {
263     if (!client_connection::is_closed())
264     {
265         socket_handler::close();
266         client_connection::close();
267     }
268 }
269
270 /** @brief try to reconnect the current connection with the same connection credentials (host and port or path)
271
272     @note will throw an exeption if reconnecting not possible
273 */
274 void socket_client_connection::reconnect()
275 {
276     LOGSTREAM(debug,"reconnect()");
277
278     // close the current connection if still open
279     close();
280
281     socket_type_value type=get_type();
282
283     if (type == tcp_s)
284         tcp_connect(max_retries);
285     else if (type == unix_s)
286         unix_connect(max_retries);
287
288     // connection is open now, otherwise an execption would have been thrown
289     reopen();
290
291     LOGSTREAM(debug,"reconnect() done, client_connection::is_closed() now " << client_connection::is_closed());
292 }
293
294 }