libt2n: (tomj) added exception handling to every child after fork(). This is needed...
[libt2n] / test / reentrant.cpp
1 /***************************************************************************
2  *   Copyright (C) 2004 by Intra2net AG                                    *
3  *   info@intra2net.com                                                    *
4  *                                                                         *
5  ***************************************************************************/
6
7 #include <sys/types.h>
8 #include <unistd.h>
9 #include <errno.h>
10 #include <signal.h>
11 #include <stdio.h>
12
13 #include <iostream>
14 #include <string>
15 #include <sstream>
16 #include <stdexcept>
17
18 #include <cppunit/extensions/TestFactoryRegistry.h>
19 #include <cppunit/ui/text/TestRunner.h>
20 #include <cppunit/extensions/HelperMacros.h>
21
22 #include <boost/archive/binary_oarchive.hpp>
23 #include <boost/archive/binary_iarchive.hpp>
24 #include <boost/archive/xml_oarchive.hpp>
25 #include <boost/archive/xml_iarchive.hpp>
26 #include <boost/serialization/serialization.hpp>
27
28 #include <container.hxx>
29 #include <socket_client.hxx>
30 #include <socket_server.hxx>
31 #include <command_client.hxx>
32 #include <command_server.hxx>
33
34 using namespace std;
35 using namespace CppUnit;
36 using namespace libt2n;
37
38 namespace
39 {
40
41 command_server *global_server = NULL;
42
43 int fork_count = 3;
44 int requests_per_child = 100;
45 int all_requests = (2 << (fork_count-1)) * requests_per_child;
46
47 int seen_client_requests = 0;
48
49 string testfunc(const string& str)
50 {
51     string ret;
52     ret=str+", testfunc() was here";
53
54     // call handle, eventually reentrant
55     if (global_server)
56         global_server->handle(10000);
57
58     ++seen_client_requests;
59
60     return ret;
61 }
62
63 class testfunc_res : public libt2n::result
64 {
65     private:
66         string res;
67
68         friend class boost::serialization::access;
69         template<class Archive>
70         void serialize(Archive & ar, const unsigned int version)
71         {
72             ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::result);
73             ar & BOOST_SERIALIZATION_NVP(res);
74         }
75
76     public:
77         testfunc_res()
78             {
79             }
80
81         testfunc_res(const string& str)
82         {
83             res=str;
84         }
85
86         string get_data()
87         {
88             return res;
89         }
90 };
91
92
93 class testfunc_cmd : public libt2n::command
94 {
95     private:
96         string param;
97
98         friend class boost::serialization::access;
99         template<class Archive>
100         void serialize(Archive & ar, const unsigned int version)
101         {
102             ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::command);
103             ar & BOOST_SERIALIZATION_NVP(param);
104         }
105
106     public:
107         testfunc_cmd()
108             {
109             }
110
111         testfunc_cmd(const string& str)
112         {
113             param=str;
114         }
115
116         libt2n::result* operator()()
117         {
118             return new testfunc_res(testfunc(param));
119         }
120 };
121
122 }
123
124
125 #include <boost/serialization/export.hpp>
126
127 BOOST_CLASS_EXPORT(testfunc_cmd)
128 BOOST_CLASS_EXPORT(testfunc_res)
129
130 class test_reentrant : public TestFixture
131 {
132     CPPUNIT_TEST_SUITE(test_reentrant);
133
134     CPPUNIT_TEST(ReentrantServer);
135
136     CPPUNIT_TEST_SUITE_END();
137
138     public:
139
140     void setUp()
141     { }
142
143     void tearDown()
144     { }
145
146     void ReentrantServer()
147     {
148         switch(fork())
149         {
150             case -1:
151             {
152                 CPPUNIT_FAIL("fork error");
153                 break;
154             }
155             case 0:
156             // child
157             {
158                 // wait till server is up
159                 sleep(2);
160
161                 // hammer the server
162                 for (int i = 0; i < fork_count; i++)
163                     fork();
164
165                 try
166                 {
167                     for (int i=0; i < requests_per_child; i++)
168                     {
169                         socket_client_connection sc("./socket");
170                         // sc.set_logging(&cerr,debug);
171                         command_client cc(&sc);
172
173                         result_container rc;
174                         cc.send_command(new testfunc_cmd("hello"),rc);
175
176                         testfunc_res *res = dynamic_cast<testfunc_res*>(rc.get_result());
177                         if (res)
178                         {
179                             string ret = res->get_data();
180                             if (ret != "hello, testfunc() was here")
181                                 std::cout << "ERROR reentrant server testfunc_res failed, res: \"" << ret << "\"\n";
182                         }
183                         else
184                         {
185                             std::cout << "ERROR result from reentrant server empty (" << rc.get_result() << ")\n";
186                         }
187                     }
188                 } catch (exception &e)
189                 {
190                     cerr << "caught exception: " << e.what() << endl;
191                 } catch(...)
192                 {
193                     std::cerr << "exception in child. ignoring\n";
194                 }
195
196                 // don't call atexit and stuff
197                 _exit(0);
198             }
199
200             default:
201             // parent
202             {
203                 // don't kill us on broken pipe
204                 signal(SIGPIPE, SIG_IGN);
205
206                 socket_server ss("./socket");
207                 command_server cs(ss);
208
209                 global_server=&cs;
210
211                 // Wait until all requests have successed
212                 int safety_check = 0;
213                 while (seen_client_requests < all_requests)
214                 {
215                     ++safety_check;
216                     if (safety_check > 10) {
217                         std::cerr << "reached safety check, aborting.\n";
218                         break;
219                     }
220
221                     long long maxtime=1000000;
222                     while(maxtime > 0)
223                         cs.handle(maxtime,&maxtime);
224                 }
225
226                 global_server = NULL;
227             }
228
229             // we are still alive, everything is ok
230             CPPUNIT_ASSERT_EQUAL(all_requests, seen_client_requests);
231         }
232     }
233
234 };
235
236
237 CPPUNIT_TEST_SUITE_REGISTRATION(test_reentrant);