iptables: (tomj) fixed badly broken memory handling, basic stuff now working
authorThomas Jarosch <thomas.jarosch@intra2net.com>
Thu, 8 Apr 2004 22:35:57 +0000 (22:35 +0000)
committerThomas Jarosch <thomas.jarosch@intra2net.com>
Thu, 8 Apr 2004 22:35:57 +0000 (22:35 +0000)
linux/include/linux/netfilter_ipv4/ipt_ACCOUNT.h
linux/net/ipv4/netfilter/ipt_ACCOUNT.c

index dc0ef91..d0976dc 100644 (file)
@@ -19,7 +19,7 @@ struct ipt_account_table
     char name[ACCOUNT_TABLE_NAME_LEN];        /* name of the table */
     unsigned int ip;                          /* base IP of network */
     unsigned int netmask;                     /* netmask of the network */
-    unsigned char netsize;                    /* Number of bits used in this netmask */
+    unsigned char depth;                      /* Size of network: 0: 8 bit, 1: 16bit, 2: 24 bit */
     unsigned int refcount;                    /* refcount of this table. if zero, destroy it */
     unsigned int itemcount;                   /* number of IPs in this table */
     void *data;                               /* pointer to the actual data, depending on netmask */
index 152c564..ec65b30 100644 (file)
@@ -40,22 +40,22 @@ ipt_account_target(struct sk_buff **pskb,
 }
 
 /* Recursive free of all data structures */
-void ipt_account_data_free(void *data, unsigned char netsize)
+void ipt_account_data_free(void *data, unsigned char depth)
 {
     // Empty data set
     if (!data)
         return;
         
-    // Free for 8 bit network. Special: 0.0.0.0/0
-    if (netsize >= 24 || netsize == 0)
+    // Free for 8 bit network
+    if (depth == 0)
     {
-        kfree(data);
+        free_page((unsigned long)data);
         data = NULL;
         return;
     }
     
     // Free for 16 bit network
-    if (netsize >= 16)
+    if (depth == 1)
     {
         struct ipt_account_mask_16 *mask_16 = (struct ipt_account_mask_16 *)data;
         unsigned char b;
@@ -63,17 +63,17 @@ void ipt_account_data_free(void *data, unsigned char netsize)
         {
             if (mask_16->mask_24[b] != 0)
             {
-                kfree(mask_16->mask_24[b]);
+                free_page((unsigned long)mask_16->mask_24[b]);
                 mask_16->mask_24[b] = NULL;
             }
         }
-        kfree(data);
+        free_page((unsigned long)data);
         data = NULL;
         return;
     } 
    
     // Free for 24 bit network
-    if (netsize >= 8)
+    if (depth == 3)
     {
         unsigned char a, b;
         for (a=0; a < 255; a++)
@@ -84,20 +84,20 @@ void ipt_account_data_free(void *data, unsigned char netsize)
                 for (b=0; b < 255; b++)
                 {
                     if (mask_16->mask_24[b]) {
-                        kfree(mask_16->mask_24[b]);
+                        free_page((unsigned long)mask_16->mask_24[b]);
                         mask_16->mask_24[b] = NULL;
                     }
                 }
-                kfree(mask_16);
+                free_page((unsigned long)mask_16);
                 mask_16 = NULL;
             }
         }
-        kfree(data);
+        free_page((unsigned long)data);
         data = NULL;
         return;
     }
     
-    printk("ACCOUNT: ipt_account_data_free called with broken netsize: %d\n", netsize);
+    printk("ACCOUNT: ipt_account_data_free called with unknown depth: %d\n", depth);
     return;
 }
 
@@ -106,19 +106,20 @@ int ipt_account_table_insert(char *name, unsigned int ip, unsigned int netmask)
 {
     unsigned int i;
 
-    DEBUGP("ACCOUNT: ipt_account_table_insert: %s, %u/%u\n", name, ip, netmask);
+    DEBUGP("ACCOUNT: ipt_account_table_insert: %s, %u.%u.%u.%u/%u.%u.%u.%u\n", name, NIPQUAD(ip), NIPQUAD(netmask));
 
     // Look for existing table
     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
     {
         if (strcmp(ipt_account_tables[i].name, name) == 0)
         {
-            DEBUGP("ACCOUT: Found existing slot: %d - %u/%u\n", i, ipt_account_tables[i].ip, ipt_account_tables[i].netmask);
+            DEBUGP("ACCOUNT: Found existing slot: %d - %u.%u.%u.%u/%u.%u.%u.%u\n", i,
+                   NIPQUAD(ipt_account_tables[i].ip), NIPQUAD(ipt_account_tables[i].netmask));
             
             if (ipt_account_tables[i].ip != ip || ipt_account_tables[i].netmask != netmask)
             {
-                printk("ACCOUNT: Table %s found, but IP/netmask mismatch. IP/netmask found: %u/%u\n",
-                        name, ipt_account_tables[i].ip, ipt_account_tables[i].netmask);
+                printk("ACCOUNT: Table %s found, but IP/netmask mismatch. IP/netmask found: %u.%u.%u.%u/%u.%u.%u.%u\n",
+                        name, NIPQUAD(ipt_account_tables[i].ip), NIPQUAD(ipt_account_tables[i].netmask));
                 return -1;
             }
 
@@ -141,25 +142,35 @@ int ipt_account_table_insert(char *name, unsigned int ip, unsigned int netmask)
             ipt_account_tables[i].ip = ip;
             ipt_account_tables[i].netmask = netmask;
             
-            // calculate netsize
-            unsigned int j, calc_mask;
+            // Calculate netsize
+            unsigned int j, calc_mask, netsize=0;
             calc_mask = htonl(netmask);
             for (j = 31; j > 0; j--)
             {
                 if (calc_mask&(1<<j))
-                    ipt_account_tables[i].netsize++;
+                    netsize++;
                 else
                     break;
             }
-            printk("ACCOUNT: calculated netsize: %u\n", ipt_account_tables[i].netsize);
+            
+            // Calculate depth from netsize
+            if (netsize >= 24)
+                ipt_account_tables[i].depth = 0;
+            else if (netsize >= 16)
+                ipt_account_tables[i].depth = 1;
+            else if(netsize >= 8)
+                ipt_account_tables[i].depth = 2;
+            
+            printk("ACCOUNT: calculated netsize: %u -> ipt_account_table depth %u\n", netsize, ipt_account_tables[i].depth);
                         
             ipt_account_tables[i].refcount++;
             if (!(ipt_account_tables[i].data = (void *)get_zeroed_page(GFP_KERNEL)))
             {
-                printk("ACCOUNT: Out of memory for data of table: %s\n", name);
+                printk("ACCOUNT: out of memory for data of table: %s\n", name);
                 memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
                 return -1;
             }
+            
             return i;
         }
     }
@@ -183,40 +194,87 @@ static int ipt_account_checkentry(const char *tablename,
         return 0;
     }
 
+    spin_lock_bh(&ipt_account_lock);
     int table_nr = ipt_account_table_insert(info->table_name, info->net_ip, info->net_mask);
     if (table_nr == -1)
     {
         printk("ACCOUNT: Table insert problem. Aborting\n");
+        spin_unlock_bh(&ipt_account_lock);
         return 0;
     }
-
     // Table nr caching so we don't have to do an extra string compare for every packet
     info->table_nr = table_nr;
     
+    spin_unlock_bh(&ipt_account_lock);
+
     return 1;
 }
 
+void ipt_account_deleteentry(void *targinfo, unsigned int targinfosize)
+{
+    unsigned int i;
+    struct ipt_account_info *info = targinfo;
+    
+    if (targinfosize != IPT_ALIGN(sizeof(struct ipt_account_info))) {
+        DEBUGP("ACCOUNT: targinfosize %u != %u\n",
+                targinfosize, IPT_ALIGN(sizeof(struct ipt_account_info)));
+    }
+
+    spin_lock_bh(&ipt_account_lock);
+    
+    DEBUGP("ACCOUNT: ipt_account_deleteentry called for table: %s (#%d)\n", info->table_name, info->table_nr);
+    
+    info->table_nr = -1;    // Set back to original state
+    
+    // Look for table
+    for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
+    {
+        if (strcmp(ipt_account_tables[i].name, info->table_name) == 0)
+        {
+            DEBUGP("ACCOUNT: Found table at slot: %d\n", i);
+            
+            ipt_account_tables[i].refcount--;
+            DEBUGP("ACCOUNT: Refcount left: %d\n", ipt_account_tables[i].refcount);
+
+            // Table not needed anymore?
+            if (ipt_account_tables[i].refcount == 0)
+            {
+                DEBUGP("ACCOUNT: Destroying table at slot: %d\n", i);
+                ipt_account_data_free(ipt_account_tables[i].data, ipt_account_tables[i].depth);
+                memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
+            }
+            
+            spin_unlock_bh(&ipt_account_lock);
+            return;
+        }
+    }
+
+    // Table not found
+    printk("ACCOUNT: Table %s not found for destroy\n", info->table_name);
+    spin_unlock_bh(&ipt_account_lock);
+}
+
 static struct ipt_target ipt_account_reg
-= { { NULL, NULL }, "ACCOUNT", ipt_account_target, ipt_account_checkentry, NULL, 
+= { { NULL, NULL }, "ACCOUNT", ipt_account_target, ipt_account_checkentry, ipt_account_deleteentry, 
     THIS_MODULE };
 
 static int __init init(void)
 {
-    if (!(ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES, sizeof(struct ipt_account_table))))
+    if (!(ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table), GFP_KERNEL)))
     {
             printk("ACCOUNT: Out of memory allocating account_tables structure");
             return -EINVAL;
     }
-    memset(ipt_account_tables, 0, sizeof(struct ipt_account_table));
+    memset(ipt_account_tables, 0, ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table));
                             
-    if (!(ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES, sizeof(struct ipt_account_handle))))
+    if (!(ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle), GFP_KERNEL)))
     {
             printk("ACCOUNT: Out of memory allocating account_handles structure");
             kfree (ipt_account_tables);
             ipt_account_tables = NULL;
             return -EINVAL;
     }
-    memset(ipt_account_handles, 0, sizeof(struct ipt_account_handle));
+    memset(ipt_account_handles, 0, ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle));
     
     if (ipt_register_target(&ipt_account_reg))
             return -EINVAL;