libt2n: (gerd) add lots of error handling code, unit tests for this error handling...
[libt2n] / src / command_server.cpp
index 2a4a7b4..f596860 100644 (file)
 #include "container.hxx"
 #include "log.hxx"
 
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif
+
 using namespace std;
 
 namespace libt2n
@@ -43,12 +47,23 @@ command_server::command_server(server& _s)
     : s(_s)
 {
     // register callback
-    s.add_callback(new_connection,bind(&command_server::new_connection_callback, boost::ref(*this), _1));
+    s.add_callback(new_connection,bind(&command_server::send_hello, boost::ref(*this), _1));
 }
 
-void command_server::new_connection_callback(unsigned int conn_id)
+void command_server::send_hello(unsigned int conn_id)
 {
-    cerr << "new connection callback: " << conn_id << endl;
+    server_connection* sc=s.get_connection(conn_id);
+
+    std::ostringstream hello;
+
+    hello << "T2Nv" << PROTOCOL_VERSION << ';';
+
+    int byteordercheck=1;
+    hello.write((char*)&byteordercheck,sizeof(byteordercheck));
+
+    hello << ';';
+
+    sc->write(hello.str());
 }
 
 /// handle a command including deserialization and answering
@@ -57,46 +72,78 @@ void command_server::handle_packet(const std::string& packet, server_connection*
     OBJLOGSTREAM(s,debug,"handling packet from connection " << conn->get_id());
 
     // deserialize packet
-    istringstream ifs(packet);
+    std::istringstream ifs(packet);
     boost::archive::binary_iarchive ia(ifs);
     command_container ccont;
+    result_container res;
 
-    // TODO: catch
-    ia >> ccont;
-
-    std::ostream* ostr;
-    if ((ostr=s.get_logstream(fulldebug))!=NULL)
+    try
     {
-        (*ostr) << "decoded packet data: " << std::endl;
-        boost::archive::xml_oarchive xo(*ostr);
-        xo << BOOST_SERIALIZATION_NVP(ccont);
+        ia >> ccont;
     }
+    catch(boost::archive::archive_exception &e)
+    {
+        std::ostringstream msg;
+        msg << "archive_exception while deserializing on server-side, "
+               "code " << e.code << " (" << e.what() << ")";
+        res.set_exception(new t2n_serialization_error(msg.str()));
+    }
+    catch(...)
+        { throw; }
 
-    // TODO: cast to command subclass (template)
-    command *cmd=ccont.get_command();
+    if (!res.has_exception())
+    {
+        std::ostream* ostr;
+        if ((ostr=s.get_logstream(fulldebug))!=NULL)
+        {
+            (*ostr) << "decoded packet data: " << std::endl;
+            boost::archive::xml_oarchive xo(*ostr);
+            xo << BOOST_SERIALIZATION_NVP(ccont);
+        }
 
-    result_container res;
+        command* cmd=cast_command(ccont.get_command());
 
-    if (cmd)
-    {
-        try
+        if (cmd)
         {
-            res.set_result((*cmd)());
+            try
+            {
+                res.set_result((*cmd)());
+            }
+            catch (t2n_exception &e)
+                { res.set_exception(e.clone()); }
+            catch (...)
+                { throw; }
+        }
+        else
+        {
+            std::ostringstream msg;
+            if (ccont.get_command()!=NULL)
+                msg << "illegal command of type " << typeid(ccont.get_command()).name() << " called";
+            else
+                msg << "NULL command called";
+            res.set_exception(new t2n_command_error(msg.str()));
         }
-        catch (t2n_exception &e)
-            { res.set_exception(e.clone()); }
-        catch (...)
-            { throw; }
     }
-    else
-        throw logic_error("uninitialized command called");
 
-    ostringstream ofs;
+    std::ostringstream ofs;
     boost::archive::binary_oarchive oa(ofs);
 
-    // TODO: catch
-    oa << res;
+    try
+    {
+        oa << res;
+    }
+    catch(boost::archive::archive_exception &e)
+    {
+        std::ostringstream msg;
+        msg << "archive_exception while serializing on server-side, "
+               "code " << e.code << " (" << e.what() << ")";
+        res.set_exception(new t2n_serialization_error(msg.str()));
+        oa << res;
+    }
+    catch(...)
+        { throw; }
 
+    std::ostream* ostr;
     if ((ostr=s.get_logstream(fulldebug))!=NULL)
     {
         (*ostr) << "returning result, decoded data: " << std::endl;
@@ -108,18 +155,32 @@ void command_server::handle_packet(const std::string& packet, server_connection*
 }
 
 /** @brief handle incoming commands
-    @param usec_timeout wait until new data is found, max timeout usecs.
-            -1: wait endless, 0: no timeout
+    @param[in,out] usec_timeout wait until new data is found, max timeout usecs.
+            -1: wait endless, 0: instant return
 */
-void command_server::handle(long long usec_timeout)
+void command_server::handle(long long usec_timeout, long long* usec_timeout_remaining)
 {
-    if (s.fill_buffer(usec_timeout))
+    if (s.fill_buffer(usec_timeout,usec_timeout_remaining))
     {
-        string packet;
+        std::string packet;
         unsigned int conn_id;
 
         while (s.get_packet(packet,conn_id))
-            handle_packet(packet,s.get_connection(conn_id)); 
+        {
+            server_connection* conn=s.get_connection(conn_id);
+            if (!conn)
+                EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received");
+
+            try
+                { handle_packet(packet,conn); }
+            catch (t2n_transfer_error &e)
+            {
+                // shut down a connection with transfer errors (usually write errors)
+                conn->close();
+            }
+            catch(...)
+                { throw; }
+        }
     }
     s.cleanup();
 }