ipt_ACCOUNT: (tomj) fixed & greatly improved locking
[ipt_ACCOUNT] / linux / net / ipv4 / netfilter / ipt_ACCOUNT.c
1 /***************************************************************************
2  *   This is a module which is used for counting packets.                  *
3  *   See http://www.intra2net.com/opensource/ipt_account                   *
4  *   for further information                                               *
5  *                                                                         * 
6  *   Copyright (C) 2004-2005 by Intra2net AG                               *
7  *   opensource@intra2net.com                                              *
8  *                                                                         *
9  *   This program is free software; you can redistribute it and/or modify  *
10  *   it under the terms of the GNU General Public License                  *
11  *   version 2 as published by the Free Software Foundation;               *
12  *                                                                         *
13  ***************************************************************************/
14
15 #include <linux/module.h>
16 #include <linux/skbuff.h>
17 #include <linux/ip.h>
18 #include <net/icmp.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <linux/netfilter_ipv4/ip_tables.h>
22 #include <linux/netfilter_ipv4/lockhelp.h>
23 #include <asm/semaphore.h>
24 #include <linux/kernel.h>
25 #include <linux/mm.h>
26 #include <linux/string.h>
27 #include <asm/uaccess.h>
28
29 #include <net/route.h>
30 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
31
32 #if 0
33 #define DEBUGP printk
34 #else
35 #define DEBUGP(format, args...)
36 #endif
37
38 #if (PAGE_SIZE < 4096)
39 #error "ipt_ACCOUNT needs at least a PAGE_SIZE of 4096"
40 #endif
41
42 static struct ipt_acc_table *ipt_acc_tables = NULL;
43 static struct ipt_acc_handle *ipt_acc_handles = NULL;
44 static void *ipt_acc_tmpbuf = NULL;
45
46 /* Spinlock used for manipulating the current accounting tables/data */
47 DECLARE_LOCK(ipt_acc_lock);
48 /* Mutex (semaphore) used for manipulating userspace handles/snapshot data */
49 static struct semaphore ipt_acc_userspace_mutex;
50
51
52 /* Recursive free of all data structures */
53 static void ipt_acc_data_free(void *data, unsigned char depth)
54 {
55     /* Empty data set */
56     if (!data)
57         return;
58
59     /* Free for 8 bit network */
60     if (depth == 0) {
61         free_page((unsigned long)data);
62         return;
63     }
64
65     /* Free for 16 bit network */
66     if (depth == 1) {
67         struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16 *)data;
68         u_int32_t b;
69         for (b=0; b <= 255; b++) {
70             if (mask_16->mask_24[b] != 0) {
71                 free_page((unsigned long)mask_16->mask_24[b]);
72             }
73         }
74         free_page((unsigned long)data);
75         return;
76     }
77
78     /* Free for 24 bit network */
79     if (depth == 2) {
80         u_int32_t a, b;
81         for (a=0; a <= 255; a++) {
82             if (((struct ipt_acc_mask_8 *)data)->mask_16[a]) {
83                 struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16*)
84                                    ((struct ipt_acc_mask_8 *)data)->mask_16[a];
85                 
86                 for (b=0; b <= 255; b++) {
87                     if (mask_16->mask_24[b]) {
88                         free_page((unsigned long)mask_16->mask_24[b]);
89                     }
90                 }
91                 free_page((unsigned long)mask_16);
92             }
93         }
94         free_page((unsigned long)data);
95         return;
96     }
97
98     printk("ACCOUNT: ipt_acc_data_free called with unknown depth: %d\n", 
99            depth);
100     return;
101 }
102
103 /* Look for existing table / insert new one. 
104    Return internal ID or -1 on error */
105 static int ipt_acc_table_insert(char *name, u_int32_t ip, u_int32_t netmask)
106 {
107     u_int32_t i;
108
109     DEBUGP("ACCOUNT: ipt_acc_table_insert: %s, %u.%u.%u.%u/%u.%u.%u.%u\n",
110                                          name, NIPQUAD(ip), NIPQUAD(netmask));
111
112     /* Look for existing table */
113     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
114         if (strncmp(ipt_acc_tables[i].name, name, 
115                     ACCOUNT_TABLE_NAME_LEN) == 0) {
116             DEBUGP("ACCOUNT: Found existing slot: %d - "
117                    "%u.%u.%u.%u/%u.%u.%u.%u\n", i, 
118                    NIPQUAD(ipt_acc_tables[i].ip), 
119                    NIPQUAD(ipt_acc_tables[i].netmask));
120
121             if (ipt_acc_tables[i].ip != ip 
122                 || ipt_acc_tables[i].netmask != netmask) {
123                 printk("ACCOUNT: Table %s found, but IP/netmask mismatch. "
124                        "IP/netmask found: %u.%u.%u.%u/%u.%u.%u.%u\n",
125                        name, NIPQUAD(ipt_acc_tables[i].ip), 
126                        NIPQUAD(ipt_acc_tables[i].netmask));
127                 return -1;
128             }
129
130             ipt_acc_tables[i].refcount++;
131             DEBUGP("ACCOUNT: Refcount: %d\n", ipt_acc_tables[i].refcount);
132             return i;
133         }
134     }
135
136     /* Insert new table */
137     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
138         /* Found free slot */
139         if (ipt_acc_tables[i].name[0] == 0) {
140             u_int32_t j, calc_mask, netsize=0;
141             
142             DEBUGP("ACCOUNT: Found free slot: %d\n", i);
143             strncpy (ipt_acc_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN-1);
144
145             ipt_acc_tables[i].ip = ip;
146             ipt_acc_tables[i].netmask = netmask;
147
148             /* Calculate netsize */
149             calc_mask = htonl(netmask);
150             for (j = 31; j >= 0; j--) {
151                 if (calc_mask&(1<<j))
152                     netsize++;
153                 else
154                     break;
155             }
156
157             /* Calculate depth from netsize */
158             if (netsize >= 24)
159                 ipt_acc_tables[i].depth = 0;
160             else if (netsize >= 16)
161                 ipt_acc_tables[i].depth = 1;
162             else if(netsize >= 8)
163                 ipt_acc_tables[i].depth = 2;
164
165             DEBUGP("ACCOUNT: calculated netsize: %u -> "
166                    "ipt_acc_table depth %u\n", netsize, 
167                    ipt_acc_tables[i].depth);
168
169             ipt_acc_tables[i].refcount++;
170             if ((ipt_acc_tables[i].data
171                 = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
172                 printk("ACCOUNT: out of memory for data of table: %s\n", name);
173                 memset(&ipt_acc_tables[i], 0, 
174                        sizeof(struct ipt_acc_table));
175                 return -1;
176             }
177
178             return i;
179         }
180     }
181
182     /* No free slot found */
183     printk("ACCOUNT: No free table slot found (max: %d). "
184            "Please increase ACCOUNT_MAX_TABLES.\n", ACCOUNT_MAX_TABLES);
185     return -1;
186 }
187
188 static int ipt_acc_checkentry(const char *tablename,
189                                   const struct ipt_entry *e,
190                                   void *targinfo,
191                                   unsigned int targinfosize,
192                                   unsigned int hook_mask)
193 {
194     struct ipt_acc_info *info = targinfo;
195     int table_nr;
196
197     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_acc_info))) {
198         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
199                targinfosize, IPT_ALIGN(sizeof(struct ipt_acc_info)));
200         return 0;
201     }
202
203     LOCK_BH(&ipt_acc_lock);
204     table_nr = ipt_acc_table_insert(info->table_name, info->net_ip,
205                                                       info->net_mask);
206     UNLOCK_BH(&ipt_acc_lock);
207     
208     if (table_nr == -1) {
209         printk("ACCOUNT: Table insert problem. Aborting\n");
210         return 0;
211     }
212     /* Table nr caching so we don't have to do an extra string compare 
213        for every packet */
214     info->table_nr = table_nr;
215
216     return 1;
217 }
218
219 static void ipt_acc_deleteentry(void *targinfo, unsigned int targinfosize)
220 {
221     u_int32_t i;
222     struct ipt_acc_info *info = targinfo;
223
224     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_acc_info))) {
225         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
226                targinfosize, IPT_ALIGN(sizeof(struct ipt_acc_info)));
227     }
228
229     LOCK_BH(&ipt_acc_lock);
230
231     DEBUGP("ACCOUNT: ipt_acc_deleteentry called for table: %s (#%d)\n", 
232            info->table_name, info->table_nr);
233
234     info->table_nr = -1;    /* Set back to original state */
235
236     /* Look for table */
237     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
238         if (strncmp(ipt_acc_tables[i].name, info->table_name, 
239                     ACCOUNT_TABLE_NAME_LEN) == 0) {
240             DEBUGP("ACCOUNT: Found table at slot: %d\n", i);
241
242             ipt_acc_tables[i].refcount--;
243             DEBUGP("ACCOUNT: Refcount left: %d\n", 
244                    ipt_acc_tables[i].refcount);
245
246             /* Table not needed anymore? */
247             if (ipt_acc_tables[i].refcount == 0) {
248                 DEBUGP("ACCOUNT: Destroying table at slot: %d\n", i);
249                 ipt_acc_data_free(ipt_acc_tables[i].data, 
250                                       ipt_acc_tables[i].depth);
251                 memset(&ipt_acc_tables[i], 0, 
252                        sizeof(struct ipt_acc_table));
253             }
254
255             UNLOCK_BH(&ipt_acc_lock);
256             return;
257         }
258     }
259
260     /* Table not found */
261     printk("ACCOUNT: Table %s not found for destroy\n", info->table_name);
262     UNLOCK_BH(&ipt_acc_lock);
263 }
264
265 static void ipt_acc_depth0_insert(struct ipt_acc_mask_24 *mask_24,
266                                u_int32_t net_ip, u_int32_t netmask,
267                                u_int32_t src_ip, u_int32_t dst_ip,
268                                u_int32_t size, u_int32_t *itemcount)
269 {
270     unsigned char is_src = 0, is_dst = 0, src_slot, dst_slot;
271     char is_src_new_ip = 0, is_dst_new_ip = 0; /* Check if this entry is new */
272
273     DEBUGP("ACCOUNT: ipt_acc_depth0_insert: %u.%u.%u.%u/%u.%u.%u.%u "
274            "for net %u.%u.%u.%u/%u.%u.%u.%u, size: %u\n", NIPQUAD(src_ip), 
275            NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask), size);
276
277     /* Check if src/dst is inside our network. */
278     /* Special: net_ip = 0.0.0.0/0 gets stored as src in slot 0 */
279     if (!netmask)
280         src_ip = 0;
281     if ((net_ip&netmask) == (src_ip&netmask))
282         is_src = 1;
283     if ((net_ip&netmask) == (dst_ip&netmask) && netmask)
284         is_dst = 1;
285
286     if (!is_src && !is_dst) {
287         DEBUGP("ACCOUNT: Skipping packet %u.%u.%u.%u/%u.%u.%u.%u "
288                "for net %u.%u.%u.%u/%u.%u.%u.%u\n", NIPQUAD(src_ip), 
289                NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask));
290         return;
291     }
292
293     /* Calculate array positions */
294     src_slot = (unsigned char)((src_ip&0xFF000000) >> 24);
295     dst_slot = (unsigned char)((dst_ip&0xFF000000) >> 24);
296
297     /* Increase size counters */
298     if (is_src) {
299         /* Calculate network slot */
300         DEBUGP("ACCOUNT: Calculated SRC 8 bit network slot: %d\n", src_slot);
301         if (!mask_24->ip[src_slot].src_packets 
302             && !mask_24->ip[src_slot].dst_packets)
303             is_src_new_ip = 1;
304
305         mask_24->ip[src_slot].src_packets++;
306         mask_24->ip[src_slot].src_bytes+=size;
307     }
308     if (is_dst) {
309         DEBUGP("ACCOUNT: Calculated DST 8 bit network slot: %d\n", dst_slot);
310         if (!mask_24->ip[dst_slot].src_packets 
311             && !mask_24->ip[dst_slot].dst_packets)
312             is_dst_new_ip = 1;
313
314         mask_24->ip[dst_slot].dst_packets++;
315         mask_24->ip[dst_slot].dst_bytes+=size;
316     }
317
318     /* Increase itemcounter */
319     DEBUGP("ACCOUNT: Itemcounter before: %d\n", *itemcount);
320     if (src_slot == dst_slot) {
321         if (is_src_new_ip || is_dst_new_ip) {
322             DEBUGP("ACCOUNT: src_slot == dst_slot: %d, %d\n", 
323                    is_src_new_ip, is_dst_new_ip);
324             (*itemcount)++;
325         }
326     } else {
327         if (is_src_new_ip) {
328             DEBUGP("ACCOUNT: New src_ip: %u.%u.%u.%u\n", NIPQUAD(src_ip));
329             (*itemcount)++;
330         }
331         if (is_dst_new_ip) {
332             DEBUGP("ACCOUNT: New dst_ip: %u.%u.%u.%u\n", NIPQUAD(dst_ip));
333             (*itemcount)++;
334         }
335     }
336     DEBUGP("ACCOUNT: Itemcounter after: %d\n", *itemcount);
337 }
338
339 static void ipt_acc_depth1_insert(struct ipt_acc_mask_16 *mask_16, 
340                                u_int32_t net_ip, u_int32_t netmask, 
341                                u_int32_t src_ip, u_int32_t dst_ip,
342                                u_int32_t size, u_int32_t *itemcount)
343 {
344     /* Do we need to process src IP? */
345     if ((net_ip&netmask) == (src_ip&netmask)) {
346         unsigned char slot = (unsigned char)((src_ip&0x00FF0000) >> 16);
347         DEBUGP("ACCOUNT: Calculated SRC 16 bit network slot: %d\n", slot);
348
349         /* Do we need to create a new mask_24 bucket? */
350         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = 
351              (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
352             printk("ACCOUNT: Can't process packet because out of memory!\n");
353             return;
354         }
355
356         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
357                                   net_ip, netmask, src_ip, 0, size, itemcount);
358     }
359
360     /* Do we need to process dst IP? */
361     if ((net_ip&netmask) == (dst_ip&netmask)) {
362         unsigned char slot = (unsigned char)((dst_ip&0x00FF0000) >> 16);
363         DEBUGP("ACCOUNT: Calculated DST 16 bit network slot: %d\n", slot);
364
365         /* Do we need to create a new mask_24 bucket? */
366         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] 
367             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
368             printk("ACCOUT: Can't process packet because out of memory!\n");
369             return;
370         }
371
372         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
373                                   net_ip, netmask, 0, dst_ip, size, itemcount);
374     }
375 }
376
377 static void ipt_acc_depth2_insert(struct ipt_acc_mask_8 *mask_8, 
378                                u_int32_t net_ip, u_int32_t netmask,
379                                u_int32_t src_ip, u_int32_t dst_ip,
380                                u_int32_t size, u_int32_t *itemcount)
381 {
382     /* Do we need to process src IP? */
383     if ((net_ip&netmask) == (src_ip&netmask)) {
384         unsigned char slot = (unsigned char)((src_ip&0x0000FF00) >> 8);
385         DEBUGP("ACCOUNT: Calculated SRC 24 bit network slot: %d\n", slot);
386
387         /* Do we need to create a new mask_24 bucket? */
388         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
389             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
390             printk("ACCOUNT: Can't process packet because out of memory!\n");
391             return;
392         }
393
394         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
395                                   net_ip, netmask, src_ip, 0, size, itemcount);
396     }
397
398     /* Do we need to process dst IP? */
399     if ((net_ip&netmask) == (dst_ip&netmask)) {
400         unsigned char slot = (unsigned char)((dst_ip&0x0000FF00) >> 8);
401         DEBUGP("ACCOUNT: Calculated DST 24 bit network slot: %d\n", slot);
402
403         /* Do we need to create a new mask_24 bucket? */
404         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
405             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
406             printk("ACCOUNT: Can't process packet because out of memory!\n");
407             return;
408         }
409
410         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
411                                   net_ip, netmask, 0, dst_ip, size, itemcount);
412     }
413 }
414
415 static unsigned int ipt_acc_target(struct sk_buff **pskb,
416                                        unsigned int hooknum,
417                                        const struct net_device *in,
418                                        const struct net_device *out,
419                                        const void *targinfo,
420                                        void *userinfo)
421 {
422     const struct ipt_acc_info *info = 
423         (const struct ipt_acc_info *)targinfo;
424     u_int32_t src_ip = (*pskb)->nh.iph->saddr;
425     u_int32_t dst_ip = (*pskb)->nh.iph->daddr;
426     u_int32_t size = ntohs((*pskb)->nh.iph->tot_len);
427
428     LOCK_BH(&ipt_acc_lock);
429
430     if (ipt_acc_tables[info->table_nr].name[0] == 0) {
431         printk("ACCOUNT: ipt_acc_target: Invalid table id %u. "
432                "IPs %u.%u.%u.%u/%u.%u.%u.%u\n", info->table_nr, 
433                NIPQUAD(src_ip), NIPQUAD(dst_ip));
434         UNLOCK_BH(&ipt_acc_lock);
435         return IPT_CONTINUE;
436     }
437
438     /* 8 bit network or "any" network */
439     if (ipt_acc_tables[info->table_nr].depth == 0) {
440         /* Count packet and check if the IP is new */
441         ipt_acc_depth0_insert(
442             (struct ipt_acc_mask_24 *)ipt_acc_tables[info->table_nr].data,
443             ipt_acc_tables[info->table_nr].ip, 
444             ipt_acc_tables[info->table_nr].netmask,
445             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
446         UNLOCK_BH(&ipt_acc_lock);
447         return IPT_CONTINUE;
448     }
449
450     /* 16 bit network */
451     if (ipt_acc_tables[info->table_nr].depth == 1) {
452         ipt_acc_depth1_insert(
453             (struct ipt_acc_mask_16 *)ipt_acc_tables[info->table_nr].data,
454             ipt_acc_tables[info->table_nr].ip, 
455             ipt_acc_tables[info->table_nr].netmask,
456             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
457         UNLOCK_BH(&ipt_acc_lock);
458         return IPT_CONTINUE;
459     }
460
461     /* 24 bit network */
462     if (ipt_acc_tables[info->table_nr].depth == 2) {
463         ipt_acc_depth2_insert(
464             (struct ipt_acc_mask_8 *)ipt_acc_tables[info->table_nr].data,
465             ipt_acc_tables[info->table_nr].ip, 
466             ipt_acc_tables[info->table_nr].netmask,
467             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
468         UNLOCK_BH(&ipt_acc_lock);
469         return IPT_CONTINUE;
470     }
471
472     printk("ACCOUNT: ipt_acc_target: Unable to process packet. "
473            "Table id %u. IPs %u.%u.%u.%u/%u.%u.%u.%u\n", 
474            info->table_nr, NIPQUAD(src_ip), NIPQUAD(dst_ip));
475
476     UNLOCK_BH(&ipt_acc_lock);
477     return IPT_CONTINUE;
478 }
479
480 /*
481     Functions dealing with "handles":
482     Handles are snapshots of a accounting state.
483     
484     read snapshots are only for debugging the code
485     and are very expensive concerning speed/memory
486     compared to read_and_flush.
487     
488     The functions aren't protected by spinlocks themselves
489     as this is done in the ioctl part of the code.
490 */
491
492 /*
493     Find a free handle slot. Normally only one should be used,
494     but there could be two or more applications accessing the data
495     at the same time.
496 */
497 static int ipt_acc_handle_find_slot(void)
498 {
499     u_int32_t i;
500     /* Insert new table */
501     for (i = 0; i < ACCOUNT_MAX_HANDLES; i++) {
502         /* Found free slot */
503         if (ipt_acc_handles[i].data == NULL) {
504             /* Don't "mark" data as used as we are protected by a spinlock 
505                by the calling function. handle_find_slot() is only a function
506                to prevent code duplication. */
507             return i;
508         }
509     }
510
511     /* No free slot found */
512     printk("ACCOUNT: No free handle slot found (max: %u). "
513            "Please increase ACCOUNT_MAX_HANDLES.\n", ACCOUNT_MAX_HANDLES);
514     return -1;
515 }
516
517 static int ipt_acc_handle_free(u_int32_t handle)
518 {
519     if (handle >= ACCOUNT_MAX_HANDLES) {
520         printk("ACCOUNT: Invalid handle for ipt_acc_handle_free() specified:"
521                " %u\n", handle);
522         return -EINVAL;
523     }
524
525     ipt_acc_data_free(ipt_acc_handles[handle].data, 
526                           ipt_acc_handles[handle].depth);
527     memset (&ipt_acc_handles[handle], 0, sizeof (struct ipt_acc_handle));
528     return 0;
529 }
530
531 /* Prepare data for read without flush. Use only for debugging!
532    Real applications should use read&flush as it's way more efficent */
533 static int ipt_acc_handle_prepare_read(char *tablename,
534          struct ipt_acc_handle *dest, u_int32_t *count)
535 {
536     int table_nr=-1;
537     unsigned char depth;
538
539     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
540         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
541             ACCOUNT_TABLE_NAME_LEN) == 0)
542                 break;
543
544     if (table_nr == ACCOUNT_MAX_TABLES) {
545         printk("ACCOUNT: ipt_acc_handle_prepare_read(): "
546                "Table %s not found\n", tablename);
547         return -1;
548     }
549
550     /* Fill up handle structure */
551     dest->ip = ipt_acc_tables[table_nr].ip;
552     dest->depth = ipt_acc_tables[table_nr].depth;
553     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
554
555     /* allocate "root" table */
556     if ((dest->data = (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
557         printk("ACCOUNT: out of memory for root table "
558                "in ipt_acc_handle_prepare_read()\n");
559         return -1;
560     }
561
562     /* Recursive copy of complete data structure */
563     depth = dest->depth;
564     if (depth == 0) {
565         memcpy(dest->data, 
566                ipt_acc_tables[table_nr].data, 
567                sizeof(struct ipt_acc_mask_24));
568     } else if (depth == 1) {
569         struct ipt_acc_mask_16 *src_16 = 
570             (struct ipt_acc_mask_16 *)ipt_acc_tables[table_nr].data;
571         struct ipt_acc_mask_16 *network_16 =
572             (struct ipt_acc_mask_16 *)dest->data;
573         u_int32_t b;
574
575         for (b = 0; b <= 255; b++) {
576             if (src_16->mask_24[b]) {
577                 if ((network_16->mask_24[b] = 
578                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
579                     printk("ACCOUNT: out of memory during copy of 16 bit "
580                            "network in ipt_acc_handle_prepare_read()\n");
581                     ipt_acc_data_free(dest->data, depth);
582                     return -1;
583                 }
584
585                 memcpy(network_16->mask_24[b], src_16->mask_24[b], 
586                        sizeof(struct ipt_acc_mask_24));
587             }
588         }
589     } else if(depth == 2) {
590         struct ipt_acc_mask_8 *src_8 = 
591             (struct ipt_acc_mask_8 *)ipt_acc_tables[table_nr].data;
592         struct ipt_acc_mask_8 *network_8 = 
593             (struct ipt_acc_mask_8 *)dest->data;
594         u_int32_t a;
595
596         for (a = 0; a <= 255; a++) {
597             if (src_8->mask_16[a]) {
598                 if ((network_8->mask_16[a] = 
599                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
600                     printk("ACCOUNT: out of memory during copy of 24 bit network"
601                            " in ipt_acc_handle_prepare_read()\n");
602                     ipt_acc_data_free(dest->data, depth);
603                     return -1;
604                 }
605
606                 memcpy(network_8->mask_16[a], src_8->mask_16[a], 
607                        sizeof(struct ipt_acc_mask_16));
608
609                 struct ipt_acc_mask_16 *src_16 = src_8->mask_16[a];
610                 struct ipt_acc_mask_16 *network_16 = network_8->mask_16[a];
611                 u_int32_t b;
612
613                 for (b = 0; b <= 255; b++) {
614                     if (src_16->mask_24[b]) {
615                         if ((network_16->mask_24[b] = 
616                              (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
617                             printk("ACCOUNT: out of memory during copy of 16 bit"
618                                    " network in ipt_acc_handle_prepare_read()\n");
619                             ipt_acc_data_free(dest->data, depth);
620                             return -1;
621                         }
622
623                         memcpy(network_16->mask_24[b], src_16->mask_24[b], 
624                                sizeof(struct ipt_acc_mask_24));
625                     }
626                 }
627             }
628         }
629     }
630
631     *count = ipt_acc_tables[table_nr].itemcount;
632     
633     return 0;
634 }
635
636 /* Prepare data for read and flush it */
637 static int ipt_acc_handle_prepare_read_flush(char *tablename,
638                struct ipt_acc_handle *dest, u_int32_t *count)
639 {
640     int table_nr;
641     void *new_data_page;
642
643     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
644         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
645             ACCOUNT_TABLE_NAME_LEN) == 0)
646                 break;
647
648     if (table_nr == ACCOUNT_MAX_TABLES) {
649         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
650                "Table %s not found\n", tablename);
651         return -1;
652     }
653
654     /* Try to allocate memory */
655     if (!(new_data_page = (void*)get_zeroed_page(GFP_ATOMIC))) {
656         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
657                "Out of memory!\n");
658         return -1;
659     }
660
661     /* Fill up handle structure */
662     dest->ip = ipt_acc_tables[table_nr].ip;
663     dest->depth = ipt_acc_tables[table_nr].depth;
664     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
665     dest->data = ipt_acc_tables[table_nr].data;
666     *count = ipt_acc_tables[table_nr].itemcount;
667
668     /* "Flush" table data */
669     ipt_acc_tables[table_nr].data = new_data_page;
670     ipt_acc_tables[table_nr].itemcount = 0;
671
672     return 0;
673 }
674
675 /* Copy 8 bit network data into a prepared buffer.
676    We only copy entries != 0 to increase performance.
677 */
678 static int ipt_acc_handle_copy_data(void *to_user, u_int32_t *to_user_pos,
679                                   u_int32_t *tmpbuf_pos, 
680                                   struct ipt_acc_mask_24 *data,
681                                   u_int32_t net_ip, u_int32_t net_OR_mask)
682 {
683     struct ipt_acc_handle_ip handle_ip;
684     u_int32_t handle_ip_size = sizeof (struct ipt_acc_handle_ip);
685     u_int32_t i;
686     
687     for (i = 0; i <= 255; i++) {
688         if (data->ip[i].src_packets || data->ip[i].dst_packets) {
689             handle_ip.ip = net_ip | net_OR_mask | (i<<24);
690             
691             handle_ip.src_packets = data->ip[i].src_packets;
692             handle_ip.src_bytes = data->ip[i].src_bytes;
693             handle_ip.dst_packets = data->ip[i].dst_packets;
694             handle_ip.dst_bytes = data->ip[i].dst_bytes;
695
696             /* Temporary buffer full? Flush to userspace */
697             if (*tmpbuf_pos+handle_ip_size >= PAGE_SIZE) {
698                 if (copy_to_user(to_user + *to_user_pos, ipt_acc_tmpbuf,
699                                                            *tmpbuf_pos))
700                     return -EFAULT;
701                 *to_user_pos = *to_user_pos + *tmpbuf_pos;
702                 *tmpbuf_pos = 0;
703             }
704             memcpy(ipt_acc_tmpbuf+*tmpbuf_pos, &handle_ip, handle_ip_size);
705             *tmpbuf_pos += handle_ip_size;
706         }
707     }
708     
709     return 0;
710 }
711    
712 /* Copy the data from our internal structure 
713    We only copy entries != 0 to increase performance.
714    Overwrites ipt_acc_tmpbuf.
715 */
716 static int ipt_acc_handle_get_data(u_int32_t handle, void *to_user)
717 {
718     u_int32_t to_user_pos=0, tmpbuf_pos=0, net_ip;
719     unsigned char depth;
720
721     if (handle >= ACCOUNT_MAX_HANDLES) {
722         printk("ACCOUNT: invalid handle for ipt_acc_handle_get_data() "
723                "specified: %u\n", handle);
724         return -1;
725     }
726
727     if (ipt_acc_handles[handle].data == NULL) {
728         printk("ACCOUNT: handle %u is BROKEN: Contains no data\n", handle);
729         return -1;
730     }
731
732     net_ip = ipt_acc_handles[handle].ip;
733     depth = ipt_acc_handles[handle].depth;
734
735     /* 8 bit network */
736     if (depth == 0) {
737         struct ipt_acc_mask_24 *network = 
738             (struct ipt_acc_mask_24*)ipt_acc_handles[handle].data;
739         if (ipt_acc_handle_copy_data(to_user, &to_user_pos, &tmpbuf_pos,
740                                      network, net_ip, 0))
741             return -1;
742         
743         /* Flush remaining data to userspace */
744         if (tmpbuf_pos)
745             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
746                 return -1;
747
748         return 0;
749     }
750
751     /* 16 bit network */
752     if (depth == 1) {
753         struct ipt_acc_mask_16 *network_16 = 
754             (struct ipt_acc_mask_16*)ipt_acc_handles[handle].data;
755         u_int32_t b;
756         for (b = 0; b <= 255; b++) {
757             if (network_16->mask_24[b]) {
758                 struct ipt_acc_mask_24 *network = 
759                     (struct ipt_acc_mask_24*)network_16->mask_24[b];
760                 if (ipt_acc_handle_copy_data(to_user, &to_user_pos,
761                                       &tmpbuf_pos, network, net_ip, (b << 16)))
762                     return -1;
763             }
764         }
765
766         /* Flush remaining data to userspace */
767         if (tmpbuf_pos)
768             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
769                 return -1;
770
771         return 0;
772     }
773
774     /* 24 bit network */
775     if (depth == 2) {
776         struct ipt_acc_mask_8 *network_8 = 
777             (struct ipt_acc_mask_8*)ipt_acc_handles[handle].data;
778         u_int32_t a, b;
779         for (a = 0; a <= 255; a++) {
780             if (network_8->mask_16[a]) {
781                 struct ipt_acc_mask_16 *network_16 = 
782                     (struct ipt_acc_mask_16*)network_8->mask_16[a];
783                 for (b = 0; b <= 255; b++) {
784                     if (network_16->mask_24[b]) {
785                         struct ipt_acc_mask_24 *network = 
786                             (struct ipt_acc_mask_24*)network_16->mask_24[b];
787                         if (ipt_acc_handle_copy_data(to_user,
788                                        &to_user_pos, &tmpbuf_pos,
789                                        network, net_ip, (a << 8) | (b << 16)))
790                             return -1;
791                     }
792                 }
793             }
794         }
795
796         /* Flush remaining data to userspace */
797         if (tmpbuf_pos)
798             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
799                 return -1;
800
801         return 0;
802     }
803     
804     return -1;
805 }
806
807 static int ipt_acc_set_ctl(struct sock *sk, int cmd, 
808                                void *user, u_int32_t len)
809 {
810     struct ipt_acc_handle_sockopt handle;
811     int ret = -EINVAL;
812
813     if (!capable(CAP_NET_ADMIN))
814         return -EPERM;
815
816     switch (cmd) {
817     case IPT_SO_SET_ACCOUNT_HANDLE_FREE:
818         if (len != sizeof(struct ipt_acc_handle_sockopt)) {
819             printk("ACCOUNT: ipt_acc_set_ctl: wrong data size (%u != %u) "
820                    "for IPT_SO_SET_HANDLE_FREE\n", 
821                    len, sizeof(struct ipt_acc_handle_sockopt));
822             break;
823         }
824
825         if (copy_from_user (&handle, user, len)) {
826             printk("ACCOUNT: ipt_acc_set_ctl: copy_from_user failed for "
827                    "IPT_SO_SET_HANDLE_FREE\n");
828             break;
829         }
830
831         down(&ipt_acc_userspace_mutex);
832         ret = ipt_acc_handle_free(handle.handle_nr);
833         up(&ipt_acc_userspace_mutex);
834         break;
835     case IPT_SO_SET_ACCOUNT_HANDLE_FREE_ALL: {
836             u_int32_t i;
837             down(&ipt_acc_userspace_mutex);
838             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
839                 ipt_acc_handle_free(i);
840             up(&ipt_acc_userspace_mutex);
841             ret = 0;
842             break;
843         }
844     default:
845         printk("ACCOUNT: ipt_acc_set_ctl: unknown request %i\n", cmd);
846     }
847
848     return ret;
849 }
850
851 static int ipt_acc_get_ctl(struct sock *sk, int cmd, void *user, int *len)
852 {
853     struct ipt_acc_handle_sockopt handle;
854     int ret = -EINVAL;
855
856     if (!capable(CAP_NET_ADMIN))
857         return -EPERM;
858
859     switch (cmd) {
860     case IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH:
861     case IPT_SO_GET_ACCOUNT_PREPARE_READ: {
862             struct ipt_acc_handle dest;
863     
864             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
865                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u) "
866                     "for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
867                     *len, sizeof(struct ipt_acc_handle_sockopt));
868                 break;
869             }
870     
871             if (copy_from_user (&handle, user, 
872                                 sizeof(struct ipt_acc_handle_sockopt))) {
873                 return -EFAULT;
874                 break;
875             }
876     
877             LOCK_BH(&ipt_acc_lock);
878             if (cmd == IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH)
879                 ret = ipt_acc_handle_prepare_read_flush(
880                                     handle.name, &dest, &handle.itemcount);
881             else
882                 ret = ipt_acc_handle_prepare_read(
883                                     handle.name, &dest, &handle.itemcount);
884             UNLOCK_BH(&ipt_acc_lock);
885             // Error occured during prepare_read?
886            if (ret == -1)
887                 return -EINVAL;
888             
889             /* Allocate a userspace handle */
890             down(&ipt_acc_userspace_mutex);
891             if ((handle.handle_nr = ipt_acc_handle_find_slot()) == -1) {
892                 ipt_acc_data_free(dest.data, dest.depth);
893                 return -EINVAL;
894             }
895             memcpy(&ipt_acc_handles[handle.handle_nr], &dest,
896                              sizeof(struct ipt_acc_handle));
897             up(&ipt_acc_userspace_mutex);
898             
899             if (copy_to_user(user, &handle, 
900                             sizeof(struct ipt_acc_handle_sockopt))) {
901                 return -EFAULT;
902                 break;
903             }
904             ret = 0;
905             break;
906         }
907     case IPT_SO_GET_ACCOUNT_GET_DATA:
908         if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
909             printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u)"
910                    " for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
911                    *len, sizeof(struct ipt_acc_handle_sockopt));
912             break;
913         }
914
915         if (copy_from_user (&handle, user, 
916                             sizeof(struct ipt_acc_handle_sockopt))) {
917             return -EFAULT;
918             break;
919         }
920
921         if (handle.handle_nr >= ACCOUNT_MAX_HANDLES) {
922             return -EINVAL;
923             break;
924         }
925
926         if (*len < ipt_acc_handles[handle.handle_nr].itemcount
927                    * sizeof(struct ipt_acc_handle_ip)) {
928             printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %u)"
929                    " to store data from IPT_SO_GET_ACCOUNT_GET_DATA\n",
930                    *len, ipt_acc_handles[handle.handle_nr].itemcount
931                    * sizeof(struct ipt_acc_handle_ip));
932             ret = -ENOMEM;
933             break;
934         }
935
936         down(&ipt_acc_userspace_mutex);
937         ret = ipt_acc_handle_get_data(handle.handle_nr, user);
938         up(&ipt_acc_userspace_mutex);
939         if (ret) {
940             printk("ACCOUNT: ipt_acc_get_ctl: ipt_acc_handle_get_data"
941                    " failed for handle %u\n", handle.handle_nr);
942             break;
943         }
944
945         ret = 0;
946         break;
947     case IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE: {
948             u_int32_t i;
949             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
950                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u)"
951                        " for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n",
952                        *len, sizeof(struct ipt_acc_handle_sockopt));
953                 break;
954             }
955
956             /* Find out how many handles are in use */
957             handle.itemcount = 0;
958             down(&ipt_acc_userspace_mutex);
959             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
960                 if (ipt_acc_handles[i].data)
961                     handle.itemcount++;
962             up(&ipt_acc_userspace_mutex);
963
964             if (copy_to_user(user, &handle, 
965                              sizeof(struct ipt_acc_handle_sockopt))) {
966                 return -EFAULT;
967                 break;
968             }
969             ret = 0;
970             break;
971         }
972     case IPT_SO_GET_ACCOUNT_GET_TABLE_NAMES: {
973             u_int32_t size = 0, i, name_len;
974             char *tnames;
975
976             LOCK_BH(&ipt_acc_lock);
977
978             /* Determine size of table names */
979             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
980                 if (ipt_acc_tables[i].name[0] != 0)
981                     size += strlen (ipt_acc_tables[i].name) + 1;
982             }
983             size += 1;    /* Terminating NULL character */
984
985             if (*len < size || size > PAGE_SIZE) {
986                 UNLOCK_BH(&ipt_acc_lock);
987                 printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %u < %lu)"
988                        " to store table names\n", *len, size, PAGE_SIZE);
989                 ret = -ENOMEM;
990                 break;
991             }
992             /* Copy table names to userspace */
993             tnames = ipt_acc_tmpbuf;
994             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
995                 if (ipt_acc_tables[i].name[0] != 0) {
996                     name_len = strlen (ipt_acc_tables[i].name) + 1;
997                     memcpy(tnames, ipt_acc_tables[i].name, name_len);
998                     tnames += name_len;
999                 }
1000             }
1001             UNLOCK_BH(&ipt_acc_lock);
1002             
1003             /* Terminating NULL character */
1004             *tnames = 0;
1005             
1006             /* Transfer to userspace */                    
1007             if (copy_to_user(user, ipt_acc_tmpbuf, size))
1008                 return -EFAULT;
1009             
1010             ret = 0;
1011             break;
1012         }
1013     default:
1014         printk("ACCOUNT: ipt_acc_get_ctl: unknown request %i\n", cmd);
1015     }
1016
1017     return ret;
1018 }
1019
1020 static struct ipt_target ipt_acc_reg = {
1021     .name = "ACCOUNT",
1022     .target = ipt_acc_target,
1023     .checkentry = ipt_acc_checkentry,
1024     .destroy = ipt_acc_deleteentry,
1025     .me = THIS_MODULE
1026 };
1027
1028 static struct nf_sockopt_ops ipt_acc_sockopts = {
1029     .pf = PF_INET,
1030     .set_optmin = IPT_SO_SET_ACCOUNT_HANDLE_FREE,
1031     .set_optmax = IPT_SO_SET_ACCOUNT_MAX+1,
1032     .set = ipt_acc_set_ctl,
1033     .get_optmin = IPT_SO_GET_ACCOUNT_PREPARE_READ,
1034     .get_optmax = IPT_SO_GET_ACCOUNT_MAX+1,
1035     .get = ipt_acc_get_ctl
1036 };
1037
1038 static int __init init(void)
1039 {
1040     init_MUTEX(&ipt_acc_userspace_mutex);    
1041
1042     if ((ipt_acc_tables = 
1043          kmalloc(ACCOUNT_MAX_TABLES * 
1044                  sizeof(struct ipt_acc_table), GFP_KERNEL)) == NULL) {
1045         printk("ACCOUNT: Out of memory allocating account_tables structure");
1046         goto error_cleanup;
1047     }
1048     memset(ipt_acc_tables, 0, 
1049            ACCOUNT_MAX_TABLES * sizeof(struct ipt_acc_table));
1050
1051     if ((ipt_acc_handles = 
1052          kmalloc(ACCOUNT_MAX_HANDLES * 
1053                  sizeof(struct ipt_acc_handle), GFP_KERNEL)) == NULL) {
1054         printk("ACCOUNT: Out of memory allocating account_handles structure");
1055         goto error_cleanup;
1056     }
1057     memset(ipt_acc_handles, 0, 
1058            ACCOUNT_MAX_HANDLES * sizeof(struct ipt_acc_handle));
1059
1060     /* Allocate one page as temporary storage */
1061     if ((ipt_acc_tmpbuf = (void*)__get_free_page(GFP_KERNEL)) == NULL) {
1062         printk("ACCOUNT: Out of memory for temporary buffer page\n");
1063         goto error_cleanup;
1064     }
1065
1066     /* Register setsockopt */
1067     if (nf_register_sockopt(&ipt_acc_sockopts) < 0) {
1068         printk("ACCOUNT: Can't register sockopts. Aborting\n");
1069         goto error_cleanup;
1070     }
1071
1072     if (ipt_register_target(&ipt_acc_reg))
1073         goto error_cleanup;
1074         
1075     return 0;
1076         
1077 error_cleanup:
1078     if(ipt_acc_tables)
1079         kfree(ipt_acc_tables);
1080     if(ipt_acc_handles)
1081         kfree(ipt_acc_handles);
1082     if (ipt_acc_tmpbuf)
1083         free_page((unsigned long)ipt_acc_tmpbuf);
1084         
1085     return -EINVAL;
1086 }
1087
1088 static void __exit fini(void)
1089 {
1090     ipt_unregister_target(&ipt_acc_reg);
1091
1092     nf_unregister_sockopt(&ipt_acc_sockopts);
1093
1094     kfree(ipt_acc_tables);
1095     kfree(ipt_acc_handles);
1096     free_page((unsigned long)ipt_acc_tmpbuf);
1097 }
1098
1099 module_init(init);
1100 module_exit(fini);
1101 MODULE_LICENSE("GPL");