libt2n: (gerd) bugfixes, better logging, unit tests for wrapper, ignore handler still...
[libt2n] / src / command_client.cpp
index 77daf0d..0f44203 100644 (file)
@@ -51,6 +51,35 @@ command_client::command_client(client_connection& _c, long long _command_timeout
     read_hello();
 }
 
+/** @brief replace the connection currently in use with a new one
+
+    @param _c reference of the new connection
+
+    @note the old connection must still be valid when this method is called,
+          it can safely be deleted after this method returned
+
+    @note all callbacks registered on the old connection will be copied over
+          to the new one
+*/
+void command_client::replace_connection(client_connection& _c)
+{
+    // copy all callbacks registered on the old connection
+    for(callback_event_type e=static_cast<callback_event_type>(0);
+        e < __events_end;
+        e=static_cast<callback_event_type>(static_cast<int>(e)+1))
+    {
+        list<boost::function<void ()> > evcb=c.get_callback_list(e);
+
+        for (list<boost::function<void ()> >::iterator i=evcb.begin(); i != evcb.end(); i++)
+            _c.add_callback(e,*i);
+    }
+
+    // replace the connection
+    c=_c;
+
+    read_hello();
+}
+
 std::string command_client::read_packet(const long long &usec_timeout)
 {
     string resultpacket;
@@ -67,30 +96,97 @@ std::string command_client::read_packet(const long long &usec_timeout)
 
 void command_client::read_hello()
 {
-    istringstream hello(read_packet(hello_timeout_usec));
+    string resultpacket;
+    bool got_packet=false;
+    long long my_timeout=hello_timeout_usec;
+    while(!(got_packet=c.get_packet(resultpacket)) && my_timeout > 0  && !c.is_closed())
+    {
+        c.fill_buffer(my_timeout,&my_timeout);
+
+        c.peek_packet(resultpacket);
+        check_hello(resultpacket);           // will throw before timeout if wrong data received
+    }
+
+    if (!got_packet)
+        throw t2n_transfer_error("timeout exceeded");
 
-    char chk[5];
-    hello.read(chk,4);
-    chk[4]=0;
-    if (hello.fail() || hello.eof() || string("T2Nv") != chk)
-        throw t2n_version_mismatch("illegal hello received (T2N)");
+    if (!check_hello(resultpacket))
+        throw t2n_version_mismatch("illegal hello received (incomplete): "+resultpacket);
+}
+
+bool command_client::check_hello(const string& hellostr)
+{
+    istringstream hello(hellostr);
+
+    char chk;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'T')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != '2')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'N')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'v')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
 
     int prot_version;
-    hello >> prot_version;
-    if (hello.fail() || hello.eof() || prot_version != PROTOCOL_VERSION)
-        throw t2n_version_mismatch("not compatible with the server protocol version");
+    if (hello >> prot_version)
+    {
+        if (prot_version != PROTOCOL_VERSION)
+            throw t2n_version_mismatch("not compatible with the server protocol version");
+    }
+    else
+        return false;
 
-    hello.read(chk,1);
-    if (hello.fail() || hello.eof() || chk[0] != ';')
-        throw t2n_version_mismatch("illegal hello received (1. ;)");
+    if (hello.read(&chk,1))
+    {
+        if (chk != ';')
+            throw t2n_version_mismatch("illegal hello received (1. ;)");
+    }
+    else
+        return false;
 
-    hello.read(chk,4);
-    if (hello.fail() || hello.eof() || *((int*)chk) != 1)
-        throw t2n_version_mismatch("host byte order not matching");
+    unsigned int hbo;
+    if (hello.read((char*)&hbo,sizeof(hbo)))
+    {
+        if (hbo != 1)
+            throw t2n_version_mismatch("host byte order not matching");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != ';')
+            throw t2n_version_mismatch("illegal hello received (2. ;)");
+    }
+    else
+        return false;
 
-    hello.read(chk,1);
-    if (hello.fail() || hello.eof() || chk[0] != ';')
-        throw t2n_version_mismatch("illegal hello received (2. ;)");
+    return true;
 }
 
 void command_client::send_command(command* cmd, result_container &res)