ipt_ACCOUNT: (tomj) fix for kernel >= 2.6.19, compat glue for older kernels, simplifi...
[ipt_ACCOUNT] / linux-2.6 / net / ipv4 / netfilter / ipt_ACCOUNT.c
1 /***************************************************************************
2  *   This is a module which is used for counting packets.                  *
3  *   See http://www.intra2net.com/opensource/ipt_account                   *
4  *   for further information                                               *
5  *                                                                         * 
6  *   Copyright (C) 2004-2006 by Intra2net AG                               *
7  *   opensource@intra2net.com                                              *
8  *                                                                         *
9  *   This program is free software; you can redistribute it and/or modify  *
10  *   it under the terms of the GNU General Public License                  *
11  *   version 2 as published by the Free Software Foundation;               *
12  *                                                                         *
13  ***************************************************************************/
14
15 #include <linux/module.h>
16 #include <linux/skbuff.h>
17 #include <linux/ip.h>
18 #include <net/icmp.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <linux/netfilter_ipv4/ip_tables.h>
22 #include <asm/semaphore.h>
23 #include <linux/kernel.h>
24 #include <linux/mm.h>
25 #include <linux/string.h>
26 #include <linux/spinlock.h>
27 #include <asm/uaccess.h>
28
29 #include <net/route.h>
30 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
31
32 #if 0
33 #define DEBUGP printk
34 #else
35 #define DEBUGP(format, args...)
36 #endif
37
38 #if (PAGE_SIZE < 4096)
39 #error "ipt_ACCOUNT needs at least a PAGE_SIZE of 4096"
40 #endif
41
42 static struct ipt_acc_table *ipt_acc_tables = NULL;
43 static struct ipt_acc_handle *ipt_acc_handles = NULL;
44 static void *ipt_acc_tmpbuf = NULL;
45
46 /* Spinlock used for manipulating the current accounting tables/data */
47 static DEFINE_SPINLOCK(ipt_acc_lock);
48 /* Mutex (semaphore) used for manipulating userspace handles/snapshot data */
49 static struct semaphore ipt_acc_userspace_mutex;
50
51
52 /* Recursive free of all data structures */
53 static void ipt_acc_data_free(void *data, unsigned char depth)
54 {
55     /* Empty data set */
56     if (!data)
57         return;
58
59     /* Free for 8 bit network */
60     if (depth == 0) {
61         free_page((unsigned long)data);
62         return;
63     }
64
65     /* Free for 16 bit network */
66     if (depth == 1) {
67         struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16 *)data;
68         unsigned int b;
69         for (b=0; b <= 255; b++) {
70             if (mask_16->mask_24[b]) {
71                 free_page((unsigned long)mask_16->mask_24[b]);
72             }
73         }
74         free_page((unsigned long)data);
75         return;
76     }
77
78     /* Free for 24 bit network */
79     if (depth == 2) {
80         unsigned int a, b;
81         for (a=0; a <= 255; a++) {
82             if (((struct ipt_acc_mask_8 *)data)->mask_16[a]) {
83                 struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16*)
84                                    ((struct ipt_acc_mask_8 *)data)->mask_16[a];
85
86                 for (b=0; b <= 255; b++) {
87                     if (mask_16->mask_24[b]) {
88                         free_page((unsigned long)mask_16->mask_24[b]);
89                     }
90                 }
91                 free_page((unsigned long)mask_16);
92             }
93         }
94         free_page((unsigned long)data);
95         return;
96     }
97
98     printk("ACCOUNT: ipt_acc_data_free called with unknown depth: %d\n", 
99            depth);
100     return;
101 }
102
103 /* Look for existing table / insert new one. 
104    Return internal ID or -1 on error */
105 static int ipt_acc_table_insert(char *name, u_int32_t ip, u_int32_t netmask)
106 {
107     unsigned int i;
108
109     DEBUGP("ACCOUNT: ipt_acc_table_insert: %s, %u.%u.%u.%u/%u.%u.%u.%u\n",
110                                          name, NIPQUAD(ip), NIPQUAD(netmask));
111
112     /* Look for existing table */
113     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
114         if (strncmp(ipt_acc_tables[i].name, name, 
115                     ACCOUNT_TABLE_NAME_LEN) == 0) {
116             DEBUGP("ACCOUNT: Found existing slot: %d - "
117                    "%u.%u.%u.%u/%u.%u.%u.%u\n", i, 
118                    NIPQUAD(ipt_acc_tables[i].ip), 
119                    NIPQUAD(ipt_acc_tables[i].netmask));
120
121             if (ipt_acc_tables[i].ip != ip 
122                 || ipt_acc_tables[i].netmask != netmask) {
123                 printk("ACCOUNT: Table %s found, but IP/netmask mismatch. "
124                        "IP/netmask found: %u.%u.%u.%u/%u.%u.%u.%u\n",
125                        name, NIPQUAD(ipt_acc_tables[i].ip), 
126                        NIPQUAD(ipt_acc_tables[i].netmask));
127                 return -1;
128             }
129
130             ipt_acc_tables[i].refcount++;
131             DEBUGP("ACCOUNT: Refcount: %d\n", ipt_acc_tables[i].refcount);
132             return i;
133         }
134     }
135
136     /* Insert new table */
137     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
138         /* Found free slot */
139         if (ipt_acc_tables[i].name[0] == 0) {
140             unsigned int netsize=0;
141             u_int32_t calc_mask;
142             int j;  /* needs to be signed, otherwise we risk endless loop */
143
144             DEBUGP("ACCOUNT: Found free slot: %d\n", i);
145             strncpy (ipt_acc_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN-1);
146
147             ipt_acc_tables[i].ip = ip;
148             ipt_acc_tables[i].netmask = netmask;
149
150             /* Calculate netsize */
151             calc_mask = htonl(netmask);
152             for (j = 31; j >= 0; j--) {
153                 if (calc_mask&(1<<j))
154                     netsize++;
155                 else
156                     break;
157             }
158
159             /* Calculate depth from netsize */
160             if (netsize >= 24)
161                 ipt_acc_tables[i].depth = 0;
162             else if (netsize >= 16)
163                 ipt_acc_tables[i].depth = 1;
164             else if(netsize >= 8)
165                 ipt_acc_tables[i].depth = 2;
166
167             DEBUGP("ACCOUNT: calculated netsize: %u -> "
168                    "ipt_acc_table depth %u\n", netsize, 
169                    ipt_acc_tables[i].depth);
170
171             ipt_acc_tables[i].refcount++;
172             if ((ipt_acc_tables[i].data
173                 = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
174                 printk("ACCOUNT: out of memory for data of table: %s\n", name);
175                 memset(&ipt_acc_tables[i], 0, 
176                        sizeof(struct ipt_acc_table));
177                 return -1;
178             }
179
180             return i;
181         }
182     }
183
184     /* No free slot found */
185     printk("ACCOUNT: No free table slot found (max: %d). "
186            "Please increase ACCOUNT_MAX_TABLES.\n", ACCOUNT_MAX_TABLES);
187     return -1;
188 }
189
190 static int ipt_acc_checkentry(const char *tablename,
191 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,16)
192                               const void *e,
193 #else
194                               const struct ipt_entry *e,
195 #endif
196 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,17)
197                               const struct xt_target *target,
198 #endif
199                               void *targinfo,
200 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,19)
201                               unsigned int targinfosize,
202 #endif
203                               unsigned int hook_mask)
204 {
205     struct ipt_acc_info *info = targinfo;
206     int table_nr;
207
208 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
209     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_acc_info))) {
210         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
211                targinfosize, IPT_ALIGN(sizeof(struct ipt_acc_info)));
212         return 0;
213     }
214 #endif
215
216     spin_lock_bh(&ipt_acc_lock);
217     table_nr = ipt_acc_table_insert(info->table_name, info->net_ip,
218                                                       info->net_mask);
219     spin_unlock_bh(&ipt_acc_lock);
220
221     if (table_nr == -1) {
222         printk("ACCOUNT: Table insert problem. Aborting\n");
223         return 0;
224     }
225     /* Table nr caching so we don't have to do an extra string compare 
226        for every packet */
227     info->table_nr = table_nr;
228
229     return 1;
230 }
231
232 static void ipt_acc_destroy(
233 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,17)
234                             const struct xt_target *target,
235 #endif
236 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,19)
237                             void *targinfo)
238 #else
239                             void *targinfo,
240                             unsigned int targinfosize)
241 #endif
242 {
243     unsigned int i;
244     struct ipt_acc_info *info = targinfo;
245
246 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
247     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_acc_info))) {
248         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
249                targinfosize, IPT_ALIGN(sizeof(struct ipt_acc_info)));
250     }
251 #endif
252
253     spin_lock_bh(&ipt_acc_lock);
254
255     DEBUGP("ACCOUNT: ipt_acc_deleteentry called for table: %s (#%d)\n", 
256            info->table_name, info->table_nr);
257
258     info->table_nr = -1;    /* Set back to original state */
259
260     /* Look for table */
261     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
262         if (strncmp(ipt_acc_tables[i].name, info->table_name, 
263                     ACCOUNT_TABLE_NAME_LEN) == 0) {
264             DEBUGP("ACCOUNT: Found table at slot: %d\n", i);
265
266             ipt_acc_tables[i].refcount--;
267             DEBUGP("ACCOUNT: Refcount left: %d\n", 
268                    ipt_acc_tables[i].refcount);
269
270             /* Table not needed anymore? */
271             if (ipt_acc_tables[i].refcount == 0) {
272                 DEBUGP("ACCOUNT: Destroying table at slot: %d\n", i);
273                 ipt_acc_data_free(ipt_acc_tables[i].data, 
274                                       ipt_acc_tables[i].depth);
275                 memset(&ipt_acc_tables[i], 0, 
276                        sizeof(struct ipt_acc_table));
277             }
278
279             spin_unlock_bh(&ipt_acc_lock);
280             return;
281         }
282     }
283
284     /* Table not found */
285     printk("ACCOUNT: Table %s not found for destroy\n", info->table_name);
286     spin_unlock_bh(&ipt_acc_lock);
287 }
288
289 static void ipt_acc_depth0_insert(struct ipt_acc_mask_24 *mask_24,
290                                u_int32_t net_ip, u_int32_t netmask,
291                                u_int32_t src_ip, u_int32_t dst_ip,
292                                u_int32_t size, u_int32_t *itemcount)
293 {
294     unsigned char is_src = 0, is_dst = 0, src_slot, dst_slot;
295     char is_src_new_ip = 0, is_dst_new_ip = 0; /* Check if this entry is new */
296
297     DEBUGP("ACCOUNT: ipt_acc_depth0_insert: %u.%u.%u.%u/%u.%u.%u.%u "
298            "for net %u.%u.%u.%u/%u.%u.%u.%u, size: %u\n", NIPQUAD(src_ip), 
299            NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask), size);
300
301     /* Check if src/dst is inside our network. */
302     /* Special: net_ip = 0.0.0.0/0 gets stored as src in slot 0 */
303     if (!netmask)
304         src_ip = 0;
305     if ((net_ip&netmask) == (src_ip&netmask))
306         is_src = 1;
307     if ((net_ip&netmask) == (dst_ip&netmask) && netmask)
308         is_dst = 1;
309
310     if (!is_src && !is_dst) {
311         DEBUGP("ACCOUNT: Skipping packet %u.%u.%u.%u/%u.%u.%u.%u "
312                "for net %u.%u.%u.%u/%u.%u.%u.%u\n", NIPQUAD(src_ip), 
313                NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask));
314         return;
315     }
316
317     /* Calculate array positions */
318     src_slot = (unsigned char)((src_ip&0xFF000000) >> 24);
319     dst_slot = (unsigned char)((dst_ip&0xFF000000) >> 24);
320
321     /* Increase size counters */
322     if (is_src) {
323         /* Calculate network slot */
324         DEBUGP("ACCOUNT: Calculated SRC 8 bit network slot: %d\n", src_slot);
325         if (!mask_24->ip[src_slot].src_packets 
326             && !mask_24->ip[src_slot].dst_packets)
327             is_src_new_ip = 1;
328
329         mask_24->ip[src_slot].src_packets++;
330         mask_24->ip[src_slot].src_bytes+=size;
331     }
332     if (is_dst) {
333         DEBUGP("ACCOUNT: Calculated DST 8 bit network slot: %d\n", dst_slot);
334         if (!mask_24->ip[dst_slot].src_packets 
335             && !mask_24->ip[dst_slot].dst_packets)
336             is_dst_new_ip = 1;
337
338         mask_24->ip[dst_slot].dst_packets++;
339         mask_24->ip[dst_slot].dst_bytes+=size;
340     }
341
342     /* Increase itemcounter */
343     DEBUGP("ACCOUNT: Itemcounter before: %d\n", *itemcount);
344     if (src_slot == dst_slot) {
345         if (is_src_new_ip || is_dst_new_ip) {
346             DEBUGP("ACCOUNT: src_slot == dst_slot: %d, %d\n", 
347                    is_src_new_ip, is_dst_new_ip);
348             (*itemcount)++;
349         }
350     } else {
351         if (is_src_new_ip) {
352             DEBUGP("ACCOUNT: New src_ip: %u.%u.%u.%u\n", NIPQUAD(src_ip));
353             (*itemcount)++;
354         }
355         if (is_dst_new_ip) {
356             DEBUGP("ACCOUNT: New dst_ip: %u.%u.%u.%u\n", NIPQUAD(dst_ip));
357             (*itemcount)++;
358         }
359     }
360     DEBUGP("ACCOUNT: Itemcounter after: %d\n", *itemcount);
361 }
362
363 static void ipt_acc_depth1_insert(struct ipt_acc_mask_16 *mask_16, 
364                                u_int32_t net_ip, u_int32_t netmask, 
365                                u_int32_t src_ip, u_int32_t dst_ip,
366                                u_int32_t size, u_int32_t *itemcount)
367 {
368     /* Do we need to process src IP? */
369     if ((net_ip&netmask) == (src_ip&netmask)) {
370         unsigned char slot = (unsigned char)((src_ip&0x00FF0000) >> 16);
371         DEBUGP("ACCOUNT: Calculated SRC 16 bit network slot: %d\n", slot);
372
373         /* Do we need to create a new mask_24 bucket? */
374         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = 
375              (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
376             printk("ACCOUNT: Can't process packet because out of memory!\n");
377             return;
378         }
379
380         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
381                                   net_ip, netmask, src_ip, 0, size, itemcount);
382     }
383
384     /* Do we need to process dst IP? */
385     if ((net_ip&netmask) == (dst_ip&netmask)) {
386         unsigned char slot = (unsigned char)((dst_ip&0x00FF0000) >> 16);
387         DEBUGP("ACCOUNT: Calculated DST 16 bit network slot: %d\n", slot);
388
389         /* Do we need to create a new mask_24 bucket? */
390         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] 
391             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
392             printk("ACCOUT: Can't process packet because out of memory!\n");
393             return;
394         }
395
396         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
397                                   net_ip, netmask, 0, dst_ip, size, itemcount);
398     }
399 }
400
401 static void ipt_acc_depth2_insert(struct ipt_acc_mask_8 *mask_8, 
402                                u_int32_t net_ip, u_int32_t netmask,
403                                u_int32_t src_ip, u_int32_t dst_ip,
404                                u_int32_t size, u_int32_t *itemcount)
405 {
406     /* Do we need to process src IP? */
407     if ((net_ip&netmask) == (src_ip&netmask)) {
408         unsigned char slot = (unsigned char)((src_ip&0x0000FF00) >> 8);
409         DEBUGP("ACCOUNT: Calculated SRC 24 bit network slot: %d\n", slot);
410
411         /* Do we need to create a new mask_24 bucket? */
412         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
413             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
414             printk("ACCOUNT: Can't process packet because out of memory!\n");
415             return;
416         }
417
418         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
419                                   net_ip, netmask, src_ip, 0, size, itemcount);
420     }
421
422     /* Do we need to process dst IP? */
423     if ((net_ip&netmask) == (dst_ip&netmask)) {
424         unsigned char slot = (unsigned char)((dst_ip&0x0000FF00) >> 8);
425         DEBUGP("ACCOUNT: Calculated DST 24 bit network slot: %d\n", slot);
426
427         /* Do we need to create a new mask_24 bucket? */
428         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
429             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
430             printk("ACCOUNT: Can't process packet because out of memory!\n");
431             return;
432         }
433
434         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
435                                   net_ip, netmask, 0, dst_ip, size, itemcount);
436     }
437 }
438
439 static unsigned int ipt_acc_target(struct sk_buff **pskb,
440                                    const struct net_device *in,
441                                    const struct net_device *out,
442                                    unsigned int hooknum,
443 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,17)
444                                    const struct xt_target *target,
445 #endif
446 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,19)
447                                    const void *targinfo)
448 #else
449                                    const void *targinfo,
450                                    void *userinfo)
451 #endif
452 {
453     const struct ipt_acc_info *info = 
454         (const struct ipt_acc_info *)targinfo;
455     u_int32_t src_ip = (*pskb)->nh.iph->saddr;
456     u_int32_t dst_ip = (*pskb)->nh.iph->daddr;
457     u_int32_t size = ntohs((*pskb)->nh.iph->tot_len);
458
459     spin_lock_bh(&ipt_acc_lock);
460
461     if (ipt_acc_tables[info->table_nr].name[0] == 0) {
462         printk("ACCOUNT: ipt_acc_target: Invalid table id %u. "
463                "IPs %u.%u.%u.%u/%u.%u.%u.%u\n", info->table_nr, 
464                NIPQUAD(src_ip), NIPQUAD(dst_ip));
465         spin_unlock_bh(&ipt_acc_lock);
466         return IPT_CONTINUE;
467     }
468
469     /* 8 bit network or "any" network */
470     if (ipt_acc_tables[info->table_nr].depth == 0) {
471         /* Count packet and check if the IP is new */
472         ipt_acc_depth0_insert(
473             (struct ipt_acc_mask_24 *)ipt_acc_tables[info->table_nr].data,
474             ipt_acc_tables[info->table_nr].ip, 
475             ipt_acc_tables[info->table_nr].netmask,
476             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
477         spin_unlock_bh(&ipt_acc_lock);
478         return IPT_CONTINUE;
479     }
480
481     /* 16 bit network */
482     if (ipt_acc_tables[info->table_nr].depth == 1) {
483         ipt_acc_depth1_insert(
484             (struct ipt_acc_mask_16 *)ipt_acc_tables[info->table_nr].data,
485             ipt_acc_tables[info->table_nr].ip, 
486             ipt_acc_tables[info->table_nr].netmask,
487             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
488         spin_unlock_bh(&ipt_acc_lock);
489         return IPT_CONTINUE;
490     }
491
492     /* 24 bit network */
493     if (ipt_acc_tables[info->table_nr].depth == 2) {
494         ipt_acc_depth2_insert(
495             (struct ipt_acc_mask_8 *)ipt_acc_tables[info->table_nr].data,
496             ipt_acc_tables[info->table_nr].ip, 
497             ipt_acc_tables[info->table_nr].netmask,
498             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
499         spin_unlock_bh(&ipt_acc_lock);
500         return IPT_CONTINUE;
501     }
502
503     printk("ACCOUNT: ipt_acc_target: Unable to process packet. "
504            "Table id %u. IPs %u.%u.%u.%u/%u.%u.%u.%u\n", 
505            info->table_nr, NIPQUAD(src_ip), NIPQUAD(dst_ip));
506
507     spin_unlock_bh(&ipt_acc_lock);
508     return IPT_CONTINUE;
509 }
510
511 /*
512     Functions dealing with "handles":
513     Handles are snapshots of a accounting state.
514     
515     read snapshots are only for debugging the code
516     and are very expensive concerning speed/memory
517     compared to read_and_flush.
518     
519     The functions aren't protected by spinlocks themselves
520     as this is done in the ioctl part of the code.
521 */
522
523 /*
524     Find a free handle slot. Normally only one should be used,
525     but there could be two or more applications accessing the data
526     at the same time.
527 */
528 static int ipt_acc_handle_find_slot(void)
529 {
530     unsigned int i;
531     /* Insert new table */
532     for (i = 0; i < ACCOUNT_MAX_HANDLES; i++) {
533         /* Found free slot */
534         if (ipt_acc_handles[i].data == NULL) {
535             /* Don't "mark" data as used as we are protected by a spinlock 
536                by the calling function. handle_find_slot() is only a function
537                to prevent code duplication. */
538             return i;
539         }
540     }
541
542     /* No free slot found */
543     printk("ACCOUNT: No free handle slot found (max: %u). "
544            "Please increase ACCOUNT_MAX_HANDLES.\n", ACCOUNT_MAX_HANDLES);
545     return -1;
546 }
547
548 static int ipt_acc_handle_free(unsigned int handle)
549 {
550     if (handle >= ACCOUNT_MAX_HANDLES) {
551         printk("ACCOUNT: Invalid handle for ipt_acc_handle_free() specified:"
552                " %u\n", handle);
553         return -EINVAL;
554     }
555
556     ipt_acc_data_free(ipt_acc_handles[handle].data, 
557                           ipt_acc_handles[handle].depth);
558     memset (&ipt_acc_handles[handle], 0, sizeof (struct ipt_acc_handle));
559     return 0;
560 }
561
562 /* Prepare data for read without flush. Use only for debugging!
563    Real applications should use read&flush as it's way more efficent */
564 static int ipt_acc_handle_prepare_read(char *tablename,
565          struct ipt_acc_handle *dest, u_int32_t *count)
566 {
567     int table_nr=-1;
568     unsigned char depth;
569
570     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
571         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
572             ACCOUNT_TABLE_NAME_LEN) == 0)
573                 break;
574
575     if (table_nr == ACCOUNT_MAX_TABLES) {
576         printk("ACCOUNT: ipt_acc_handle_prepare_read(): "
577                "Table %s not found\n", tablename);
578         return -1;
579     }
580
581     /* Fill up handle structure */
582     dest->ip = ipt_acc_tables[table_nr].ip;
583     dest->depth = ipt_acc_tables[table_nr].depth;
584     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
585
586     /* allocate "root" table */
587     if ((dest->data = (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
588         printk("ACCOUNT: out of memory for root table "
589                "in ipt_acc_handle_prepare_read()\n");
590         return -1;
591     }
592
593     /* Recursive copy of complete data structure */
594     depth = dest->depth;
595     if (depth == 0) {
596         memcpy(dest->data, 
597                ipt_acc_tables[table_nr].data, 
598                sizeof(struct ipt_acc_mask_24));
599     } else if (depth == 1) {
600         struct ipt_acc_mask_16 *src_16 = 
601             (struct ipt_acc_mask_16 *)ipt_acc_tables[table_nr].data;
602         struct ipt_acc_mask_16 *network_16 =
603             (struct ipt_acc_mask_16 *)dest->data;
604         unsigned int b;
605
606         for (b = 0; b <= 255; b++) {
607             if (src_16->mask_24[b]) {
608                 if ((network_16->mask_24[b] = 
609                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
610                     printk("ACCOUNT: out of memory during copy of 16 bit "
611                            "network in ipt_acc_handle_prepare_read()\n");
612                     ipt_acc_data_free(dest->data, depth);
613                     return -1;
614                 }
615
616                 memcpy(network_16->mask_24[b], src_16->mask_24[b], 
617                        sizeof(struct ipt_acc_mask_24));
618             }
619         }
620     } else if(depth == 2) {
621         struct ipt_acc_mask_8 *src_8 = 
622             (struct ipt_acc_mask_8 *)ipt_acc_tables[table_nr].data;
623         struct ipt_acc_mask_8 *network_8 = 
624             (struct ipt_acc_mask_8 *)dest->data;
625         struct ipt_acc_mask_16 *src_16, *network_16;
626         unsigned int a, b;
627
628         for (a = 0; a <= 255; a++) {
629             if (src_8->mask_16[a]) {
630                 if ((network_8->mask_16[a] = 
631                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
632                     printk("ACCOUNT: out of memory during copy of 24 bit network"
633                            " in ipt_acc_handle_prepare_read()\n");
634                     ipt_acc_data_free(dest->data, depth);
635                     return -1;
636                 }
637
638                 memcpy(network_8->mask_16[a], src_8->mask_16[a], 
639                        sizeof(struct ipt_acc_mask_16));
640
641                 src_16 = src_8->mask_16[a];
642                 network_16 = network_8->mask_16[a];
643
644                 for (b = 0; b <= 255; b++) {
645                     if (src_16->mask_24[b]) {
646                         if ((network_16->mask_24[b] = 
647                              (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
648                             printk("ACCOUNT: out of memory during copy of 16 bit"
649                                    " network in ipt_acc_handle_prepare_read()\n");
650                             ipt_acc_data_free(dest->data, depth);
651                             return -1;
652                         }
653
654                         memcpy(network_16->mask_24[b], src_16->mask_24[b], 
655                                sizeof(struct ipt_acc_mask_24));
656                     }
657                 }
658             }
659         }
660     }
661
662     *count = ipt_acc_tables[table_nr].itemcount;
663     
664     return 0;
665 }
666
667 /* Prepare data for read and flush it */
668 static int ipt_acc_handle_prepare_read_flush(char *tablename,
669                struct ipt_acc_handle *dest, u_int32_t *count)
670 {
671     int table_nr;
672     void *new_data_page;
673
674     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
675         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
676             ACCOUNT_TABLE_NAME_LEN) == 0)
677                 break;
678
679     if (table_nr == ACCOUNT_MAX_TABLES) {
680         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
681                "Table %s not found\n", tablename);
682         return -1;
683     }
684
685     /* Try to allocate memory */
686     if (!(new_data_page = (void*)get_zeroed_page(GFP_ATOMIC))) {
687         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
688                "Out of memory!\n");
689         return -1;
690     }
691
692     /* Fill up handle structure */
693     dest->ip = ipt_acc_tables[table_nr].ip;
694     dest->depth = ipt_acc_tables[table_nr].depth;
695     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
696     dest->data = ipt_acc_tables[table_nr].data;
697     *count = ipt_acc_tables[table_nr].itemcount;
698
699     /* "Flush" table data */
700     ipt_acc_tables[table_nr].data = new_data_page;
701     ipt_acc_tables[table_nr].itemcount = 0;
702
703     return 0;
704 }
705
706 /* Copy 8 bit network data into a prepared buffer.
707    We only copy entries != 0 to increase performance.
708 */
709 static int ipt_acc_handle_copy_data(void *to_user, unsigned long *to_user_pos,
710                                   unsigned long *tmpbuf_pos, 
711                                   struct ipt_acc_mask_24 *data,
712                                   u_int32_t net_ip, u_int32_t net_OR_mask)
713 {
714     struct ipt_acc_handle_ip handle_ip;
715     size_t handle_ip_size = sizeof (struct ipt_acc_handle_ip);
716     unsigned int i;
717     
718     for (i = 0; i <= 255; i++) {
719         if (data->ip[i].src_packets || data->ip[i].dst_packets) {
720             handle_ip.ip = net_ip | net_OR_mask | (i<<24);
721             
722             handle_ip.src_packets = data->ip[i].src_packets;
723             handle_ip.src_bytes = data->ip[i].src_bytes;
724             handle_ip.dst_packets = data->ip[i].dst_packets;
725             handle_ip.dst_bytes = data->ip[i].dst_bytes;
726
727             /* Temporary buffer full? Flush to userspace */
728             if (*tmpbuf_pos+handle_ip_size >= PAGE_SIZE) {
729                 if (copy_to_user(to_user + *to_user_pos, ipt_acc_tmpbuf,
730                                                            *tmpbuf_pos))
731                     return -EFAULT;
732                 *to_user_pos = *to_user_pos + *tmpbuf_pos;
733                 *tmpbuf_pos = 0;
734             }
735             memcpy(ipt_acc_tmpbuf+*tmpbuf_pos, &handle_ip, handle_ip_size);
736             *tmpbuf_pos += handle_ip_size;
737         }
738     }
739     
740     return 0;
741 }
742    
743 /* Copy the data from our internal structure 
744    We only copy entries != 0 to increase performance.
745    Overwrites ipt_acc_tmpbuf.
746 */
747 static int ipt_acc_handle_get_data(u_int32_t handle, void *to_user)
748 {
749     unsigned long to_user_pos=0, tmpbuf_pos=0;
750     u_int32_t net_ip;
751     unsigned char depth;
752
753     if (handle >= ACCOUNT_MAX_HANDLES) {
754         printk("ACCOUNT: invalid handle for ipt_acc_handle_get_data() "
755                "specified: %u\n", handle);
756         return -1;
757     }
758
759     if (ipt_acc_handles[handle].data == NULL) {
760         printk("ACCOUNT: handle %u is BROKEN: Contains no data\n", handle);
761         return -1;
762     }
763
764     net_ip = ipt_acc_handles[handle].ip;
765     depth = ipt_acc_handles[handle].depth;
766
767     /* 8 bit network */
768     if (depth == 0) {
769         struct ipt_acc_mask_24 *network = 
770             (struct ipt_acc_mask_24*)ipt_acc_handles[handle].data;
771         if (ipt_acc_handle_copy_data(to_user, &to_user_pos, &tmpbuf_pos,
772                                      network, net_ip, 0))
773             return -1;
774         
775         /* Flush remaining data to userspace */
776         if (tmpbuf_pos)
777             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
778                 return -1;
779
780         return 0;
781     }
782
783     /* 16 bit network */
784     if (depth == 1) {
785         struct ipt_acc_mask_16 *network_16 = 
786             (struct ipt_acc_mask_16*)ipt_acc_handles[handle].data;
787         unsigned int b;
788         for (b = 0; b <= 255; b++) {
789             if (network_16->mask_24[b]) {
790                 struct ipt_acc_mask_24 *network = 
791                     (struct ipt_acc_mask_24*)network_16->mask_24[b];
792                 if (ipt_acc_handle_copy_data(to_user, &to_user_pos,
793                                       &tmpbuf_pos, network, net_ip, (b << 16)))
794                     return -1;
795             }
796         }
797
798         /* Flush remaining data to userspace */
799         if (tmpbuf_pos)
800             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
801                 return -1;
802
803         return 0;
804     }
805
806     /* 24 bit network */
807     if (depth == 2) {
808         struct ipt_acc_mask_8 *network_8 = 
809             (struct ipt_acc_mask_8*)ipt_acc_handles[handle].data;
810         unsigned int a, b;
811         for (a = 0; a <= 255; a++) {
812             if (network_8->mask_16[a]) {
813                 struct ipt_acc_mask_16 *network_16 = 
814                     (struct ipt_acc_mask_16*)network_8->mask_16[a];
815                 for (b = 0; b <= 255; b++) {
816                     if (network_16->mask_24[b]) {
817                         struct ipt_acc_mask_24 *network = 
818                             (struct ipt_acc_mask_24*)network_16->mask_24[b];
819                         if (ipt_acc_handle_copy_data(to_user,
820                                        &to_user_pos, &tmpbuf_pos,
821                                        network, net_ip, (a << 8) | (b << 16)))
822                             return -1;
823                     }
824                 }
825             }
826         }
827
828         /* Flush remaining data to userspace */
829         if (tmpbuf_pos)
830             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
831                 return -1;
832
833         return 0;
834     }
835     
836     return -1;
837 }
838
839 static int ipt_acc_set_ctl(struct sock *sk, int cmd, 
840                                void *user, unsigned int len)
841 {
842     struct ipt_acc_handle_sockopt handle;
843     int ret = -EINVAL;
844
845     if (!capable(CAP_NET_ADMIN))
846         return -EPERM;
847
848     switch (cmd) {
849     case IPT_SO_SET_ACCOUNT_HANDLE_FREE:
850         if (len != sizeof(struct ipt_acc_handle_sockopt)) {
851             printk("ACCOUNT: ipt_acc_set_ctl: wrong data size (%u != %lu) "
852                    "for IPT_SO_SET_HANDLE_FREE\n", 
853                    len, sizeof(struct ipt_acc_handle_sockopt));
854             break;
855         }
856
857         if (copy_from_user (&handle, user, len)) {
858             printk("ACCOUNT: ipt_acc_set_ctl: copy_from_user failed for "
859                    "IPT_SO_SET_HANDLE_FREE\n");
860             break;
861         }
862
863         down(&ipt_acc_userspace_mutex);
864         ret = ipt_acc_handle_free(handle.handle_nr);
865         up(&ipt_acc_userspace_mutex);
866         break;
867     case IPT_SO_SET_ACCOUNT_HANDLE_FREE_ALL: {
868             unsigned int i;
869             down(&ipt_acc_userspace_mutex);
870             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
871                 ipt_acc_handle_free(i);
872             up(&ipt_acc_userspace_mutex);
873             ret = 0;
874             break;
875         }
876     default:
877         printk("ACCOUNT: ipt_acc_set_ctl: unknown request %i\n", cmd);
878     }
879
880     return ret;
881 }
882
883 static int ipt_acc_get_ctl(struct sock *sk, int cmd, void *user, int *len)
884 {
885     struct ipt_acc_handle_sockopt handle;
886     int ret = -EINVAL;
887
888     if (!capable(CAP_NET_ADMIN))
889         return -EPERM;
890
891     switch (cmd) {
892     case IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH:
893     case IPT_SO_GET_ACCOUNT_PREPARE_READ: {
894             struct ipt_acc_handle dest;
895
896             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
897                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %lu) "
898                     "for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
899                     *len, sizeof(struct ipt_acc_handle_sockopt));
900                 break;
901             }
902
903             if (copy_from_user (&handle, user, 
904                                 sizeof(struct ipt_acc_handle_sockopt))) {
905                 return -EFAULT;
906                 break;
907             }
908
909             spin_lock_bh(&ipt_acc_lock);
910             if (cmd == IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH)
911                 ret = ipt_acc_handle_prepare_read_flush(
912                                     handle.name, &dest, &handle.itemcount);
913             else
914                 ret = ipt_acc_handle_prepare_read(
915                                     handle.name, &dest, &handle.itemcount);
916             spin_unlock_bh(&ipt_acc_lock);
917             // Error occured during prepare_read?
918            if (ret == -1)
919                 return -EINVAL;
920
921             /* Allocate a userspace handle */
922             down(&ipt_acc_userspace_mutex);
923             if ((handle.handle_nr = ipt_acc_handle_find_slot()) == -1) {
924                 ipt_acc_data_free(dest.data, dest.depth);
925                 up(&ipt_acc_userspace_mutex);
926                 return -EINVAL;
927             }
928             memcpy(&ipt_acc_handles[handle.handle_nr], &dest,
929                              sizeof(struct ipt_acc_handle));
930             up(&ipt_acc_userspace_mutex);
931
932             if (copy_to_user(user, &handle, 
933                             sizeof(struct ipt_acc_handle_sockopt))) {
934                 return -EFAULT;
935                 break;
936             }
937             ret = 0;
938             break;
939         }
940     case IPT_SO_GET_ACCOUNT_GET_DATA:
941         if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
942             printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %lu)"
943                    " for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
944                    *len, sizeof(struct ipt_acc_handle_sockopt));
945             break;
946         }
947
948         if (copy_from_user (&handle, user, 
949                             sizeof(struct ipt_acc_handle_sockopt))) {
950             return -EFAULT;
951             break;
952         }
953
954         if (handle.handle_nr >= ACCOUNT_MAX_HANDLES) {
955             return -EINVAL;
956             break;
957         }
958
959         if (*len < ipt_acc_handles[handle.handle_nr].itemcount
960                    * sizeof(struct ipt_acc_handle_ip)) {
961             printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %lu)"
962                    " to store data from IPT_SO_GET_ACCOUNT_GET_DATA\n",
963                    *len, ipt_acc_handles[handle.handle_nr].itemcount
964                    * sizeof(struct ipt_acc_handle_ip));
965             ret = -ENOMEM;
966             break;
967         }
968
969         down(&ipt_acc_userspace_mutex);
970         ret = ipt_acc_handle_get_data(handle.handle_nr, user);
971         up(&ipt_acc_userspace_mutex);
972         if (ret) {
973             printk("ACCOUNT: ipt_acc_get_ctl: ipt_acc_handle_get_data"
974                    " failed for handle %u\n", handle.handle_nr);
975             break;
976         }
977
978         ret = 0;
979         break;
980     case IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE: {
981             unsigned int i;
982             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
983                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %lu)"
984                        " for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n",
985                        *len, sizeof(struct ipt_acc_handle_sockopt));
986                 break;
987             }
988
989             /* Find out how many handles are in use */
990             handle.itemcount = 0;
991             down(&ipt_acc_userspace_mutex);
992             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
993                 if (ipt_acc_handles[i].data)
994                     handle.itemcount++;
995             up(&ipt_acc_userspace_mutex);
996
997             if (copy_to_user(user, &handle, 
998                              sizeof(struct ipt_acc_handle_sockopt))) {
999                 return -EFAULT;
1000                 break;
1001             }
1002             ret = 0;
1003             break;
1004         }
1005     case IPT_SO_GET_ACCOUNT_GET_TABLE_NAMES: {
1006             u_int32_t size = 0, i, name_len;
1007             char *tnames;
1008
1009             spin_lock_bh(&ipt_acc_lock);
1010
1011             /* Determine size of table names */
1012             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
1013                 if (ipt_acc_tables[i].name[0] != 0)
1014                     size += strlen (ipt_acc_tables[i].name) + 1;
1015             }
1016             size += 1;    /* Terminating NULL character */
1017
1018             if (*len < size || size > PAGE_SIZE) {
1019                 spin_unlock_bh(&ipt_acc_lock);
1020                 printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %u < %lu)"
1021                        " to store table names\n", *len, size, PAGE_SIZE);
1022                 ret = -ENOMEM;
1023                 break;
1024             }
1025             /* Copy table names to userspace */
1026             tnames = ipt_acc_tmpbuf;
1027             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
1028                 if (ipt_acc_tables[i].name[0] != 0) {
1029                     name_len = strlen (ipt_acc_tables[i].name) + 1;
1030                     memcpy(tnames, ipt_acc_tables[i].name, name_len);
1031                     tnames += name_len;
1032                 }
1033             }
1034             spin_unlock_bh(&ipt_acc_lock);
1035
1036             /* Terminating NULL character */
1037             *tnames = 0;
1038
1039             /* Transfer to userspace */
1040             if (copy_to_user(user, ipt_acc_tmpbuf, size))
1041                 return -EFAULT;
1042
1043             ret = 0;
1044             break;
1045         }
1046     default:
1047         printk("ACCOUNT: ipt_acc_get_ctl: unknown request %i\n", cmd);
1048     }
1049
1050     return ret;
1051 }
1052
1053 static struct ipt_target ipt_acc_reg = {
1054     .name = "ACCOUNT",
1055     .target = ipt_acc_target,
1056 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,17)
1057     .targetsize = sizeof(struct ipt_acc_info),
1058 #endif
1059     .checkentry = ipt_acc_checkentry,
1060     .destroy = ipt_acc_destroy,
1061     .me = THIS_MODULE
1062 };
1063
1064 static struct nf_sockopt_ops ipt_acc_sockopts = {
1065     .pf = PF_INET,
1066     .set_optmin = IPT_SO_SET_ACCOUNT_HANDLE_FREE,
1067     .set_optmax = IPT_SO_SET_ACCOUNT_MAX+1,
1068     .set = ipt_acc_set_ctl,
1069     .get_optmin = IPT_SO_GET_ACCOUNT_PREPARE_READ,
1070     .get_optmax = IPT_SO_GET_ACCOUNT_MAX+1,
1071     .get = ipt_acc_get_ctl
1072 };
1073
1074 static int __init init(void)
1075 {
1076     init_MUTEX(&ipt_acc_userspace_mutex);
1077
1078     if ((ipt_acc_tables = 
1079          kmalloc(ACCOUNT_MAX_TABLES * 
1080                  sizeof(struct ipt_acc_table), GFP_KERNEL)) == NULL) {
1081         printk("ACCOUNT: Out of memory allocating account_tables structure");
1082         goto error_cleanup;
1083     }
1084     memset(ipt_acc_tables, 0, 
1085            ACCOUNT_MAX_TABLES * sizeof(struct ipt_acc_table));
1086
1087     if ((ipt_acc_handles = 
1088          kmalloc(ACCOUNT_MAX_HANDLES * 
1089                  sizeof(struct ipt_acc_handle), GFP_KERNEL)) == NULL) {
1090         printk("ACCOUNT: Out of memory allocating account_handles structure");
1091         goto error_cleanup;
1092     }
1093     memset(ipt_acc_handles, 0, 
1094            ACCOUNT_MAX_HANDLES * sizeof(struct ipt_acc_handle));
1095
1096     /* Allocate one page as temporary storage */
1097     if ((ipt_acc_tmpbuf = (void*)__get_free_page(GFP_KERNEL)) == NULL) {
1098         printk("ACCOUNT: Out of memory for temporary buffer page\n");
1099         goto error_cleanup;
1100     }
1101
1102     /* Register setsockopt */
1103     if (nf_register_sockopt(&ipt_acc_sockopts) < 0) {
1104         printk("ACCOUNT: Can't register sockopts. Aborting\n");
1105         goto error_cleanup;
1106     }
1107
1108     if (ipt_register_target(&ipt_acc_reg))
1109         goto error_cleanup;
1110
1111     return 0;
1112
1113 error_cleanup:
1114     if(ipt_acc_tables)
1115         kfree(ipt_acc_tables);
1116     if(ipt_acc_handles)
1117         kfree(ipt_acc_handles);
1118     if (ipt_acc_tmpbuf)
1119         free_page((unsigned long)ipt_acc_tmpbuf);
1120
1121     return -EINVAL;
1122 }
1123
1124 static void __exit fini(void)
1125 {
1126     ipt_unregister_target(&ipt_acc_reg);
1127
1128     nf_unregister_sockopt(&ipt_acc_sockopts);
1129
1130     kfree(ipt_acc_tables);
1131     kfree(ipt_acc_handles);
1132     free_page((unsigned long)ipt_acc_tmpbuf);
1133 }
1134
1135 module_init(init);
1136 module_exit(fini);
1137 MODULE_LICENSE("GPL");