polished documentation for first public release
[ipt_ACCOUNT] / linux / net / ipv4 / netfilter / ipt_ACCOUNT.c
index e1a5262..edac8ff 100644 (file)
@@ -1,6 +1,17 @@
-/*
- * This is a module which is used for counting packets.
- */
+/***************************************************************************
+ *   This is a module which is used for counting packets.                  *
+ *   See http://www.intra2net.com/opensource/ipt_ACCOUNT                   *
+ *   for further information                                               *
+ *                                                                         * 
+ *   Copyright (C) 2004 by Intra2net AG                                    *
+ *   opensource@intra2net.com                                              *
+ *                                                                         *
+ *   This program is free software; you can redistribute it and/or modify  *
+ *   it under the terms of the GNU General Public License                  *
+ *   version 2 as published by the Free Software Foundation;               *
+ *                                                                         *
+ ***************************************************************************/
+
 #include <linux/module.h>
 #include <linux/skbuff.h>
 #include <linux/ip.h>
@@ -14,11 +25,11 @@ struct in_device;
 #include <net/route.h>
 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
 
-//#if 0
+#if 0
 #define DEBUGP printk
-//#else
-//#define DEBUGP(format, args...)
-//#endif
+#else
+#define DEBUGP(format, args...)
+#endif
 
 struct ipt_account_table *ipt_account_tables = NULL;
 struct ipt_account_handle *ipt_account_handles = NULL;
@@ -155,7 +166,7 @@ int ipt_account_table_insert(char *name, unsigned int ip, unsigned int netmask)
             printk("ACCOUNT: calculated netsize: %u -> ipt_account_table depth %u\n", netsize, ipt_account_tables[i].depth);
                         
             ipt_account_tables[i].refcount++;
-            if (!(ipt_account_tables[i].data = (void *)get_zeroed_page(GFP_KERNEL)))
+            if ((ipt_account_tables[i].data = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
             {
                 printk("ACCOUNT: out of memory for data of table: %s\n", name);
                 memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
@@ -328,7 +339,7 @@ void ipt_account_depth1_insert(struct ipt_account_mask_16 *mask_16, unsigned int
         DEBUGP("ACCOUNT: Calculated SRC 16 bit network slot: %d\n", slot);
         
         // Do we need to create a new mask_24 bucket?
-        if (!mask_16->mask_24[slot] && !(mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)))
+        if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
         {
             printk("ACCOUNT: Can't process packet because out of memory!\n");
             return;
@@ -345,7 +356,7 @@ void ipt_account_depth1_insert(struct ipt_account_mask_16 *mask_16, unsigned int
         DEBUGP("ACCOUNT: Calculated DST 16 bit network slot: %d\n", slot);
         
         // Do we need to create a new mask_24 bucket?
-        if (!mask_16->mask_24[slot] && !(mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)))
+        if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
         {
             printk("ACCOUT: Can't process packet because out of memory!\n");
             return;
@@ -366,7 +377,7 @@ void ipt_account_depth2_insert(struct ipt_account_mask_8 *mask_8, unsigned int n
         DEBUGP("ACCOUNT: Calculated SRC 24 bit network slot: %d\n", slot);
                 
         // Do we need to create a new mask_24 bucket?
-        if (!mask_8->mask_16[slot] && !(mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)))
+        if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
         {
             printk("ACCOUNT: Can't process packet because out of memory!\n");
             return;
@@ -383,7 +394,7 @@ void ipt_account_depth2_insert(struct ipt_account_mask_8 *mask_8, unsigned int n
         DEBUGP("ACCOUNT: Calculated DST 24 bit network slot: %d\n", slot);
         
         // Do we need to create a new mask_24 bucket?
-        if (!mask_8->mask_16[slot] && !(mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)))
+        if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
         {
             printk("ACCOUNT: Can't process packet because out of memory!\n");
             return;
@@ -533,7 +544,7 @@ int ipt_account_handle_prepare_read(char *tablename, unsigned int *count)
     ipt_account_handles[handle].itemcount = ipt_account_tables[table_nr].itemcount;
     
     // allocate "root" table
-    if (!(ipt_account_handles[handle].data = (void*)get_zeroed_page(GFP_KERNEL)))
+    if ((ipt_account_handles[handle].data = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
     {
         printk("ACCOUNT: out of memory for root table in ipt_account_handle_prepare_read()\n");
         memset (&ipt_account_handles[handle], 0, sizeof(struct ipt_account_handle));
@@ -554,7 +565,7 @@ int ipt_account_handle_prepare_read(char *tablename, unsigned int *count)
         {
             if (src_16->mask_24[b])
             {
-                if (!(network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)))
+                if ((network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
                 {
                     printk("ACCOUNT: out of memory during copy of 16 bit network in ipt_account_handle_prepare_read()\n");
                     ipt_account_data_free(ipt_account_handles[handle].data, depth);
@@ -574,7 +585,7 @@ int ipt_account_handle_prepare_read(char *tablename, unsigned int *count)
         {
             if (src_8->mask_16[a])
             {
-                if (!(network_8->mask_16[a] = (void*)get_zeroed_page(GFP_KERNEL)))
+                if ((network_8->mask_16[a] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
                 {
                     printk("ACCOUNT: out of memory during copy of 24 bit network in ipt_account_handle_prepare_read()\n");
                     ipt_account_data_free(ipt_account_handles[handle].data, depth);
@@ -592,7 +603,7 @@ int ipt_account_handle_prepare_read(char *tablename, unsigned int *count)
                 {
                     if (src_16->mask_24[b])
                     {
-                        if (!(network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)))
+                        if ((network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
                         {
                             printk("ACCOUNT: out of memory during copy of 16 bit network in ipt_account_handle_prepare_read()\n");
                             ipt_account_data_free(ipt_account_handles[handle].data, depth);
@@ -821,6 +832,16 @@ static int ipt_account_set_ctl(struct sock *sk, int cmd, void *user, unsigned in
             ret = ipt_account_handle_free(handle.handle_nr);
             spin_unlock_bh(&ipt_account_userspace_lock);
             break;
+        case IPT_SO_SET_ACCOUNT_HANDLE_FREE_ALL:
+        {
+            unsigned int i;
+            spin_lock_bh(&ipt_account_userspace_lock);
+            for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
+                ipt_account_handle_free(i);
+            spin_unlock_bh(&ipt_account_userspace_lock);
+            ret = 0;
+            break;
+        }
         default:
             printk("ACCOUNT: ipt_account_set_ctl: unknown request %i\n", cmd);
     }
@@ -899,6 +920,7 @@ static int ipt_account_get_ctl(struct sock *sk, int cmd, void *user, int *len)
             {
                 printk("ACCOUNT: ipt_account_get_ctl: not enough space (%u < %u) to store data from IPT_SO_GET_ACCOUNT_GET_DATA\n",
                        *len, ipt_account_handles[handle.handle_nr].itemcount*sizeof(struct ipt_account_handle_ip));
+                ret = -ENOMEM;
                 break;
             }   
             
@@ -913,7 +935,70 @@ static int ipt_account_get_ctl(struct sock *sk, int cmd, void *user, int *len)
             
             ret = 0;
             break;
-        
+        case IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE:
+        {
+            if (*len < sizeof(struct ipt_account_handle_sockopt))
+            {
+                printk("ACCOUNT: ipt_account_get_ctl: wrong data size (%u != %u) for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n",
+                       *len, sizeof(struct ipt_account_handle_sockopt));
+                break;
+            }   
+            
+            // Find out how many handles are in use
+            unsigned int i;
+            handle.itemcount = 0;
+            spin_lock_bh(&ipt_account_userspace_lock);
+            for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
+                if (ipt_account_handles[i].data)
+                    handle.itemcount++;
+            spin_unlock_bh(&ipt_account_userspace_lock);
+            
+            if (copy_to_user(user, &handle, sizeof(struct ipt_account_handle_sockopt)))
+            {
+                printk("ACCOUNT: ipt_account_set_ctl: copy_to_user failed for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n");
+                break;
+            }
+            ret = 0;
+            break;
+        }
+        case IPT_SO_GET_ACCOUNT_GET_TABLE_NAMES:
+        {
+            spin_lock_bh(&ipt_account_lock);
+            
+            // Determine size of table names
+            unsigned int size = 0, i;
+            for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
+            {
+                if (ipt_account_tables[i].name[0] != 0)
+                    size += strlen (ipt_account_tables[i].name) + 1;
+            }
+            size += 1;    // Terminating NULL character
+            
+            if (*len < size)
+            {
+                spin_unlock_bh(&ipt_account_lock);
+                printk("ACCOUNT: ipt_account_get_ctl: not enough space (%u < %u) to store table names\n", *len, size);
+                ret = -ENOMEM;
+                break;
+            }
+            // Copy table names to userspace
+            char *tnames = user;
+            for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
+            {
+                if (ipt_account_tables[i].name[0] != 0)
+                {
+                    int len = strlen (ipt_account_tables[i].name) + 1;
+                    copy_to_user(tnames, ipt_account_tables[i].name, len);    // copy string + terminating zero
+                    tnames += len;
+                }
+            }
+            // Append terminating zero
+            i = 0;
+            copy_to_user(tnames, &i, 1);    
+            spin_unlock_bh(&ipt_account_lock);
+            ret = 0;
+            break;
+        }
         default:
             printk("ACCOUNT: ipt_account_get_ctl: unknown request %i\n", cmd);
     }
@@ -937,14 +1022,14 @@ static int __init init(void)
         return -EINVAL;
     }
 
-    if (!(ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table), GFP_KERNEL)))
+    if ((ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table), GFP_KERNEL)) == NULL)
     {
         printk("ACCOUNT: Out of memory allocating account_tables structure");
         return -EINVAL;
     }
     memset(ipt_account_tables, 0, ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table));
                             
-    if (!(ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle), GFP_KERNEL)))
+    if ((ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle), GFP_KERNEL)) == NULL)
     {
         printk("ACCOUNT: Out of memory allocating account_handles structure");
         kfree (ipt_account_tables);
@@ -957,8 +1042,8 @@ static int __init init(void)
     if ((ipt_account_tmpbuf = (void*)__get_free_page(GFP_KERNEL)) == NULL)
     {
         printk("ACCOUNT: Out of memory for temporary buffer page\n");
-        kfree (ipt_account_tables);
-        kfree (ipt_account_handles);
+        kfree(ipt_account_tables);
+        kfree(ipt_account_handles);
         ipt_account_tables = NULL;
         ipt_account_handles = NULL;
         return -EINVAL;
@@ -969,10 +1054,12 @@ static int __init init(void)
     {
         printk("ACCOUNT: Can't register sockopts. Aborting\n");
         
-        kfree (ipt_account_tables);
-        ipt_account_tables = NULL;
+        kfree(ipt_account_tables);
         kfree(ipt_account_handles);
+        free_page((unsigned long)ipt_account_tmpbuf);
+        ipt_account_tables = NULL;
         ipt_account_handles = NULL;
+        ipt_account_tmpbuf = NULL;
         
         return -EINVAL;
     }
@@ -990,10 +1077,12 @@ static void __exit fini(void)
     nf_unregister_sockopt(&ipt_account_sockopts);
     
     kfree(ipt_account_tables);
-    ipt_account_tables = NULL;
-        
     kfree(ipt_account_handles);
+    free_page((unsigned long)ipt_account_tmpbuf);
+    
+    ipt_account_tables = NULL;
     ipt_account_handles = NULL;
+    ipt_account_tmpbuf = NULL;
 }
 
 module_init(init);