libt2n: (gerd) add lots of error handling code, unit tests for this error handling...
[libt2n] / src / command_server.cpp
index 55c4cc1..f596860 100644 (file)
@@ -54,7 +54,7 @@ void command_server::send_hello(unsigned int conn_id)
 {
     server_connection* sc=s.get_connection(conn_id);
 
-    ostringstream hello;
+    std::ostringstream hello;
 
     hello << "T2Nv" << PROTOCOL_VERSION << ';';
 
@@ -72,7 +72,7 @@ void command_server::handle_packet(const std::string& packet, server_connection*
     OBJLOGSTREAM(s,debug,"handling packet from connection " << conn->get_id());
 
     // deserialize packet
-    istringstream ifs(packet);
+    std::istringstream ifs(packet);
     boost::archive::binary_iarchive ia(ifs);
     command_container ccont;
     result_container res;
@@ -83,7 +83,7 @@ void command_server::handle_packet(const std::string& packet, server_connection*
     }
     catch(boost::archive::archive_exception &e)
     {
-        ostringstream msg;
+        std::ostringstream msg;
         msg << "archive_exception while deserializing on server-side, "
                "code " << e.code << " (" << e.what() << ")";
         res.set_exception(new t2n_serialization_error(msg.str()));
@@ -101,8 +101,7 @@ void command_server::handle_packet(const std::string& packet, server_connection*
             xo << BOOST_SERIALIZATION_NVP(ccont);
         }
 
-        // TODO: cast to command subclass (template)
-        command *cmd=ccont.get_command();
+        command* cmd=cast_command(ccont.get_command());
 
         if (cmd)
         {
@@ -116,10 +115,17 @@ void command_server::handle_packet(const std::string& packet, server_connection*
                 { throw; }
         }
         else
-            throw logic_error("uninitialized command called");
+        {
+            std::ostringstream msg;
+            if (ccont.get_command()!=NULL)
+                msg << "illegal command of type " << typeid(ccont.get_command()).name() << " called";
+            else
+                msg << "NULL command called";
+            res.set_exception(new t2n_command_error(msg.str()));
+        }
     }
 
-    ostringstream ofs;
+    std::ostringstream ofs;
     boost::archive::binary_oarchive oa(ofs);
 
     try
@@ -128,7 +134,7 @@ void command_server::handle_packet(const std::string& packet, server_connection*
     }
     catch(boost::archive::archive_exception &e)
     {
-        ostringstream msg;
+        std::ostringstream msg;
         msg << "archive_exception while serializing on server-side, "
                "code " << e.code << " (" << e.what() << ")";
         res.set_exception(new t2n_serialization_error(msg.str()));
@@ -156,11 +162,25 @@ void command_server::handle(long long usec_timeout, long long* usec_timeout_rema
 {
     if (s.fill_buffer(usec_timeout,usec_timeout_remaining))
     {
-        string packet;
+        std::string packet;
         unsigned int conn_id;
 
         while (s.get_packet(packet,conn_id))
-            handle_packet(packet,s.get_connection(conn_id)); 
+        {
+            server_connection* conn=s.get_connection(conn_id);
+            if (!conn)
+                EXCEPTIONSTREAM(error,logic_error,"illegal connection id " << conn_id << " received");
+
+            try
+                { handle_packet(packet,conn); }
+            catch (t2n_transfer_error &e)
+            {
+                // shut down a connection with transfer errors (usually write errors)
+                conn->close();
+            }
+            catch(...)
+                { throw; }
+        }
     }
     s.cleanup();
 }