{
command_server::command_server(server& _s)
- : s(_s)
+ : s(_s), guard_handle(0)
{
// register callback
s.add_callback(new_connection,bind(&command_server::send_hello, boost::ref(*this), _1));
*/
void command_server::handle(long long usec_timeout, long long* usec_timeout_remaining)
{
- if (s.fill_buffer(usec_timeout,usec_timeout_remaining))
+ guard_handle++;
+ try
{
- std::string packet;
- unsigned int conn_id;
-
- while (s.get_packet(packet,conn_id))
+ if (s.fill_buffer(usec_timeout,usec_timeout_remaining))
{
- server_connection* conn=s.get_connection(conn_id);
- if (!conn)
- EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received");
+ std::string packet;
+ unsigned int conn_id;
- try
- { handle_packet(packet,conn); }
- catch (t2n_transfer_error &e)
+ while (s.get_packet(packet,conn_id))
{
- // shut down a connection with transfer errors (usually write errors)
- conn->close();
+ server_connection* conn=s.get_connection(conn_id);
+ if (!conn)
+ EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received");
+
+ try
+ { handle_packet(packet,conn); }
+ catch (t2n_transfer_error &e)
+ {
+ // shut down a connection with transfer errors (usually write errors)
+ conn->close();
+ }
+ catch(...)
+ { throw; }
}
- catch(...)
- { throw; }
}
}
- s.cleanup();
+ catch(...)
+ {
+ guard_handle--;
+ throw;
+ }
+ guard_handle--;
+
+ // don't call cleanup on re-entered handle-calls
+ if (guard_handle == 0)
+ s.cleanup();
}
}
--- /dev/null
+/***************************************************************************
+ * Copyright (C) 2004 by Intra2net AG *
+ * info@intra2net.com *
+ * *
+ ***************************************************************************/
+
+#include <sys/types.h>
+#include <unistd.h>
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+
+#include <iostream>
+#include <string>
+#include <sstream>
+#include <stdexcept>
+
+#include <cppunit/extensions/TestFactoryRegistry.h>
+#include <cppunit/ui/text/TestRunner.h>
+#include <cppunit/extensions/HelperMacros.h>
+
+#include <boost/archive/binary_oarchive.hpp>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/xml_oarchive.hpp>
+#include <boost/archive/xml_iarchive.hpp>
+#include <boost/serialization/serialization.hpp>
+
+#include <container.hxx>
+#include <socket_client.hxx>
+#include <socket_server.hxx>
+#include <command_client.hxx>
+#include <command_server.hxx>
+
+using namespace std;
+using namespace CppUnit;
+using namespace libt2n;
+
+namespace
+{
+
+command_server *global_server;
+
+string testfunc(const string& str)
+{
+ string ret;
+ ret=str+", testfunc() was here";
+
+ // call handle, eventually reentrant
+ global_server->handle(1000);
+
+ return ret;
+}
+
+class testfunc_res : public libt2n::result
+{
+ private:
+ string res;
+
+ friend class boost::serialization::access;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int version)
+ {
+ ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::result);
+ ar & BOOST_SERIALIZATION_NVP(res);
+ }
+
+ public:
+ testfunc_res()
+ { }
+
+ testfunc_res(const string& str)
+ {
+ res=str;
+ }
+
+ string get_data()
+ {
+ return res;
+ }
+};
+
+
+class testfunc_cmd : public libt2n::command
+{
+ private:
+ string param;
+
+ friend class boost::serialization::access;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int version)
+ {
+ ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::command);
+ ar & BOOST_SERIALIZATION_NVP(param);
+ }
+
+ public:
+ testfunc_cmd()
+ { }
+
+ testfunc_cmd(const string& str)
+ {
+ param=str;
+ }
+
+ libt2n::result* operator()()
+ {
+ return new testfunc_res(testfunc(param));
+ }
+};
+
+}
+
+#include <boost/serialization/export.hpp>
+
+BOOST_CLASS_EXPORT(testfunc_cmd)
+BOOST_CLASS_EXPORT(testfunc_res)
+
+class test_reentrant : public TestFixture
+{
+ CPPUNIT_TEST_SUITE(test_reentrant);
+
+ CPPUNIT_TEST(ReentrantServer);
+
+ CPPUNIT_TEST_SUITE_END();
+
+ pid_t child_pid;
+
+ public:
+
+ void setUp()
+ { }
+
+ void tearDown()
+ { }
+
+ void ReentrantServer()
+ {
+ switch(child_pid=fork())
+ {
+ case -1:
+ {
+ CPPUNIT_FAIL("fork error");
+ break;
+ }
+ case 0:
+ // child
+ {
+ // wait till server is up
+ sleep(1);
+
+ // we want 8 identical childs hammering the server
+ fork();
+ fork();
+ fork();
+
+ for (int i=0; i < 100; i++)
+ {
+ socket_client_connection sc("./socket");
+ command_client cc(&sc);
+
+ result_container rc;
+ cc.send_command(new testfunc_cmd("hello"),rc);
+
+ string ret=dynamic_cast<testfunc_res*>(rc.get_result())->get_data();
+
+ CPPUNIT_ASSERT_EQUAL(string("hello, testfunc() was here"),ret);
+ }
+
+ // don't call atexit and stuff
+ _exit(0);
+ }
+
+ default:
+ // parent
+ {
+ socket_server ss("./socket");
+ command_server cs(ss);
+
+ global_server=&cs;
+
+ // max 10 sec
+ long long maxtime=5000000;
+ while(maxtime > 0)
+ cs.handle(maxtime,&maxtime);
+ }
+
+ // we are still alive, everything is ok
+ CPPUNIT_ASSERT_EQUAL(1,1);
+ }
+ }
+
+};
+
+
+CPPUNIT_TEST_SUITE_REGISTRATION(test_reentrant);