8b1ce2a67ed4f2292c0ebb2af551ddcad529f9cc
[ipt_ACCOUNT] / linux / net / ipv4 / netfilter / ipt_ACCOUNT.c
1 /***************************************************************************
2  *   This is a module which is used for counting packets.                  *
3  *   See http://www.intra2net.com/opensource/ipt_account                   *
4  *   for further information                                               *
5  *                                                                         * 
6  *   Copyright (C) 2004 by Intra2net AG                                    *
7  *   opensource@intra2net.com                                              *
8  *                                                                         *
9  *   This program is free software; you can redistribute it and/or modify  *
10  *   it under the terms of the GNU General Public License                  *
11  *   version 2 as published by the Free Software Foundation;               *
12  *                                                                         *
13  ***************************************************************************/
14
15 #include <linux/module.h>
16 #include <linux/skbuff.h>
17 #include <linux/ip.h>
18 #include <linux/spinlock.h>
19 #include <net/icmp.h>
20 #include <net/udp.h>
21 #include <net/tcp.h>
22 #include <linux/netfilter_ipv4/ip_tables.h>
23
24 struct in_device;
25 #include <net/route.h>
26 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
27
28 #if 0
29 #define DEBUGP printk
30 #else
31 #define DEBUGP(format, args...)
32 #endif
33
34 struct ipt_account_table *ipt_account_tables = NULL;
35 struct ipt_account_handle *ipt_account_handles = NULL;
36 void *ipt_account_tmpbuf = NULL;
37
38 // Spinlock used for manipulating the current accounting tables/data
39 static spinlock_t ipt_account_lock = SPIN_LOCK_UNLOCKED;
40 // Spinlock used for manipulating userspace handles/snapshot data
41 static spinlock_t ipt_account_userspace_lock = SPIN_LOCK_UNLOCKED;
42
43
44 /* Recursive free of all data structures */
45 void ipt_account_data_free(void *data, unsigned char depth)
46 {
47     // Empty data set
48     if (!data)
49         return;
50         
51     // Free for 8 bit network
52     if (depth == 0)
53     {
54         free_page((unsigned long)data);
55         data = NULL;
56         return;
57     }
58     
59     // Free for 16 bit network
60     if (depth == 1)
61     {
62         struct ipt_account_mask_16 *mask_16 = (struct ipt_account_mask_16 *)data;
63         unsigned int b;
64         for (b=0; b <= 255; b++)
65         {
66             if (mask_16->mask_24[b] != 0)
67             {
68                 free_page((unsigned long)mask_16->mask_24[b]);
69                 mask_16->mask_24[b] = NULL;
70             }
71         }
72         free_page((unsigned long)data);
73         data = NULL;
74         return;
75     } 
76    
77     // Free for 24 bit network
78     if (depth == 3)
79     {
80         unsigned int a, b;
81         for (a=0; a <= 255; a++)
82         {
83             if (((struct ipt_account_mask_8 *)data)->mask_16[a])
84             {
85                 struct ipt_account_mask_16 *mask_16 = (struct ipt_account_mask_16*)((struct ipt_account_mask_8 *)data)->mask_16[a];
86                 for (b=0; b <= 255; b++)
87                 {
88                     if (mask_16->mask_24[b]) {
89                         free_page((unsigned long)mask_16->mask_24[b]);
90                         mask_16->mask_24[b] = NULL;
91                     }
92                 }
93                 free_page((unsigned long)mask_16);
94                 mask_16 = NULL;
95             }
96         }
97         free_page((unsigned long)data);
98         data = NULL;
99         return;
100     }
101     
102     printk("ACCOUNT: ipt_account_data_free called with unknown depth: %d\n", depth);
103     return;
104 }
105
106 /* Look for existing table / insert new one. Return internal ID or -1 on error */
107 int ipt_account_table_insert(char *name, unsigned int ip, unsigned int netmask)
108 {
109     unsigned int i;
110
111     DEBUGP("ACCOUNT: ipt_account_table_insert: %s, %u.%u.%u.%u/%u.%u.%u.%u\n", name, NIPQUAD(ip), NIPQUAD(netmask));
112
113     // Look for existing table
114     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
115     {
116         if (strncmp(ipt_account_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN) == 0)
117         {
118             DEBUGP("ACCOUNT: Found existing slot: %d - %u.%u.%u.%u/%u.%u.%u.%u\n", i,
119                    NIPQUAD(ipt_account_tables[i].ip), NIPQUAD(ipt_account_tables[i].netmask));
120             
121             if (ipt_account_tables[i].ip != ip || ipt_account_tables[i].netmask != netmask)
122             {
123                 printk("ACCOUNT: Table %s found, but IP/netmask mismatch. IP/netmask found: %u.%u.%u.%u/%u.%u.%u.%u\n",
124                         name, NIPQUAD(ipt_account_tables[i].ip), NIPQUAD(ipt_account_tables[i].netmask));
125                 return -1;
126             }
127
128             ipt_account_tables[i].refcount++;
129             DEBUGP("ACCOUNT: Refcount: %d\n", ipt_account_tables[i].refcount);
130             return i;
131         }
132     }
133
134     // Insert new table
135     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
136     {
137         // Found free slot
138         if (ipt_account_tables[i].name[0] == 0)
139         {
140             DEBUGP("ACCOUNT: Found free slot: %d\n", i);
141         
142             strncpy (ipt_account_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN-1);
143             
144             ipt_account_tables[i].ip = ip;
145             ipt_account_tables[i].netmask = netmask;
146             
147             // Calculate netsize
148             unsigned int j, calc_mask, netsize=0;
149             calc_mask = htonl(netmask);
150             for (j = 31; j > 0; j--)
151             {
152                 if (calc_mask&(1<<j))
153                     netsize++;
154                 else
155                     break;
156             }
157             
158             // Calculate depth from netsize
159             if (netsize >= 24)
160                 ipt_account_tables[i].depth = 0;
161             else if (netsize >= 16)
162                 ipt_account_tables[i].depth = 1;
163             else if(netsize >= 8)
164                 ipt_account_tables[i].depth = 2;
165             
166             printk("ACCOUNT: calculated netsize: %u -> ipt_account_table depth %u\n", netsize, ipt_account_tables[i].depth);
167                         
168             ipt_account_tables[i].refcount++;
169             if ((ipt_account_tables[i].data = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
170             {
171                 printk("ACCOUNT: out of memory for data of table: %s\n", name);
172                 memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
173                 return -1;
174             }
175             
176             return i;
177         }
178     }
179         
180     // No free slot found
181     printk("ACCOUNT: No free table slot found (max: %d). Please increase ACCOUNT_MAX_TABLES.\n", ACCOUNT_MAX_TABLES);
182     return -1;
183 }
184
185 static int ipt_account_checkentry(const char *tablename,
186     const struct ipt_entry *e,
187     void *targinfo,
188     unsigned int targinfosize,
189     unsigned int hook_mask)
190 {
191     struct ipt_account_info *info = targinfo;
192     
193     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_account_info))) {
194         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
195                 targinfosize, IPT_ALIGN(sizeof(struct ipt_account_info)));
196         return 0;
197     }
198
199     spin_lock_bh(&ipt_account_lock);
200     int table_nr = ipt_account_table_insert(info->table_name, info->net_ip, info->net_mask);
201     if (table_nr == -1)
202     {
203         printk("ACCOUNT: Table insert problem. Aborting\n");
204         spin_unlock_bh(&ipt_account_lock);
205         return 0;
206     }
207     // Table nr caching so we don't have to do an extra string compare for every packet
208     info->table_nr = table_nr;
209     
210     spin_unlock_bh(&ipt_account_lock);
211
212     return 1;
213 }
214
215 void ipt_account_deleteentry(void *targinfo, unsigned int targinfosize)
216 {
217     unsigned int i;
218     struct ipt_account_info *info = targinfo;
219     
220     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_account_info))) {
221         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
222                 targinfosize, IPT_ALIGN(sizeof(struct ipt_account_info)));
223     }
224
225     spin_lock_bh(&ipt_account_lock);
226     
227     DEBUGP("ACCOUNT: ipt_account_deleteentry called for table: %s (#%d)\n", info->table_name, info->table_nr);
228     
229     info->table_nr = -1;    // Set back to original state
230     
231     // Look for table
232     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
233     {
234         if (strncmp(ipt_account_tables[i].name, info->table_name, ACCOUNT_TABLE_NAME_LEN) == 0)
235         {
236             DEBUGP("ACCOUNT: Found table at slot: %d\n", i);
237             
238             ipt_account_tables[i].refcount--;
239             DEBUGP("ACCOUNT: Refcount left: %d\n", ipt_account_tables[i].refcount);
240
241             // Table not needed anymore?
242             if (ipt_account_tables[i].refcount == 0)
243             {
244                 DEBUGP("ACCOUNT: Destroying table at slot: %d\n", i);
245                 ipt_account_data_free(ipt_account_tables[i].data, ipt_account_tables[i].depth);
246                 memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
247             }
248             
249             spin_unlock_bh(&ipt_account_lock);
250             return;
251         }
252     }
253
254     // Table not found
255     printk("ACCOUNT: Table %s not found for destroy\n", info->table_name);
256     spin_unlock_bh(&ipt_account_lock);
257 }
258
259 void ipt_account_depth0_insert(struct ipt_account_mask_24 *mask_24, unsigned int net_ip, unsigned int netmask,
260                                unsigned int src_ip, unsigned int dst_ip, unsigned int size, unsigned int *itemcount)
261 {
262     unsigned char is_src = 0, is_dst = 0;
263     
264     DEBUGP("ACCOUNT: ipt_account_depth0_insert: %u.%u.%u.%u/%u.%u.%u.%u for net %u.%u.%u.%u/%u.%u.%u.%u, size: %u\n",
265             NIPQUAD(src_ip), NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask), size);
266         
267     // Check if src/dst is inside our network.
268     // Special: net_ip = 0.0.0.0/0 gets stored as src in slot 0
269     if (!netmask)
270         src_ip = 0;
271     if ((net_ip&netmask) == (src_ip&netmask))
272         is_src = 1;
273     if ((net_ip&netmask) == (dst_ip&netmask) && netmask)
274         is_dst = 1;
275     
276     if (!is_src && !is_dst)
277     {
278         DEBUGP("ACCOUNT: Skipping packet %u.%u.%u.%u/%u.%u.%u.%u for net %u.%u.%u.%u/%u.%u.%u.%u\n",
279                 NIPQUAD(src_ip), NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask));
280         return;
281     }
282         
283     // Check if this entry is new
284     char is_src_new_ip = 0, is_dst_new_ip = 0;
285
286     // Calculate array positions
287     unsigned char src_slot = (unsigned char)((src_ip&0xFF000000) >> 24);
288     unsigned char dst_slot = (unsigned char)((dst_ip&0xFF000000) >> 24);
289     
290     // Increase size counters
291     if (is_src)
292     {
293         // Calculate network slot
294         DEBUGP("ACCOUNT: Calculated SRC 8 bit network slot: %d\n", src_slot);
295         if (!mask_24->ip[src_slot].src_packets && !mask_24->ip[src_slot].dst_packets)
296             is_src_new_ip = 1;
297         
298         mask_24->ip[src_slot].src_packets++;
299         mask_24->ip[src_slot].src_bytes+=size;
300     }
301     if (is_dst)
302     {
303         DEBUGP("ACCOUNT: Calculated DST 8 bit network slot: %d\n", dst_slot);
304         if (!mask_24->ip[dst_slot].src_packets && !mask_24->ip[dst_slot].dst_packets)
305             is_dst_new_ip = 1;
306         
307         mask_24->ip[dst_slot].dst_packets++;
308         mask_24->ip[dst_slot].dst_bytes+=size;
309     }
310     
311     // Increase itemcounter
312     DEBUGP("ACCOUNT: Itemcounter before: %d\n", *itemcount);
313     if (src_slot == dst_slot)
314     {
315         if (is_src_new_ip || is_dst_new_ip) {
316             DEBUGP("ACCOUNT: src_slot == dst_slot: %d, %d\n", is_src_new_ip, is_dst_new_ip);
317             (*itemcount)++;
318         }
319     } else {
320         if (is_src_new_ip) {
321             DEBUGP("ACCOUNT: New src_ip: %u.%u.%u.%u\n", NIPQUAD(src_ip));
322             (*itemcount)++;
323         }
324         if (is_dst_new_ip) {
325             DEBUGP("ACCOUNT: New dst_ip: %u.%u.%u.%u\n", NIPQUAD(dst_ip));
326             (*itemcount)++;
327         }
328     }
329     DEBUGP("ACCOUNT: Itemcounter after: %d\n", *itemcount);
330 }
331
332 void ipt_account_depth1_insert(struct ipt_account_mask_16 *mask_16, unsigned int net_ip, unsigned int netmask,
333                                unsigned int src_ip, unsigned int dst_ip, unsigned int size, unsigned int *itemcount)
334 {
335     // Do we need to process src IP?
336     if ((net_ip&netmask) == (src_ip&netmask))
337     {
338         unsigned char slot = (unsigned char)((src_ip&0x00FF0000) >> 16);
339         DEBUGP("ACCOUNT: Calculated SRC 16 bit network slot: %d\n", slot);
340         
341         // Do we need to create a new mask_24 bucket?
342         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
343         {
344             printk("ACCOUNT: Can't process packet because out of memory!\n");
345             return;
346         }
347         
348         ipt_account_depth0_insert((struct ipt_account_mask_24 *)mask_16->mask_24[slot], net_ip, netmask,
349                                     src_ip, 0, size, itemcount);
350     }
351     
352     // Do we need to process dst IP?
353     if ((net_ip&netmask) == (dst_ip&netmask))
354     {
355         unsigned char slot = (unsigned char)((dst_ip&0x00FF0000) >> 16);
356         DEBUGP("ACCOUNT: Calculated DST 16 bit network slot: %d\n", slot);
357         
358         // Do we need to create a new mask_24 bucket?
359         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
360         {
361             printk("ACCOUT: Can't process packet because out of memory!\n");
362             return;
363         }
364         
365         ipt_account_depth0_insert((struct ipt_account_mask_24 *)mask_16->mask_24[slot], net_ip, netmask,
366                                     0, dst_ip, size, itemcount);
367     }
368 }
369
370 void ipt_account_depth2_insert(struct ipt_account_mask_8 *mask_8, unsigned int net_ip, unsigned int netmask,
371                                unsigned int src_ip, unsigned int dst_ip, unsigned int size, unsigned int *itemcount)
372 {
373     // Do we need to process src IP?
374     if ((net_ip&netmask) == (src_ip&netmask))
375     {
376         unsigned char slot = (unsigned char)((src_ip&0x0000FF00) >> 8);
377         DEBUGP("ACCOUNT: Calculated SRC 24 bit network slot: %d\n", slot);
378                 
379         // Do we need to create a new mask_24 bucket?
380         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
381         {
382             printk("ACCOUNT: Can't process packet because out of memory!\n");
383             return;
384         }
385         
386         ipt_account_depth1_insert((struct ipt_account_mask_16 *)mask_8->mask_16[slot], net_ip, netmask,
387                                     src_ip, 0, size, itemcount);
388     }
389     
390     // Do we need to process dst IP?
391     if ((net_ip&netmask) == (dst_ip&netmask))
392     {
393         unsigned char slot = (unsigned char)((dst_ip&0x0000FF00) >> 8);
394         DEBUGP("ACCOUNT: Calculated DST 24 bit network slot: %d\n", slot);
395         
396         // Do we need to create a new mask_24 bucket?
397         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] = (void *)get_zeroed_page(GFP_KERNEL)) == NULL)
398         {
399             printk("ACCOUNT: Can't process packet because out of memory!\n");
400             return;
401         }
402         
403         ipt_account_depth1_insert((struct ipt_account_mask_16 *)mask_8->mask_16[slot], net_ip, netmask,
404                                     0, dst_ip, size, itemcount);
405     }
406 }
407
408 static unsigned int ipt_account_target(struct sk_buff **pskb,
409     unsigned int hooknum,
410     const struct net_device *in,
411     const struct net_device *out,
412     const void *targinfo,
413     void *userinfo)
414 {
415     const struct ipt_account_info *info = (const struct ipt_account_info *)targinfo;
416     unsigned int src_ip = (*pskb)->nh.iph->saddr;
417     unsigned int dst_ip = (*pskb)->nh.iph->daddr;
418     unsigned int size = ntohs((*pskb)->nh.iph->tot_len);
419     
420     spin_lock_bh(&ipt_account_lock);
421     
422     if (ipt_account_tables[info->table_nr].name[0] == 0)
423     {
424         printk("ACCOUNT: ipt_account_target: Invalid table id %u. IPs %u.%u.%u.%u/%u.%u.%u.%u\n",
425                info->table_nr, NIPQUAD(src_ip), NIPQUAD(dst_ip));
426         spin_unlock_bh(&ipt_account_lock);
427         return IPT_CONTINUE;
428     }
429     
430     // 8 bit network or "any" network
431     if (ipt_account_tables[info->table_nr].depth == 0)
432     {
433         // Count packet and check if the IP is new
434         ipt_account_depth0_insert((struct ipt_account_mask_24 *)ipt_account_tables[info->table_nr].data,
435                                   ipt_account_tables[info->table_nr].ip, ipt_account_tables[info->table_nr].netmask,
436                                   src_ip, dst_ip, size, &ipt_account_tables[info->table_nr].itemcount);
437         spin_unlock_bh(&ipt_account_lock);
438         return IPT_CONTINUE;
439     }    
440     
441     // 16 bit network
442     if (ipt_account_tables[info->table_nr].depth == 1)
443     {
444         ipt_account_depth1_insert((struct ipt_account_mask_16 *)ipt_account_tables[info->table_nr].data,
445                                   ipt_account_tables[info->table_nr].ip, ipt_account_tables[info->table_nr].netmask,
446                                   src_ip, dst_ip, size, &ipt_account_tables[info->table_nr].itemcount);
447         spin_unlock_bh(&ipt_account_lock);
448         return IPT_CONTINUE;
449     }
450     
451     // 24 bit network
452     if (ipt_account_tables[info->table_nr].depth == 2)
453     {
454         ipt_account_depth2_insert((struct ipt_account_mask_8 *)ipt_account_tables[info->table_nr].data,
455                                   ipt_account_tables[info->table_nr].ip, ipt_account_tables[info->table_nr].netmask,
456                                   src_ip, dst_ip, size, &ipt_account_tables[info->table_nr].itemcount);
457         spin_unlock_bh(&ipt_account_lock);
458         return IPT_CONTINUE;
459     }
460     
461     printk("ACCOUNT: ipt_account_target: Unable to process packet. Table id %u. IPs %u.%u.%u.%u/%u.%u.%u.%u\n",
462             info->table_nr, NIPQUAD(src_ip), NIPQUAD(dst_ip));
463     
464     spin_unlock_bh(&ipt_account_lock);
465     return IPT_CONTINUE;
466 }
467
468 /*
469     Functions dealing with "handles":
470     Handles are snapshots of a accounting state.
471     
472     read snapshots are only for debugging the code
473     and are very expensive concerning speed/memory
474     compared to read_and_flush.
475     
476     The functions aren't protected by spinlocks themselves
477     as this is done in the ioctl part of the code.
478 */
479
480 /*
481     Find a free handle slot. Normally only one should be used,
482     but there could be two or more applications accessing the data
483     at the same time.
484 */
485 int ipt_account_handle_find_slot(void)
486 {
487     unsigned int i;
488     // Insert new table
489     for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
490     {
491         // Found free slot
492         if (ipt_account_handles[i].data == NULL)
493         {
494             // Don't "mark" data as used as we are protected by a spinlock by the calling function.
495             // handle_find_slot() is only a function to prevent code duplication.
496             return i;
497         }
498     }
499         
500     // No free slot found
501     printk("ACCOUNT: No free handle slot found (max: %u). Please increase ACCOUNT_MAX_HANDLES.\n", ACCOUNT_MAX_HANDLES);
502     return -1;
503 }
504
505 int ipt_account_handle_free(unsigned int handle)
506 {
507     if (handle >= ACCOUNT_MAX_HANDLES)
508     {
509         printk("ACCOUNT: Invalid handle for ipt_account_handle_free() specified: %u\n", handle);
510         return -EINVAL;
511     }
512     
513     ipt_account_data_free(ipt_account_handles[handle].data, ipt_account_handles[handle].depth);
514     memset (&ipt_account_handles[handle], 0, sizeof (struct ipt_account_handle));
515     return 0;
516 }
517
518 /* Prepare data for read without flush. Use only for debugging!
519    Real applications should use read&flush as it's way more efficent */
520 int ipt_account_handle_prepare_read(char *tablename, unsigned int *count)
521 {
522     int handle, i, table_nr=-1;
523     
524     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
525         if (strncmp(ipt_account_tables[i].name, tablename, ACCOUNT_TABLE_NAME_LEN) == 0)
526         {
527             table_nr = i;
528             break;
529         }
530
531     if (table_nr == -1)
532     {
533         printk("ACCOUNT: ipt_account_handle_prepare_read(): Table %s not found\n", tablename);
534         return -1;
535     }
536             
537     // Can't find a free handle slot?
538     if ((handle = ipt_account_handle_find_slot()) == -1)
539         return -1;
540     
541     // Fill up handle structure
542     ipt_account_handles[handle].ip = ipt_account_tables[table_nr].ip;
543     ipt_account_handles[handle].depth = ipt_account_tables[table_nr].depth;
544     ipt_account_handles[handle].itemcount = ipt_account_tables[table_nr].itemcount;
545     
546     // allocate "root" table
547     if ((ipt_account_handles[handle].data = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
548     {
549         printk("ACCOUNT: out of memory for root table in ipt_account_handle_prepare_read()\n");
550         memset (&ipt_account_handles[handle], 0, sizeof(struct ipt_account_handle));
551         return -1;        
552     }
553     
554     // Recursive copy of complete data structure
555     unsigned int depth = ipt_account_handles[handle].depth;
556     if (depth == 0)
557     {
558         memcpy(ipt_account_handles[handle].data, ipt_account_tables[table_nr].data, sizeof(struct ipt_account_mask_24));
559     } else if (depth == 1) {
560         struct ipt_account_mask_16 *src_16 = (struct ipt_account_mask_16 *)ipt_account_tables[table_nr].data;
561         struct ipt_account_mask_16 *network_16 = (struct ipt_account_mask_16 *)ipt_account_handles[handle].data;
562         unsigned int b;
563         
564         for (b = 0; b <= 255; b++)
565         {
566             if (src_16->mask_24[b])
567             {
568                 if ((network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
569                 {
570                     printk("ACCOUNT: out of memory during copy of 16 bit network in ipt_account_handle_prepare_read()\n");
571                     ipt_account_data_free(ipt_account_handles[handle].data, depth);
572                     memset (&ipt_account_handles[handle], 0, sizeof(struct ipt_account_handle));
573                     return -1;                    
574                 }
575                 
576                 memcpy(network_16->mask_24[b], src_16->mask_24[b], sizeof(struct ipt_account_mask_24));
577             }
578         }
579     } else if(depth == 2) {
580         struct ipt_account_mask_8 *src_8 = (struct ipt_account_mask_8 *)ipt_account_tables[table_nr].data;
581         struct ipt_account_mask_8 *network_8 = (struct ipt_account_mask_8 *)ipt_account_handles[handle].data;
582         unsigned int a;
583         
584         for (a = 0; a <= 255; a++)
585         {
586             if (src_8->mask_16[a])
587             {
588                 if ((network_8->mask_16[a] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
589                 {
590                     printk("ACCOUNT: out of memory during copy of 24 bit network in ipt_account_handle_prepare_read()\n");
591                     ipt_account_data_free(ipt_account_handles[handle].data, depth);
592                     memset (&ipt_account_handles[handle], 0, sizeof(struct ipt_account_handle));
593                     return -1;                    
594                 }
595
596                 memcpy(network_8->mask_16[a], src_8->mask_16[a], sizeof(struct ipt_account_mask_16));
597                     
598                 struct ipt_account_mask_16 *src_16 = src_8->mask_16[a];
599                 struct ipt_account_mask_16 *network_16 = network_8->mask_16[a];
600                 unsigned int b;
601                 
602                 for (b = 0; b <= 255; b++)
603                 {
604                     if (src_16->mask_24[b])
605                     {
606                         if ((network_16->mask_24[b] = (void*)get_zeroed_page(GFP_KERNEL)) == NULL)
607                         {
608                             printk("ACCOUNT: out of memory during copy of 16 bit network in ipt_account_handle_prepare_read()\n");
609                             ipt_account_data_free(ipt_account_handles[handle].data, depth);
610                             memset (&ipt_account_handles[handle], 0, sizeof(struct ipt_account_handle));
611                             return -1;                    
612                         }
613                         
614                         memcpy(network_16->mask_24[b], src_16->mask_24[b], sizeof(struct ipt_account_mask_24));
615                     }
616                 }
617             }
618         }
619     }
620
621     *count = ipt_account_tables[table_nr].itemcount;
622     return handle;
623 }
624
625 /* Prepare data for read and flush it */
626 int ipt_account_handle_prepare_read_flush(char *tablename, unsigned int *count)
627 {
628     int handle, i, table_nr=-1;
629     
630     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
631         if (strncmp(ipt_account_tables[i].name, tablename, ACCOUNT_TABLE_NAME_LEN) == 0)
632         {
633             table_nr = i;
634             break;
635         }
636
637     if (table_nr == -1)
638     {
639         printk("ACCOUNT: ipt_account_handle_prepare_read_flush(): Table %s not found\n", tablename);
640         return -1;
641     }
642             
643     // Can't find a free handle slot?
644     if ((handle = ipt_account_handle_find_slot()) == -1)
645         return -1;
646     
647     // Fill up handle structure
648     ipt_account_handles[handle].ip = ipt_account_tables[table_nr].ip;
649     ipt_account_handles[handle].depth = ipt_account_tables[table_nr].depth;
650     ipt_account_handles[handle].itemcount = ipt_account_tables[table_nr].itemcount;
651     ipt_account_handles[handle].data = ipt_account_tables[table_nr].data;
652     *count = ipt_account_tables[table_nr].itemcount;
653
654     // "Flush" table data
655     ipt_account_tables[table_nr].data = (void*)get_zeroed_page(GFP_KERNEL);
656     ipt_account_tables[table_nr].itemcount = 0;
657
658     return handle;
659 }
660
661 /* Copy the actual that into a prepared buffer.
662    We only copy entries != 0 to increase performance.
663    The memory gets freed again in ipt_account_handle_free().
664 */
665 int ipt_account_handle_get_data(unsigned int handle, void *buffer)
666 {
667     struct ipt_account_handle_ip handle_ip;    
668     unsigned int handle_ip_size = sizeof (struct ipt_account_handle_ip);
669     unsigned int i, tmpbuf_pos=0;
670
671     if (handle >= ACCOUNT_MAX_HANDLES)
672     {
673         printk("ACCOUNT: invalid handle for ipt_account_handle_get_data() specified: %u\n", handle);
674         return -1;
675     }
676     
677     if (ipt_account_handles[handle].data == NULL)
678     {
679         printk("ACCOUNT: handle %u is BROKEN: Contains no data\n", handle);
680         return -1;
681     }
682     
683     unsigned int net_ip = ipt_account_handles[handle].ip;
684     unsigned int depth = ipt_account_handles[handle].depth;
685     
686     // 8 bit network
687     if (depth == 0)
688     {
689         struct ipt_account_mask_24 *network = (struct ipt_account_mask_24*)ipt_account_handles[handle].data;
690         for (i = 0; i <= 255; i++)
691         {
692             if (network->ip[i].src_packets || network->ip[i].dst_packets)
693             {
694                 handle_ip.ip = net_ip | (i<<24);
695                 handle_ip.src_packets = network->ip[i].src_packets;
696                 handle_ip.src_bytes = network->ip[i].src_bytes;
697                 handle_ip.dst_packets = network->ip[i].dst_packets;
698                 handle_ip.dst_bytes = network->ip[i].dst_bytes;
699
700                 // Temporary buffer full? Flush to userspace
701                 if (tmpbuf_pos+handle_ip_size >= PAGE_SIZE)
702                 {
703                     copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
704                     tmpbuf_pos = 0;
705                 }
706                 memcpy(ipt_account_tmpbuf+tmpbuf_pos, &handle_ip, handle_ip_size);
707                 tmpbuf_pos += handle_ip_size;
708             }
709         }
710
711         // Flush remaining data to userspace
712         if (tmpbuf_pos)
713             copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
714         
715         return 0;
716     }
717     
718     // 16 bit network
719     if (depth == 1)
720     {
721         struct ipt_account_mask_16 *network_16 = (struct ipt_account_mask_16*)ipt_account_handles[handle].data;
722         unsigned int b;
723         for (b = 0; b <= 255; b++)
724         {
725             if (network_16->mask_24[b])
726             {
727                 struct ipt_account_mask_24 *network = (struct ipt_account_mask_24*)network_16->mask_24[b];
728                 for (i = 0; i <= 255; i++)
729                 {
730                     if (network->ip[i].src_packets || network->ip[i].dst_packets)
731                     {
732                         handle_ip.ip = net_ip | (b << 16) | (i<<24);
733                         handle_ip.src_packets = network->ip[i].src_packets;
734                         handle_ip.src_bytes = network->ip[i].src_bytes;
735                         handle_ip.dst_packets = network->ip[i].dst_packets;
736                         handle_ip.dst_bytes = network->ip[i].dst_bytes;
737         
738                         // Temporary buffer full? Flush to userspace
739                         if (tmpbuf_pos+handle_ip_size >= PAGE_SIZE)
740                         {
741                             copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
742                             tmpbuf_pos = 0;
743                         }
744                         memcpy(ipt_account_tmpbuf+tmpbuf_pos, &handle_ip, handle_ip_size);
745                         tmpbuf_pos += handle_ip_size;
746                     }
747                 }
748             }
749         }
750         
751         // Flush remaining data to userspace
752         if (tmpbuf_pos)
753             copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
754         
755         return 0;
756     }
757     
758     // 24 bit network
759     if (depth == 2)
760     {
761         struct ipt_account_mask_8 *network_8 = (struct ipt_account_mask_8*)ipt_account_handles[handle].data;
762         unsigned int a, b;
763         for (a = 0; a <= 255; a++)
764         {
765             if (network_8->mask_16[a])
766             {
767                 struct ipt_account_mask_16 *network_16 = (struct ipt_account_mask_16*)network_8->mask_16[a];
768                 for (b = 0; b <= 255; b++)
769                 {
770                     if (network_16->mask_24[b])
771                     {
772                         struct ipt_account_mask_24 *network = (struct ipt_account_mask_24*)network_16->mask_24[b];
773                         for (i = 0; i <= 255; i++)
774                         {
775                             if (network->ip[i].src_packets || network->ip[i].dst_packets)
776                             {
777                                 handle_ip.ip = net_ip | (a << 8) | (b << 16) | (i<<24);
778                                 handle_ip.src_packets = network->ip[i].src_packets;
779                                 handle_ip.src_bytes = network->ip[i].src_bytes;
780                                 handle_ip.dst_packets = network->ip[i].dst_packets;
781                                 handle_ip.dst_bytes = network->ip[i].dst_bytes;
782                 
783                                 // Temporary buffer full? Flush to userspace
784                                 if (tmpbuf_pos+handle_ip_size >= PAGE_SIZE)
785                                 {
786                                     copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
787                                     tmpbuf_pos = 0;
788                                 }
789                                 memcpy(ipt_account_tmpbuf+tmpbuf_pos, &handle_ip, handle_ip_size);
790                                 tmpbuf_pos += handle_ip_size;
791                             }
792                         }
793                     }
794                 }
795             }
796         }
797         
798         // Flush remaining data to userspace
799         if (tmpbuf_pos)
800             copy_to_user(buffer, ipt_account_tmpbuf, tmpbuf_pos);
801         
802         return 0;
803     }
804     
805     return -1;
806 }
807
808 static int ipt_account_set_ctl(struct sock *sk, int cmd, void *user, unsigned int len)
809 {
810     struct ipt_account_handle_sockopt handle;
811     int ret = -EINVAL;
812
813     if (!capable(CAP_NET_ADMIN))
814             return -EPERM;
815             
816     switch (cmd)
817     {
818         case IPT_SO_SET_ACCOUNT_HANDLE_FREE:
819             if (len != sizeof(struct ipt_account_handle_sockopt))
820             {
821                 printk("ACCOUNT: ipt_account_set_ctl: wrong data size (%u != %u) for IPT_SO_SET_HANDLE_FREE\n", len, sizeof(struct ipt_account_handle_sockopt));
822                 break;
823             }   
824             
825             if (copy_from_user (&handle, user, len))
826             {
827                 printk("ACCOUNT: ipt_account_set_ctl: copy_from_user failed for IPT_SO_SET_HANDLE_FREE\n");
828                 break;
829             }
830             
831             spin_lock_bh(&ipt_account_userspace_lock);
832             ret = ipt_account_handle_free(handle.handle_nr);
833             spin_unlock_bh(&ipt_account_userspace_lock);
834             break;
835         case IPT_SO_SET_ACCOUNT_HANDLE_FREE_ALL:
836         {
837             unsigned int i;
838             spin_lock_bh(&ipt_account_userspace_lock);
839             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
840                 ipt_account_handle_free(i);
841             spin_unlock_bh(&ipt_account_userspace_lock);
842             ret = 0;
843             break;
844         }
845         default:
846             printk("ACCOUNT: ipt_account_set_ctl: unknown request %i\n", cmd);
847     }
848
849     return ret;
850 }
851
852 static int ipt_account_get_ctl(struct sock *sk, int cmd, void *user, int *len)
853 {
854     struct ipt_account_handle_sockopt handle;
855     int ret = -EINVAL;
856
857     if (!capable(CAP_NET_ADMIN))
858             return -EPERM;
859
860     switch (cmd)
861     {
862         case IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH:
863         case IPT_SO_GET_ACCOUNT_PREPARE_READ:
864             if (*len < sizeof(struct ipt_account_handle_sockopt))
865             {
866                 printk("ACCOUNT: ipt_account_get_ctl: wrong data size (%u != %u) for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
867                        *len, sizeof(struct ipt_account_handle_sockopt));
868                 break;
869             }   
870             
871             if (copy_from_user (&handle, user, sizeof(struct ipt_account_handle_sockopt)))
872             {
873                 printk("ACCOUNT: ipt_account_get_ctl: copy_from_user failed for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n");
874                 break;
875             }
876             
877             spin_lock_bh(&ipt_account_lock);
878             spin_lock_bh(&ipt_account_userspace_lock);
879             if (cmd == IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH)
880                 handle.handle_nr = ipt_account_handle_prepare_read_flush(handle.name, &handle.itemcount);
881             else
882                 handle.handle_nr = ipt_account_handle_prepare_read(handle.name, &handle.itemcount);
883             spin_unlock_bh(&ipt_account_userspace_lock);
884             spin_unlock_bh(&ipt_account_lock);
885                 
886             if (handle.handle_nr == -1)
887             {
888                 printk("ACCOUNT: ipt_account_get_ctl: ipt_account_handle_prepare_read failed\n");
889                 break;
890             }
891             
892             if (copy_to_user(user, &handle, sizeof(struct ipt_account_handle_sockopt)))
893             {
894                 printk("ACCOUNT: ipt_account_set_ctl: copy_to_user failed for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n");
895                 break;
896             }
897             ret = 0;
898             break;
899         case IPT_SO_GET_ACCOUNT_GET_DATA:
900             if (*len < sizeof(struct ipt_account_handle_sockopt))
901             {
902                 printk("ACCOUNT: ipt_account_get_ctl: wrong data size (%u != %u) for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
903                        *len, sizeof(struct ipt_account_handle_sockopt));
904                 break;
905             }   
906             
907             if (copy_from_user (&handle, user, sizeof(struct ipt_account_handle_sockopt)))
908             {
909                 printk("ACCOUNT: ipt_account_get_ctl: copy_from_user failed for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n");
910                 break;
911             }
912             
913             if (handle.handle_nr >= ACCOUNT_MAX_HANDLES)
914             {
915                 printk("ACCOUNT: Invalid handle for IPT_SO_GET_ACCOUNT_GET_DATA: %u\n", handle.handle_nr);
916                 break;
917             }
918             
919             if (*len < ipt_account_handles[handle.handle_nr].itemcount*sizeof(struct ipt_account_handle_ip))
920             {
921                 printk("ACCOUNT: ipt_account_get_ctl: not enough space (%u < %u) to store data from IPT_SO_GET_ACCOUNT_GET_DATA\n",
922                        *len, ipt_account_handles[handle.handle_nr].itemcount*sizeof(struct ipt_account_handle_ip));
923                 ret = -ENOMEM;
924                 break;
925             }   
926             
927             spin_lock_bh(&ipt_account_userspace_lock);
928             ret = ipt_account_handle_get_data(handle.handle_nr, user);
929             spin_unlock_bh(&ipt_account_userspace_lock);
930             if (ret)
931             {
932                 printk("ACCOUNT: ipt_account_get_ctl: ipt_account_handle_get_data failed for handle %u\n", handle.handle_nr);
933                 break;
934             }
935             
936             ret = 0;
937             break;
938         case IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE:
939         {
940             if (*len < sizeof(struct ipt_account_handle_sockopt))
941             {
942                 printk("ACCOUNT: ipt_account_get_ctl: wrong data size (%u != %u) for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n",
943                        *len, sizeof(struct ipt_account_handle_sockopt));
944                 break;
945             }   
946             
947             // Find out how many handles are in use
948             unsigned int i;
949             handle.itemcount = 0;
950             spin_lock_bh(&ipt_account_userspace_lock);
951             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
952                 if (ipt_account_handles[i].data)
953                     handle.itemcount++;
954             spin_unlock_bh(&ipt_account_userspace_lock);
955             
956             if (copy_to_user(user, &handle, sizeof(struct ipt_account_handle_sockopt)))
957             {
958                 printk("ACCOUNT: ipt_account_set_ctl: copy_to_user failed for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n");
959                 break;
960             }
961             ret = 0;
962             break;
963         }
964         case IPT_SO_GET_ACCOUNT_GET_TABLE_NAMES:
965         {
966             spin_lock_bh(&ipt_account_lock);
967             
968             // Determine size of table names
969             unsigned int size = 0, i;
970             for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
971             {
972                 if (ipt_account_tables[i].name[0] != 0)
973                     size += strlen (ipt_account_tables[i].name) + 1;
974             }
975             size += 1;    // Terminating NULL character
976             
977             if (*len < size)
978             {
979                 spin_unlock_bh(&ipt_account_lock);
980                 printk("ACCOUNT: ipt_account_get_ctl: not enough space (%u < %u) to store table names\n", *len, size);
981                 ret = -ENOMEM;
982                 break;
983             }
984             // Copy table names to userspace
985             char *tnames = user;
986             for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
987             {
988                 if (ipt_account_tables[i].name[0] != 0)
989                 {
990                     int len = strlen (ipt_account_tables[i].name) + 1;
991                     copy_to_user(tnames, ipt_account_tables[i].name, len);    // copy string + terminating zero
992                     tnames += len;
993                 }
994             }
995             // Append terminating zero
996             i = 0;
997             copy_to_user(tnames, &i, 1);    
998             spin_unlock_bh(&ipt_account_lock);
999             ret = 0;
1000             break;
1001         }
1002         default:
1003             printk("ACCOUNT: ipt_account_get_ctl: unknown request %i\n", cmd);
1004     }
1005
1006     return ret;
1007 }
1008
1009 static struct ipt_target ipt_account_reg
1010 = { { NULL, NULL }, "ACCOUNT", ipt_account_target, ipt_account_checkentry, ipt_account_deleteentry, 
1011     THIS_MODULE };
1012
1013 static struct nf_sockopt_ops ipt_account_sockopts
1014 = { { NULL, NULL }, PF_INET, IPT_SO_SET_ACCOUNT_HANDLE_FREE, IPT_SO_SET_ACCOUNT_MAX+1, ipt_account_set_ctl,
1015     IPT_SO_GET_ACCOUNT_PREPARE_READ, IPT_SO_GET_ACCOUNT_MAX+1, ipt_account_get_ctl, 0, NULL  };
1016     
1017 static int __init init(void)
1018 {
1019     if (PAGE_SIZE < 4096)
1020     {
1021         printk("ACCOUNT: Sorry we need at least a PAGE_SIZE of 4096. Found: %lu\n", PAGE_SIZE);
1022         return -EINVAL;
1023     }
1024
1025     if ((ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table), GFP_KERNEL)) == NULL)
1026     {
1027         printk("ACCOUNT: Out of memory allocating account_tables structure");
1028         return -EINVAL;
1029     }
1030     memset(ipt_account_tables, 0, ACCOUNT_MAX_TABLES*sizeof(struct ipt_account_table));
1031                             
1032     if ((ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle), GFP_KERNEL)) == NULL)
1033     {
1034         printk("ACCOUNT: Out of memory allocating account_handles structure");
1035         kfree (ipt_account_tables);
1036         ipt_account_tables = NULL;
1037         return -EINVAL;
1038     }
1039     memset(ipt_account_handles, 0, ACCOUNT_MAX_HANDLES*sizeof(struct ipt_account_handle));
1040
1041     // Allocate one page as temporary storage
1042     if ((ipt_account_tmpbuf = (void*)__get_free_page(GFP_KERNEL)) == NULL)
1043     {
1044         printk("ACCOUNT: Out of memory for temporary buffer page\n");
1045         kfree(ipt_account_tables);
1046         kfree(ipt_account_handles);
1047         ipt_account_tables = NULL;
1048         ipt_account_handles = NULL;
1049         return -EINVAL;
1050     }
1051     
1052     /* Register setsockopt */
1053     if (nf_register_sockopt(&ipt_account_sockopts) < 0)
1054     {
1055         printk("ACCOUNT: Can't register sockopts. Aborting\n");
1056         
1057         kfree(ipt_account_tables);
1058         kfree(ipt_account_handles);
1059         free_page((unsigned long)ipt_account_tmpbuf);
1060         ipt_account_tables = NULL;
1061         ipt_account_handles = NULL;
1062         ipt_account_tmpbuf = NULL;
1063         
1064         return -EINVAL;
1065     }
1066     
1067     if (ipt_register_target(&ipt_account_reg))
1068             return -EINVAL;
1069
1070     return 0;
1071 }
1072
1073 static void __exit fini(void)
1074 {
1075     ipt_unregister_target(&ipt_account_reg);
1076
1077     nf_unregister_sockopt(&ipt_account_sockopts);
1078     
1079     kfree(ipt_account_tables);
1080     kfree(ipt_account_handles);
1081     free_page((unsigned long)ipt_account_tmpbuf);
1082     
1083     ipt_account_tables = NULL;
1084     ipt_account_handles = NULL;
1085     ipt_account_tmpbuf = NULL;
1086 }
1087
1088 module_init(init);
1089 module_exit(fini);
1090 MODULE_LICENSE("GPL");