libt2n: (gerd) add hello peek
[libt2n] / src / command_client.cpp
index 0d387b9..376d3f9 100644 (file)
 #include <boost/archive/xml_oarchive.hpp>
 #include <boost/archive/xml_iarchive.hpp>
 #include <boost/serialization/serialization.hpp>
-#include <boost/serialization/export.hpp>
+
+#include <boost/bind.hpp>
 
 #include "command_client.hxx"
 
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif
+
 using namespace std;
 
 namespace libt2n
 {
 
+command_client::command_client(client_connection& _c, long long _command_timeout_usec, long long _hello_timeout_usec)
+    : c(_c)
+{
+    command_timeout_usec=_command_timeout_usec;
+    hello_timeout_usec=_hello_timeout_usec;
+
+    // for reconnects
+    c.add_callback(new_connection,bind(&command_client::read_hello, boost::ref(*this)));
+
+    read_hello();
+}
+
+std::string command_client::read_packet(const long long &usec_timeout)
+{
+    string resultpacket;
+    bool got_packet=false;
+    long long my_timeout=usec_timeout;
+    while(!(got_packet=c.get_packet(resultpacket)) && my_timeout > 0  && !c.is_closed())
+        c.fill_buffer(my_timeout,&my_timeout);
+
+    if (!got_packet)
+        throw t2n_transfer_error("timeout exceeded");
+
+    return resultpacket;
+}
+
+void command_client::read_hello()
+{
+    string resultpacket;
+    bool got_packet=false;
+    long long my_timeout=hello_timeout_usec;
+    while(!(got_packet=c.get_packet(resultpacket)) && my_timeout > 0  && !c.is_closed())
+    {
+        c.fill_buffer(my_timeout,&my_timeout);
+
+        c.peek_packet(resultpacket);
+        check_hello(resultpacket);           // will throw before timeout if wrong data received
+    }
+
+    if (!got_packet)
+        throw t2n_transfer_error("timeout exceeded");
+
+    if (!check_hello(resultpacket))
+        throw t2n_version_mismatch("illegal hello received (incomplete): "+resultpacket);
+}
+
+bool command_client::check_hello(const string& hellostr)
+{
+    istringstream hello(hellostr);
+
+    char chk;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'T')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != '2')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'N')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != 'v')
+            throw t2n_version_mismatch("illegal hello received (T2N)");
+    }
+    else
+        return false;
+
+    int prot_version;
+    if (hello >> prot_version)
+    {
+        if (prot_version != PROTOCOL_VERSION)
+            throw t2n_version_mismatch("not compatible with the server protocol version");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != ';')
+            throw t2n_version_mismatch("illegal hello received (1. ;)");
+    }
+    else
+        return false;
+
+    unsigned int hbo;
+    if (hello.read((char*)&hbo,sizeof(hbo)))
+    {
+        if (hbo != 1)
+            throw t2n_version_mismatch("host byte order not matching");
+    }
+    else
+        return false;
+
+    if (hello.read(&chk,1))
+    {
+        if (chk != ';')
+            throw t2n_version_mismatch("illegal hello received (2. ;)");
+    }
+    else
+        return false;
+
+    return true;
+}
+
 void command_client::send_command(command* cmd, result_container &res)
 {
     ostringstream ofs;
     command_container cc(cmd);
     boost::archive::binary_oarchive oa(ofs);
 
-    // TODO: exceptions
-    oa << cc;
+    try
+    {
+        oa << cc;
+    }
+    catch(boost::archive::archive_exception &e)
+    {
+        ostringstream msg;
+        msg << "archive_exception while serializing on client-side, code " << e.code << " (" << e.what() << ")";
+        throw t2n_serialization_error(msg.str());
+    }
+    catch(...)
+        { throw; }
 
-    c.write(ofs.str());
+    std::ostream* ostr;
+    if ((ostr=c.get_logstream(fulldebug))!=NULL)
+    {
+        (*ostr) << "sending command, decoded data: " << std::endl;
+        boost::archive::xml_oarchive xo(*ostr);
+        xo << BOOST_SERIALIZATION_NVP(cc);
+    }
 
-    // TODO: fix timeout
-    string resultpacket;
-    while(!c.get_packet(resultpacket))
-        c.fill_buffer();
+    c.write(ofs.str());
 
-    istringstream ifs(resultpacket);
+    istringstream ifs(read_packet(command_timeout_usec));
     boost::archive::binary_iarchive ia(ifs);
 
-    // TODO: exceptions
-    ia >> res;
+    try
+    {
+        ia >> res;
+    }
+    catch(boost::archive::archive_exception &e)
+    {
+        ostringstream msg;
+        msg << "archive_exception while deserializing on client-side, code " << e.code << " (" << e.what() << ")";
+        throw t2n_serialization_error(msg.str());
+    }
+    catch(...)
+        { throw; }
+
+    if ((ostr=c.get_logstream(fulldebug))!=NULL)
+    {
+        (*ostr) << "received result, decoded data: " << std::endl;
+        boost::archive::xml_oarchive xo(*ostr);
+        xo << BOOST_SERIALIZATION_NVP(res);
+    }
 }
 
 }