From 3b2543e7dfd705d6e624560dd5a681898c0f242c Mon Sep 17 00:00:00 2001 From: Gerd v. Egidy Date: Fri, 5 Sep 2008 11:42:16 +0000 Subject: [PATCH] libt2n: (gerd) make handle-function on server reentrant --- src/command_server.cpp | 47 ++++++++---- src/command_server.hxx | 2 + test/Makefile.am | 2 +- test/reentrant.cpp | 195 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 228 insertions(+), 18 deletions(-) create mode 100644 test/reentrant.cpp diff --git a/src/command_server.cpp b/src/command_server.cpp index f596860..6146981 100644 --- a/src/command_server.cpp +++ b/src/command_server.cpp @@ -44,7 +44,7 @@ namespace libt2n { command_server::command_server(server& _s) - : s(_s) + : s(_s), guard_handle(0) { // register callback s.add_callback(new_connection,bind(&command_server::send_hello, boost::ref(*this), _1)); @@ -160,29 +160,42 @@ void command_server::handle_packet(const std::string& packet, server_connection* */ void command_server::handle(long long usec_timeout, long long* usec_timeout_remaining) { - if (s.fill_buffer(usec_timeout,usec_timeout_remaining)) + guard_handle++; + try { - std::string packet; - unsigned int conn_id; - - while (s.get_packet(packet,conn_id)) + if (s.fill_buffer(usec_timeout,usec_timeout_remaining)) { - server_connection* conn=s.get_connection(conn_id); - if (!conn) - EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received"); + std::string packet; + unsigned int conn_id; - try - { handle_packet(packet,conn); } - catch (t2n_transfer_error &e) + while (s.get_packet(packet,conn_id)) { - // shut down a connection with transfer errors (usually write errors) - conn->close(); + server_connection* conn=s.get_connection(conn_id); + if (!conn) + EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received"); + + try + { handle_packet(packet,conn); } + catch (t2n_transfer_error &e) + { + // shut down a connection with transfer errors (usually write errors) + conn->close(); + } + catch(...) + { throw; } } - catch(...) - { throw; } } } - s.cleanup(); + catch(...) + { + guard_handle--; + throw; + } + guard_handle--; + + // don't call cleanup on re-entered handle-calls + if (guard_handle == 0) + s.cleanup(); } } diff --git a/src/command_server.hxx b/src/command_server.hxx index 378ef08..c839e78 100644 --- a/src/command_server.hxx +++ b/src/command_server.hxx @@ -33,6 +33,8 @@ class command_server void handle_packet(const std::string& packet, server_connection* conn); + int guard_handle; + protected: virtual command* cast_command(command* input) { return input; } diff --git a/test/Makefile.am b/test/Makefile.am index abeb415..cbe0440 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -5,6 +5,6 @@ check_PROGRAMS = test test_LDADD = $(top_builddir)/src/libt2n.la @BOOST_SERIALIZATION_LIB@ \ @BOOST_LDFLAGS@ @CPPUNIT_LIBS@ test_SOURCES = callback.cpp cmdgroup.cpp comm.cpp hello.cpp reconnect.cpp \ - serialize.cpp simplecmd.cpp test.cpp timeout.cpp wrapper.cpp + reentrant.cpp serialize.cpp simplecmd.cpp test.cpp timeout.cpp wrapper.cpp TESTS = test diff --git a/test/reentrant.cpp b/test/reentrant.cpp new file mode 100644 index 0000000..bb9420a --- /dev/null +++ b/test/reentrant.cpp @@ -0,0 +1,195 @@ +/*************************************************************************** + * Copyright (C) 2004 by Intra2net AG * + * info@intra2net.com * + * * + ***************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace std; +using namespace CppUnit; +using namespace libt2n; + +namespace +{ + +command_server *global_server; + +string testfunc(const string& str) +{ + string ret; + ret=str+", testfunc() was here"; + + // call handle, eventually reentrant + global_server->handle(1000); + + return ret; +} + +class testfunc_res : public libt2n::result +{ + private: + string res; + + friend class boost::serialization::access; + template + void serialize(Archive & ar, const unsigned int version) + { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::result); + ar & BOOST_SERIALIZATION_NVP(res); + } + + public: + testfunc_res() + { } + + testfunc_res(const string& str) + { + res=str; + } + + string get_data() + { + return res; + } +}; + + +class testfunc_cmd : public libt2n::command +{ + private: + string param; + + friend class boost::serialization::access; + template + void serialize(Archive & ar, const unsigned int version) + { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(libt2n::command); + ar & BOOST_SERIALIZATION_NVP(param); + } + + public: + testfunc_cmd() + { } + + testfunc_cmd(const string& str) + { + param=str; + } + + libt2n::result* operator()() + { + return new testfunc_res(testfunc(param)); + } +}; + +} + +#include + +BOOST_CLASS_EXPORT(testfunc_cmd) +BOOST_CLASS_EXPORT(testfunc_res) + +class test_reentrant : public TestFixture +{ + CPPUNIT_TEST_SUITE(test_reentrant); + + CPPUNIT_TEST(ReentrantServer); + + CPPUNIT_TEST_SUITE_END(); + + pid_t child_pid; + + public: + + void setUp() + { } + + void tearDown() + { } + + void ReentrantServer() + { + switch(child_pid=fork()) + { + case -1: + { + CPPUNIT_FAIL("fork error"); + break; + } + case 0: + // child + { + // wait till server is up + sleep(1); + + // we want 8 identical childs hammering the server + fork(); + fork(); + fork(); + + for (int i=0; i < 100; i++) + { + socket_client_connection sc("./socket"); + command_client cc(&sc); + + result_container rc; + cc.send_command(new testfunc_cmd("hello"),rc); + + string ret=dynamic_cast(rc.get_result())->get_data(); + + CPPUNIT_ASSERT_EQUAL(string("hello, testfunc() was here"),ret); + } + + // don't call atexit and stuff + _exit(0); + } + + default: + // parent + { + socket_server ss("./socket"); + command_server cs(ss); + + global_server=&cs; + + // max 10 sec + long long maxtime=5000000; + while(maxtime > 0) + cs.handle(maxtime,&maxtime); + } + + // we are still alive, everything is ok + CPPUNIT_ASSERT_EQUAL(1,1); + } + } + +}; + + +CPPUNIT_TEST_SUITE_REGISTRATION(test_reentrant); -- 1.7.1