Initial 2.6 submission
[ipt_ACCOUNT] / linux-2.6 / net / ipv4 / netfilter / ipt_ACCOUNT.c
1 /***************************************************************************
2  *   This is a module which is used for counting packets.                  *
3  *   See http://www.intra2net.com/opensource/ipt_account                   *
4  *   for further information                                               *
5  *                                                                         * 
6  *   Copyright (C) 2004-2006 by Intra2net AG                               *
7  *   opensource@intra2net.com                                              *
8  *                                                                         *
9  *   This program is free software; you can redistribute it and/or modify  *
10  *   it under the terms of the GNU General Public License                  *
11  *   version 2 as published by the Free Software Foundation;               *
12  *                                                                         *
13  ***************************************************************************/
14
15 #include <linux/module.h>
16 #include <linux/skbuff.h>
17 #include <linux/ip.h>
18 #include <net/icmp.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <linux/netfilter_ipv4/ip_tables.h>
22 #include <asm/semaphore.h>
23 #include <linux/kernel.h>
24 #include <linux/mm.h>
25 #include <linux/string.h>
26 #include <linux/spinlock.h>
27 #include <asm/uaccess.h>
28
29 #include <net/route.h>
30 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
31
32 #if 0
33 #define DEBUGP printk
34 #else
35 #define DEBUGP(format, args...)
36 #endif
37
38 #if (PAGE_SIZE < 4096)
39 #error "ipt_ACCOUNT needs at least a PAGE_SIZE of 4096"
40 #endif
41
42 static struct ipt_acc_table *ipt_acc_tables = NULL;
43 static struct ipt_acc_handle *ipt_acc_handles = NULL;
44 static void *ipt_acc_tmpbuf = NULL;
45
46 /* Spinlock used for manipulating the current accounting tables/data */
47 static DEFINE_SPINLOCK(ipt_acc_lock);
48 /* Mutex (semaphore) used for manipulating userspace handles/snapshot data */
49 static struct semaphore ipt_acc_userspace_mutex;
50
51
52 /* Recursive free of all data structures */
53 static void ipt_acc_data_free(void *data, unsigned char depth)
54 {
55     /* Empty data set */
56     if (!data)
57         return;
58
59     /* Free for 8 bit network */
60     if (depth == 0) {
61         free_page((unsigned long)data);
62         return;
63     }
64
65     /* Free for 16 bit network */
66     if (depth == 1) {
67         struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16 *)data;
68         u_int32_t b;
69         for (b=0; b <= 255; b++) {
70             if (mask_16->mask_24[b] != 0) {
71                 free_page((unsigned long)mask_16->mask_24[b]);
72             }
73         }
74         free_page((unsigned long)data);
75         return;
76     }
77
78     /* Free for 24 bit network */
79     if (depth == 2) {
80         u_int32_t a, b;
81         for (a=0; a <= 255; a++) {
82             if (((struct ipt_acc_mask_8 *)data)->mask_16[a]) {
83                 struct ipt_acc_mask_16 *mask_16 = (struct ipt_acc_mask_16*)
84                                    ((struct ipt_acc_mask_8 *)data)->mask_16[a];
85
86                 for (b=0; b <= 255; b++) {
87                     if (mask_16->mask_24[b]) {
88                         free_page((unsigned long)mask_16->mask_24[b]);
89                     }
90                 }
91                 free_page((unsigned long)mask_16);
92             }
93         }
94         free_page((unsigned long)data);
95         return;
96     }
97
98     printk("ACCOUNT: ipt_acc_data_free called with unknown depth: %d\n", 
99            depth);
100     return;
101 }
102
103 /* Look for existing table / insert new one. 
104    Return internal ID or -1 on error */
105 static int ipt_acc_table_insert(char *name, u_int32_t ip, u_int32_t netmask)
106 {
107     u_int32_t i;
108
109     DEBUGP("ACCOUNT: ipt_acc_table_insert: %s, %u.%u.%u.%u/%u.%u.%u.%u\n",
110                                          name, NIPQUAD(ip), NIPQUAD(netmask));
111
112     /* Look for existing table */
113     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
114         if (strncmp(ipt_acc_tables[i].name, name, 
115                     ACCOUNT_TABLE_NAME_LEN) == 0) {
116             DEBUGP("ACCOUNT: Found existing slot: %d - "
117                    "%u.%u.%u.%u/%u.%u.%u.%u\n", i, 
118                    NIPQUAD(ipt_acc_tables[i].ip), 
119                    NIPQUAD(ipt_acc_tables[i].netmask));
120
121             if (ipt_acc_tables[i].ip != ip 
122                 || ipt_acc_tables[i].netmask != netmask) {
123                 printk("ACCOUNT: Table %s found, but IP/netmask mismatch. "
124                        "IP/netmask found: %u.%u.%u.%u/%u.%u.%u.%u\n",
125                        name, NIPQUAD(ipt_acc_tables[i].ip), 
126                        NIPQUAD(ipt_acc_tables[i].netmask));
127                 return -1;
128             }
129
130             ipt_acc_tables[i].refcount++;
131             DEBUGP("ACCOUNT: Refcount: %d\n", ipt_acc_tables[i].refcount);
132             return i;
133         }
134     }
135
136     /* Insert new table */
137     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
138         /* Found free slot */
139         if (ipt_acc_tables[i].name[0] == 0) {
140             u_int32_t calc_mask, netsize=0;
141             int j;  /* needs to be signed, otherwise we risk endless loop */
142
143             DEBUGP("ACCOUNT: Found free slot: %d\n", i);
144             strncpy (ipt_acc_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN-1);
145
146             ipt_acc_tables[i].ip = ip;
147             ipt_acc_tables[i].netmask = netmask;
148
149             /* Calculate netsize */
150             calc_mask = htonl(netmask);
151             for (j = 31; j >= 0; j--) {
152                 if (calc_mask&(1<<j))
153                     netsize++;
154                 else
155                     break;
156             }
157
158             /* Calculate depth from netsize */
159             if (netsize >= 24)
160                 ipt_acc_tables[i].depth = 0;
161             else if (netsize >= 16)
162                 ipt_acc_tables[i].depth = 1;
163             else if(netsize >= 8)
164                 ipt_acc_tables[i].depth = 2;
165
166             DEBUGP("ACCOUNT: calculated netsize: %u -> "
167                    "ipt_acc_table depth %u\n", netsize, 
168                    ipt_acc_tables[i].depth);
169
170             ipt_acc_tables[i].refcount++;
171             if ((ipt_acc_tables[i].data
172                 = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
173                 printk("ACCOUNT: out of memory for data of table: %s\n", name);
174                 memset(&ipt_acc_tables[i], 0, 
175                        sizeof(struct ipt_acc_table));
176                 return -1;
177             }
178
179             return i;
180         }
181     }
182
183     /* No free slot found */
184     printk("ACCOUNT: No free table slot found (max: %d). "
185            "Please increase ACCOUNT_MAX_TABLES.\n", ACCOUNT_MAX_TABLES);
186     return -1;
187 }
188
189 static int ipt_acc_checkentry(const char *tablename,
190                               const void *e,
191                               const struct xt_target *target,
192                               void *targinfo,
193                               unsigned int targinfosize,
194                               unsigned int hook_mask)
195 {
196     struct ipt_acc_info *info = targinfo;
197     int table_nr;
198
199     spin_lock_bh(&ipt_acc_lock);
200     table_nr = ipt_acc_table_insert(info->table_name, info->net_ip,
201                                                       info->net_mask);
202     spin_unlock_bh(&ipt_acc_lock);
203
204     if (table_nr == -1) {
205         printk("ACCOUNT: Table insert problem. Aborting\n");
206         return 0;
207     }
208     /* Table nr caching so we don't have to do an extra string compare 
209        for every packet */
210     info->table_nr = table_nr;
211
212     return 1;
213 }
214
215 static void ipt_acc_destroy(const struct xt_target *target, void *targinfo, unsigned int targinfosize)
216 {
217     u_int32_t i;
218     struct ipt_acc_info *info = targinfo;
219
220     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_acc_info))) {
221         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
222                targinfosize, IPT_ALIGN(sizeof(struct ipt_acc_info)));
223     }
224
225     spin_lock_bh(&ipt_acc_lock);
226
227     DEBUGP("ACCOUNT: ipt_acc_deleteentry called for table: %s (#%d)\n", 
228            info->table_name, info->table_nr);
229
230     info->table_nr = -1;    /* Set back to original state */
231
232     /* Look for table */
233     for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
234         if (strncmp(ipt_acc_tables[i].name, info->table_name, 
235                     ACCOUNT_TABLE_NAME_LEN) == 0) {
236             DEBUGP("ACCOUNT: Found table at slot: %d\n", i);
237
238             ipt_acc_tables[i].refcount--;
239             DEBUGP("ACCOUNT: Refcount left: %d\n", 
240                    ipt_acc_tables[i].refcount);
241
242             /* Table not needed anymore? */
243             if (ipt_acc_tables[i].refcount == 0) {
244                 DEBUGP("ACCOUNT: Destroying table at slot: %d\n", i);
245                 ipt_acc_data_free(ipt_acc_tables[i].data, 
246                                       ipt_acc_tables[i].depth);
247                 memset(&ipt_acc_tables[i], 0, 
248                        sizeof(struct ipt_acc_table));
249             }
250
251             spin_unlock_bh(&ipt_acc_lock);
252             return;
253         }
254     }
255
256     /* Table not found */
257     printk("ACCOUNT: Table %s not found for destroy\n", info->table_name);
258     spin_unlock_bh(&ipt_acc_lock);
259 }
260
261 static void ipt_acc_depth0_insert(struct ipt_acc_mask_24 *mask_24,
262                                u_int32_t net_ip, u_int32_t netmask,
263                                u_int32_t src_ip, u_int32_t dst_ip,
264                                u_int32_t size, u_int32_t *itemcount)
265 {
266     unsigned char is_src = 0, is_dst = 0, src_slot, dst_slot;
267     char is_src_new_ip = 0, is_dst_new_ip = 0; /* Check if this entry is new */
268
269     DEBUGP("ACCOUNT: ipt_acc_depth0_insert: %u.%u.%u.%u/%u.%u.%u.%u "
270            "for net %u.%u.%u.%u/%u.%u.%u.%u, size: %u\n", NIPQUAD(src_ip), 
271            NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask), size);
272
273     /* Check if src/dst is inside our network. */
274     /* Special: net_ip = 0.0.0.0/0 gets stored as src in slot 0 */
275     if (!netmask)
276         src_ip = 0;
277     if ((net_ip&netmask) == (src_ip&netmask))
278         is_src = 1;
279     if ((net_ip&netmask) == (dst_ip&netmask) && netmask)
280         is_dst = 1;
281
282     if (!is_src && !is_dst) {
283         DEBUGP("ACCOUNT: Skipping packet %u.%u.%u.%u/%u.%u.%u.%u "
284                "for net %u.%u.%u.%u/%u.%u.%u.%u\n", NIPQUAD(src_ip), 
285                NIPQUAD(dst_ip), NIPQUAD(net_ip), NIPQUAD(netmask));
286         return;
287     }
288
289     /* Calculate array positions */
290     src_slot = (unsigned char)((src_ip&0xFF000000) >> 24);
291     dst_slot = (unsigned char)((dst_ip&0xFF000000) >> 24);
292
293     /* Increase size counters */
294     if (is_src) {
295         /* Calculate network slot */
296         DEBUGP("ACCOUNT: Calculated SRC 8 bit network slot: %d\n", src_slot);
297         if (!mask_24->ip[src_slot].src_packets 
298             && !mask_24->ip[src_slot].dst_packets)
299             is_src_new_ip = 1;
300
301         mask_24->ip[src_slot].src_packets++;
302         mask_24->ip[src_slot].src_bytes+=size;
303     }
304     if (is_dst) {
305         DEBUGP("ACCOUNT: Calculated DST 8 bit network slot: %d\n", dst_slot);
306         if (!mask_24->ip[dst_slot].src_packets 
307             && !mask_24->ip[dst_slot].dst_packets)
308             is_dst_new_ip = 1;
309
310         mask_24->ip[dst_slot].dst_packets++;
311         mask_24->ip[dst_slot].dst_bytes+=size;
312     }
313
314     /* Increase itemcounter */
315     DEBUGP("ACCOUNT: Itemcounter before: %d\n", *itemcount);
316     if (src_slot == dst_slot) {
317         if (is_src_new_ip || is_dst_new_ip) {
318             DEBUGP("ACCOUNT: src_slot == dst_slot: %d, %d\n", 
319                    is_src_new_ip, is_dst_new_ip);
320             (*itemcount)++;
321         }
322     } else {
323         if (is_src_new_ip) {
324             DEBUGP("ACCOUNT: New src_ip: %u.%u.%u.%u\n", NIPQUAD(src_ip));
325             (*itemcount)++;
326         }
327         if (is_dst_new_ip) {
328             DEBUGP("ACCOUNT: New dst_ip: %u.%u.%u.%u\n", NIPQUAD(dst_ip));
329             (*itemcount)++;
330         }
331     }
332     DEBUGP("ACCOUNT: Itemcounter after: %d\n", *itemcount);
333 }
334
335 static void ipt_acc_depth1_insert(struct ipt_acc_mask_16 *mask_16, 
336                                u_int32_t net_ip, u_int32_t netmask, 
337                                u_int32_t src_ip, u_int32_t dst_ip,
338                                u_int32_t size, u_int32_t *itemcount)
339 {
340     /* Do we need to process src IP? */
341     if ((net_ip&netmask) == (src_ip&netmask)) {
342         unsigned char slot = (unsigned char)((src_ip&0x00FF0000) >> 16);
343         DEBUGP("ACCOUNT: Calculated SRC 16 bit network slot: %d\n", slot);
344
345         /* Do we need to create a new mask_24 bucket? */
346         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] = 
347              (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
348             printk("ACCOUNT: Can't process packet because out of memory!\n");
349             return;
350         }
351
352         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
353                                   net_ip, netmask, src_ip, 0, size, itemcount);
354     }
355
356     /* Do we need to process dst IP? */
357     if ((net_ip&netmask) == (dst_ip&netmask)) {
358         unsigned char slot = (unsigned char)((dst_ip&0x00FF0000) >> 16);
359         DEBUGP("ACCOUNT: Calculated DST 16 bit network slot: %d\n", slot);
360
361         /* Do we need to create a new mask_24 bucket? */
362         if (!mask_16->mask_24[slot] && (mask_16->mask_24[slot] 
363             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
364             printk("ACCOUT: Can't process packet because out of memory!\n");
365             return;
366         }
367
368         ipt_acc_depth0_insert((struct ipt_acc_mask_24 *)mask_16->mask_24[slot],
369                                   net_ip, netmask, 0, dst_ip, size, itemcount);
370     }
371 }
372
373 static void ipt_acc_depth2_insert(struct ipt_acc_mask_8 *mask_8, 
374                                u_int32_t net_ip, u_int32_t netmask,
375                                u_int32_t src_ip, u_int32_t dst_ip,
376                                u_int32_t size, u_int32_t *itemcount)
377 {
378     /* Do we need to process src IP? */
379     if ((net_ip&netmask) == (src_ip&netmask)) {
380         unsigned char slot = (unsigned char)((src_ip&0x0000FF00) >> 8);
381         DEBUGP("ACCOUNT: Calculated SRC 24 bit network slot: %d\n", slot);
382
383         /* Do we need to create a new mask_24 bucket? */
384         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
385             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
386             printk("ACCOUNT: Can't process packet because out of memory!\n");
387             return;
388         }
389
390         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
391                                   net_ip, netmask, src_ip, 0, size, itemcount);
392     }
393
394     /* Do we need to process dst IP? */
395     if ((net_ip&netmask) == (dst_ip&netmask)) {
396         unsigned char slot = (unsigned char)((dst_ip&0x0000FF00) >> 8);
397         DEBUGP("ACCOUNT: Calculated DST 24 bit network slot: %d\n", slot);
398
399         /* Do we need to create a new mask_24 bucket? */
400         if (!mask_8->mask_16[slot] && (mask_8->mask_16[slot] 
401             = (void *)get_zeroed_page(GFP_ATOMIC)) == NULL) {
402             printk("ACCOUNT: Can't process packet because out of memory!\n");
403             return;
404         }
405
406         ipt_acc_depth1_insert((struct ipt_acc_mask_16 *)mask_8->mask_16[slot],
407                                   net_ip, netmask, 0, dst_ip, size, itemcount);
408     }
409 }
410
411 static unsigned int ipt_acc_target(struct sk_buff **pskb,
412                                    const struct net_device *in,
413                                    const struct net_device *out,
414                                    unsigned int hooknum,
415                                    const struct xt_target *target,
416                                    const void *targinfo,
417                                    void *userinfo)
418 {
419     const struct ipt_acc_info *info = 
420         (const struct ipt_acc_info *)targinfo;
421     u_int32_t src_ip = (*pskb)->nh.iph->saddr;
422     u_int32_t dst_ip = (*pskb)->nh.iph->daddr;
423     u_int32_t size = ntohs((*pskb)->nh.iph->tot_len);
424
425     spin_lock_bh(&ipt_acc_lock);
426
427     if (ipt_acc_tables[info->table_nr].name[0] == 0) {
428         printk("ACCOUNT: ipt_acc_target: Invalid table id %u. "
429                "IPs %u.%u.%u.%u/%u.%u.%u.%u\n", info->table_nr, 
430                NIPQUAD(src_ip), NIPQUAD(dst_ip));
431         spin_unlock_bh(&ipt_acc_lock);
432         return IPT_CONTINUE;
433     }
434
435     /* 8 bit network or "any" network */
436     if (ipt_acc_tables[info->table_nr].depth == 0) {
437         /* Count packet and check if the IP is new */
438         ipt_acc_depth0_insert(
439             (struct ipt_acc_mask_24 *)ipt_acc_tables[info->table_nr].data,
440             ipt_acc_tables[info->table_nr].ip, 
441             ipt_acc_tables[info->table_nr].netmask,
442             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
443         spin_unlock_bh(&ipt_acc_lock);
444         return IPT_CONTINUE;
445     }
446
447     /* 16 bit network */
448     if (ipt_acc_tables[info->table_nr].depth == 1) {
449         ipt_acc_depth1_insert(
450             (struct ipt_acc_mask_16 *)ipt_acc_tables[info->table_nr].data,
451             ipt_acc_tables[info->table_nr].ip, 
452             ipt_acc_tables[info->table_nr].netmask,
453             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
454         spin_unlock_bh(&ipt_acc_lock);
455         return IPT_CONTINUE;
456     }
457
458     /* 24 bit network */
459     if (ipt_acc_tables[info->table_nr].depth == 2) {
460         ipt_acc_depth2_insert(
461             (struct ipt_acc_mask_8 *)ipt_acc_tables[info->table_nr].data,
462             ipt_acc_tables[info->table_nr].ip, 
463             ipt_acc_tables[info->table_nr].netmask,
464             src_ip, dst_ip, size, &ipt_acc_tables[info->table_nr].itemcount);
465         spin_unlock_bh(&ipt_acc_lock);
466         return IPT_CONTINUE;
467     }
468
469     printk("ACCOUNT: ipt_acc_target: Unable to process packet. "
470            "Table id %u. IPs %u.%u.%u.%u/%u.%u.%u.%u\n", 
471            info->table_nr, NIPQUAD(src_ip), NIPQUAD(dst_ip));
472
473     spin_unlock_bh(&ipt_acc_lock);
474     return IPT_CONTINUE;
475 }
476
477 /*
478     Functions dealing with "handles":
479     Handles are snapshots of a accounting state.
480     
481     read snapshots are only for debugging the code
482     and are very expensive concerning speed/memory
483     compared to read_and_flush.
484     
485     The functions aren't protected by spinlocks themselves
486     as this is done in the ioctl part of the code.
487 */
488
489 /*
490     Find a free handle slot. Normally only one should be used,
491     but there could be two or more applications accessing the data
492     at the same time.
493 */
494 static int ipt_acc_handle_find_slot(void)
495 {
496     u_int32_t i;
497     /* Insert new table */
498     for (i = 0; i < ACCOUNT_MAX_HANDLES; i++) {
499         /* Found free slot */
500         if (ipt_acc_handles[i].data == NULL) {
501             /* Don't "mark" data as used as we are protected by a spinlock 
502                by the calling function. handle_find_slot() is only a function
503                to prevent code duplication. */
504             return i;
505         }
506     }
507
508     /* No free slot found */
509     printk("ACCOUNT: No free handle slot found (max: %u). "
510            "Please increase ACCOUNT_MAX_HANDLES.\n", ACCOUNT_MAX_HANDLES);
511     return -1;
512 }
513
514 static int ipt_acc_handle_free(u_int32_t handle)
515 {
516     if (handle >= ACCOUNT_MAX_HANDLES) {
517         printk("ACCOUNT: Invalid handle for ipt_acc_handle_free() specified:"
518                " %u\n", handle);
519         return -EINVAL;
520     }
521
522     ipt_acc_data_free(ipt_acc_handles[handle].data, 
523                           ipt_acc_handles[handle].depth);
524     memset (&ipt_acc_handles[handle], 0, sizeof (struct ipt_acc_handle));
525     return 0;
526 }
527
528 /* Prepare data for read without flush. Use only for debugging!
529    Real applications should use read&flush as it's way more efficent */
530 static int ipt_acc_handle_prepare_read(char *tablename,
531          struct ipt_acc_handle *dest, u_int32_t *count)
532 {
533     int table_nr=-1;
534     unsigned char depth;
535
536     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
537         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
538             ACCOUNT_TABLE_NAME_LEN) == 0)
539                 break;
540
541     if (table_nr == ACCOUNT_MAX_TABLES) {
542         printk("ACCOUNT: ipt_acc_handle_prepare_read(): "
543                "Table %s not found\n", tablename);
544         return -1;
545     }
546
547     /* Fill up handle structure */
548     dest->ip = ipt_acc_tables[table_nr].ip;
549     dest->depth = ipt_acc_tables[table_nr].depth;
550     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
551
552     /* allocate "root" table */
553     if ((dest->data = (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
554         printk("ACCOUNT: out of memory for root table "
555                "in ipt_acc_handle_prepare_read()\n");
556         return -1;
557     }
558
559     /* Recursive copy of complete data structure */
560     depth = dest->depth;
561     if (depth == 0) {
562         memcpy(dest->data, 
563                ipt_acc_tables[table_nr].data, 
564                sizeof(struct ipt_acc_mask_24));
565     } else if (depth == 1) {
566         struct ipt_acc_mask_16 *src_16 = 
567             (struct ipt_acc_mask_16 *)ipt_acc_tables[table_nr].data;
568         struct ipt_acc_mask_16 *network_16 =
569             (struct ipt_acc_mask_16 *)dest->data;
570         u_int32_t b;
571
572         for (b = 0; b <= 255; b++) {
573             if (src_16->mask_24[b]) {
574                 if ((network_16->mask_24[b] = 
575                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
576                     printk("ACCOUNT: out of memory during copy of 16 bit "
577                            "network in ipt_acc_handle_prepare_read()\n");
578                     ipt_acc_data_free(dest->data, depth);
579                     return -1;
580                 }
581
582                 memcpy(network_16->mask_24[b], src_16->mask_24[b], 
583                        sizeof(struct ipt_acc_mask_24));
584             }
585         }
586     } else if(depth == 2) {
587         struct ipt_acc_mask_8 *src_8 = 
588             (struct ipt_acc_mask_8 *)ipt_acc_tables[table_nr].data;
589         struct ipt_acc_mask_8 *network_8 = 
590             (struct ipt_acc_mask_8 *)dest->data;
591         struct ipt_acc_mask_16 *src_16, *network_16;
592         u_int32_t a, b;
593
594         for (a = 0; a <= 255; a++) {
595             if (src_8->mask_16[a]) {
596                 if ((network_8->mask_16[a] = 
597                      (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
598                     printk("ACCOUNT: out of memory during copy of 24 bit network"
599                            " in ipt_acc_handle_prepare_read()\n");
600                     ipt_acc_data_free(dest->data, depth);
601                     return -1;
602                 }
603
604                 memcpy(network_8->mask_16[a], src_8->mask_16[a], 
605                        sizeof(struct ipt_acc_mask_16));
606
607                 src_16 = src_8->mask_16[a];
608                 network_16 = network_8->mask_16[a];
609
610                 for (b = 0; b <= 255; b++) {
611                     if (src_16->mask_24[b]) {
612                         if ((network_16->mask_24[b] = 
613                              (void*)get_zeroed_page(GFP_ATOMIC)) == NULL) {
614                             printk("ACCOUNT: out of memory during copy of 16 bit"
615                                    " network in ipt_acc_handle_prepare_read()\n");
616                             ipt_acc_data_free(dest->data, depth);
617                             return -1;
618                         }
619
620                         memcpy(network_16->mask_24[b], src_16->mask_24[b], 
621                                sizeof(struct ipt_acc_mask_24));
622                     }
623                 }
624             }
625         }
626     }
627
628     *count = ipt_acc_tables[table_nr].itemcount;
629     
630     return 0;
631 }
632
633 /* Prepare data for read and flush it */
634 static int ipt_acc_handle_prepare_read_flush(char *tablename,
635                struct ipt_acc_handle *dest, u_int32_t *count)
636 {
637     int table_nr;
638     void *new_data_page;
639
640     for (table_nr = 0; table_nr < ACCOUNT_MAX_TABLES; table_nr++)
641         if (strncmp(ipt_acc_tables[table_nr].name, tablename, 
642             ACCOUNT_TABLE_NAME_LEN) == 0)
643                 break;
644
645     if (table_nr == ACCOUNT_MAX_TABLES) {
646         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
647                "Table %s not found\n", tablename);
648         return -1;
649     }
650
651     /* Try to allocate memory */
652     if (!(new_data_page = (void*)get_zeroed_page(GFP_ATOMIC))) {
653         printk("ACCOUNT: ipt_acc_handle_prepare_read_flush(): "
654                "Out of memory!\n");
655         return -1;
656     }
657
658     /* Fill up handle structure */
659     dest->ip = ipt_acc_tables[table_nr].ip;
660     dest->depth = ipt_acc_tables[table_nr].depth;
661     dest->itemcount = ipt_acc_tables[table_nr].itemcount;
662     dest->data = ipt_acc_tables[table_nr].data;
663     *count = ipt_acc_tables[table_nr].itemcount;
664
665     /* "Flush" table data */
666     ipt_acc_tables[table_nr].data = new_data_page;
667     ipt_acc_tables[table_nr].itemcount = 0;
668
669     return 0;
670 }
671
672 /* Copy 8 bit network data into a prepared buffer.
673    We only copy entries != 0 to increase performance.
674 */
675 static int ipt_acc_handle_copy_data(void *to_user, u_int32_t *to_user_pos,
676                                   u_int32_t *tmpbuf_pos, 
677                                   struct ipt_acc_mask_24 *data,
678                                   u_int32_t net_ip, u_int32_t net_OR_mask)
679 {
680     struct ipt_acc_handle_ip handle_ip;
681     u_int32_t handle_ip_size = sizeof (struct ipt_acc_handle_ip);
682     u_int32_t i;
683     
684     for (i = 0; i <= 255; i++) {
685         if (data->ip[i].src_packets || data->ip[i].dst_packets) {
686             handle_ip.ip = net_ip | net_OR_mask | (i<<24);
687             
688             handle_ip.src_packets = data->ip[i].src_packets;
689             handle_ip.src_bytes = data->ip[i].src_bytes;
690             handle_ip.dst_packets = data->ip[i].dst_packets;
691             handle_ip.dst_bytes = data->ip[i].dst_bytes;
692
693             /* Temporary buffer full? Flush to userspace */
694             if (*tmpbuf_pos+handle_ip_size >= PAGE_SIZE) {
695                 if (copy_to_user(to_user + *to_user_pos, ipt_acc_tmpbuf,
696                                                            *tmpbuf_pos))
697                     return -EFAULT;
698                 *to_user_pos = *to_user_pos + *tmpbuf_pos;
699                 *tmpbuf_pos = 0;
700             }
701             memcpy(ipt_acc_tmpbuf+*tmpbuf_pos, &handle_ip, handle_ip_size);
702             *tmpbuf_pos += handle_ip_size;
703         }
704     }
705     
706     return 0;
707 }
708    
709 /* Copy the data from our internal structure 
710    We only copy entries != 0 to increase performance.
711    Overwrites ipt_acc_tmpbuf.
712 */
713 static int ipt_acc_handle_get_data(u_int32_t handle, void *to_user)
714 {
715     u_int32_t to_user_pos=0, tmpbuf_pos=0, net_ip;
716     unsigned char depth;
717
718     if (handle >= ACCOUNT_MAX_HANDLES) {
719         printk("ACCOUNT: invalid handle for ipt_acc_handle_get_data() "
720                "specified: %u\n", handle);
721         return -1;
722     }
723
724     if (ipt_acc_handles[handle].data == NULL) {
725         printk("ACCOUNT: handle %u is BROKEN: Contains no data\n", handle);
726         return -1;
727     }
728
729     net_ip = ipt_acc_handles[handle].ip;
730     depth = ipt_acc_handles[handle].depth;
731
732     /* 8 bit network */
733     if (depth == 0) {
734         struct ipt_acc_mask_24 *network = 
735             (struct ipt_acc_mask_24*)ipt_acc_handles[handle].data;
736         if (ipt_acc_handle_copy_data(to_user, &to_user_pos, &tmpbuf_pos,
737                                      network, net_ip, 0))
738             return -1;
739         
740         /* Flush remaining data to userspace */
741         if (tmpbuf_pos)
742             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
743                 return -1;
744
745         return 0;
746     }
747
748     /* 16 bit network */
749     if (depth == 1) {
750         struct ipt_acc_mask_16 *network_16 = 
751             (struct ipt_acc_mask_16*)ipt_acc_handles[handle].data;
752         u_int32_t b;
753         for (b = 0; b <= 255; b++) {
754             if (network_16->mask_24[b]) {
755                 struct ipt_acc_mask_24 *network = 
756                     (struct ipt_acc_mask_24*)network_16->mask_24[b];
757                 if (ipt_acc_handle_copy_data(to_user, &to_user_pos,
758                                       &tmpbuf_pos, network, net_ip, (b << 16)))
759                     return -1;
760             }
761         }
762
763         /* Flush remaining data to userspace */
764         if (tmpbuf_pos)
765             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
766                 return -1;
767
768         return 0;
769     }
770
771     /* 24 bit network */
772     if (depth == 2) {
773         struct ipt_acc_mask_8 *network_8 = 
774             (struct ipt_acc_mask_8*)ipt_acc_handles[handle].data;
775         u_int32_t a, b;
776         for (a = 0; a <= 255; a++) {
777             if (network_8->mask_16[a]) {
778                 struct ipt_acc_mask_16 *network_16 = 
779                     (struct ipt_acc_mask_16*)network_8->mask_16[a];
780                 for (b = 0; b <= 255; b++) {
781                     if (network_16->mask_24[b]) {
782                         struct ipt_acc_mask_24 *network = 
783                             (struct ipt_acc_mask_24*)network_16->mask_24[b];
784                         if (ipt_acc_handle_copy_data(to_user,
785                                        &to_user_pos, &tmpbuf_pos,
786                                        network, net_ip, (a << 8) | (b << 16)))
787                             return -1;
788                     }
789                 }
790             }
791         }
792
793         /* Flush remaining data to userspace */
794         if (tmpbuf_pos)
795             if (copy_to_user(to_user+to_user_pos, ipt_acc_tmpbuf, tmpbuf_pos))
796                 return -1;
797
798         return 0;
799     }
800     
801     return -1;
802 }
803
804 static int ipt_acc_set_ctl(struct sock *sk, int cmd, 
805                                void *user, u_int32_t len)
806 {
807     struct ipt_acc_handle_sockopt handle;
808     int ret = -EINVAL;
809
810     if (!capable(CAP_NET_ADMIN))
811         return -EPERM;
812
813     switch (cmd) {
814     case IPT_SO_SET_ACCOUNT_HANDLE_FREE:
815         if (len != sizeof(struct ipt_acc_handle_sockopt)) {
816             printk("ACCOUNT: ipt_acc_set_ctl: wrong data size (%u != %u) "
817                    "for IPT_SO_SET_HANDLE_FREE\n", 
818                    len, sizeof(struct ipt_acc_handle_sockopt));
819             break;
820         }
821
822         if (copy_from_user (&handle, user, len)) {
823             printk("ACCOUNT: ipt_acc_set_ctl: copy_from_user failed for "
824                    "IPT_SO_SET_HANDLE_FREE\n");
825             break;
826         }
827
828         down(&ipt_acc_userspace_mutex);
829         ret = ipt_acc_handle_free(handle.handle_nr);
830         up(&ipt_acc_userspace_mutex);
831         break;
832     case IPT_SO_SET_ACCOUNT_HANDLE_FREE_ALL: {
833             u_int32_t i;
834             down(&ipt_acc_userspace_mutex);
835             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
836                 ipt_acc_handle_free(i);
837             up(&ipt_acc_userspace_mutex);
838             ret = 0;
839             break;
840         }
841     default:
842         printk("ACCOUNT: ipt_acc_set_ctl: unknown request %i\n", cmd);
843     }
844
845     return ret;
846 }
847
848 static int ipt_acc_get_ctl(struct sock *sk, int cmd, void *user, int *len)
849 {
850     struct ipt_acc_handle_sockopt handle;
851     int ret = -EINVAL;
852
853     if (!capable(CAP_NET_ADMIN))
854         return -EPERM;
855
856     switch (cmd) {
857     case IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH:
858     case IPT_SO_GET_ACCOUNT_PREPARE_READ: {
859             struct ipt_acc_handle dest;
860
861             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
862                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u) "
863                     "for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
864                     *len, sizeof(struct ipt_acc_handle_sockopt));
865                 break;
866             }
867
868             if (copy_from_user (&handle, user, 
869                                 sizeof(struct ipt_acc_handle_sockopt))) {
870                 return -EFAULT;
871                 break;
872             }
873
874             spin_lock_bh(&ipt_acc_lock);
875             if (cmd == IPT_SO_GET_ACCOUNT_PREPARE_READ_FLUSH)
876                 ret = ipt_acc_handle_prepare_read_flush(
877                                     handle.name, &dest, &handle.itemcount);
878             else
879                 ret = ipt_acc_handle_prepare_read(
880                                     handle.name, &dest, &handle.itemcount);
881             spin_unlock_bh(&ipt_acc_lock);
882             // Error occured during prepare_read?
883            if (ret == -1)
884                 return -EINVAL;
885
886             /* Allocate a userspace handle */
887             down(&ipt_acc_userspace_mutex);
888             if ((handle.handle_nr = ipt_acc_handle_find_slot()) == -1) {
889                 ipt_acc_data_free(dest.data, dest.depth);
890                 up(&ipt_acc_userspace_mutex);
891                 return -EINVAL;
892             }
893             memcpy(&ipt_acc_handles[handle.handle_nr], &dest,
894                              sizeof(struct ipt_acc_handle));
895             up(&ipt_acc_userspace_mutex);
896
897             if (copy_to_user(user, &handle, 
898                             sizeof(struct ipt_acc_handle_sockopt))) {
899                 return -EFAULT;
900                 break;
901             }
902             ret = 0;
903             break;
904         }
905     case IPT_SO_GET_ACCOUNT_GET_DATA:
906         if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
907             printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u)"
908                    " for IPT_SO_GET_ACCOUNT_PREPARE_READ/READ_FLUSH\n",
909                    *len, sizeof(struct ipt_acc_handle_sockopt));
910             break;
911         }
912
913         if (copy_from_user (&handle, user, 
914                             sizeof(struct ipt_acc_handle_sockopt))) {
915             return -EFAULT;
916             break;
917         }
918
919         if (handle.handle_nr >= ACCOUNT_MAX_HANDLES) {
920             return -EINVAL;
921             break;
922         }
923
924         if (*len < ipt_acc_handles[handle.handle_nr].itemcount
925                    * sizeof(struct ipt_acc_handle_ip)) {
926             printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %u)"
927                    " to store data from IPT_SO_GET_ACCOUNT_GET_DATA\n",
928                    *len, ipt_acc_handles[handle.handle_nr].itemcount
929                    * sizeof(struct ipt_acc_handle_ip));
930             ret = -ENOMEM;
931             break;
932         }
933
934         down(&ipt_acc_userspace_mutex);
935         ret = ipt_acc_handle_get_data(handle.handle_nr, user);
936         up(&ipt_acc_userspace_mutex);
937         if (ret) {
938             printk("ACCOUNT: ipt_acc_get_ctl: ipt_acc_handle_get_data"
939                    " failed for handle %u\n", handle.handle_nr);
940             break;
941         }
942
943         ret = 0;
944         break;
945     case IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE: {
946             u_int32_t i;
947             if (*len < sizeof(struct ipt_acc_handle_sockopt)) {
948                 printk("ACCOUNT: ipt_acc_get_ctl: wrong data size (%u != %u)"
949                        " for IPT_SO_GET_ACCOUNT_GET_HANDLE_USAGE\n",
950                        *len, sizeof(struct ipt_acc_handle_sockopt));
951                 break;
952             }
953
954             /* Find out how many handles are in use */
955             handle.itemcount = 0;
956             down(&ipt_acc_userspace_mutex);
957             for (i = 0; i < ACCOUNT_MAX_HANDLES; i++)
958                 if (ipt_acc_handles[i].data)
959                     handle.itemcount++;
960             up(&ipt_acc_userspace_mutex);
961
962             if (copy_to_user(user, &handle, 
963                              sizeof(struct ipt_acc_handle_sockopt))) {
964                 return -EFAULT;
965                 break;
966             }
967             ret = 0;
968             break;
969         }
970     case IPT_SO_GET_ACCOUNT_GET_TABLE_NAMES: {
971             u_int32_t size = 0, i, name_len;
972             char *tnames;
973
974             spin_lock_bh(&ipt_acc_lock);
975
976             /* Determine size of table names */
977             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
978                 if (ipt_acc_tables[i].name[0] != 0)
979                     size += strlen (ipt_acc_tables[i].name) + 1;
980             }
981             size += 1;    /* Terminating NULL character */
982
983             if (*len < size || size > PAGE_SIZE) {
984                 spin_unlock_bh(&ipt_acc_lock);
985                 printk("ACCOUNT: ipt_acc_get_ctl: not enough space (%u < %u < %lu)"
986                        " to store table names\n", *len, size, PAGE_SIZE);
987                 ret = -ENOMEM;
988                 break;
989             }
990             /* Copy table names to userspace */
991             tnames = ipt_acc_tmpbuf;
992             for (i = 0; i < ACCOUNT_MAX_TABLES; i++) {
993                 if (ipt_acc_tables[i].name[0] != 0) {
994                     name_len = strlen (ipt_acc_tables[i].name) + 1;
995                     memcpy(tnames, ipt_acc_tables[i].name, name_len);
996                     tnames += name_len;
997                 }
998             }
999             spin_unlock_bh(&ipt_acc_lock);
1000
1001             /* Terminating NULL character */
1002             *tnames = 0;
1003
1004             /* Transfer to userspace */
1005             if (copy_to_user(user, ipt_acc_tmpbuf, size))
1006                 return -EFAULT;
1007
1008             ret = 0;
1009             break;
1010         }
1011     default:
1012         printk("ACCOUNT: ipt_acc_get_ctl: unknown request %i\n", cmd);
1013     }
1014
1015     return ret;
1016 }
1017
1018 static struct ipt_target ipt_acc_reg = {
1019     .name = "ACCOUNT",
1020     .target = ipt_acc_target,
1021     .targetsize = sizeof(struct ipt_acc_info),
1022     .checkentry = ipt_acc_checkentry,
1023     .destroy = ipt_acc_destroy,
1024     .me = THIS_MODULE
1025 };
1026
1027 static struct nf_sockopt_ops ipt_acc_sockopts = {
1028     .pf = PF_INET,
1029     .set_optmin = IPT_SO_SET_ACCOUNT_HANDLE_FREE,
1030     .set_optmax = IPT_SO_SET_ACCOUNT_MAX+1,
1031     .set = ipt_acc_set_ctl,
1032     .get_optmin = IPT_SO_GET_ACCOUNT_PREPARE_READ,
1033     .get_optmax = IPT_SO_GET_ACCOUNT_MAX+1,
1034     .get = ipt_acc_get_ctl
1035 };
1036
1037 static int __init init(void)
1038 {
1039     init_MUTEX(&ipt_acc_userspace_mutex);
1040
1041     if ((ipt_acc_tables = 
1042          kmalloc(ACCOUNT_MAX_TABLES * 
1043                  sizeof(struct ipt_acc_table), GFP_KERNEL)) == NULL) {
1044         printk("ACCOUNT: Out of memory allocating account_tables structure");
1045         goto error_cleanup;
1046     }
1047     memset(ipt_acc_tables, 0, 
1048            ACCOUNT_MAX_TABLES * sizeof(struct ipt_acc_table));
1049
1050     if ((ipt_acc_handles = 
1051          kmalloc(ACCOUNT_MAX_HANDLES * 
1052                  sizeof(struct ipt_acc_handle), GFP_KERNEL)) == NULL) {
1053         printk("ACCOUNT: Out of memory allocating account_handles structure");
1054         goto error_cleanup;
1055     }
1056     memset(ipt_acc_handles, 0, 
1057            ACCOUNT_MAX_HANDLES * sizeof(struct ipt_acc_handle));
1058
1059     /* Allocate one page as temporary storage */
1060     if ((ipt_acc_tmpbuf = (void*)__get_free_page(GFP_KERNEL)) == NULL) {
1061         printk("ACCOUNT: Out of memory for temporary buffer page\n");
1062         goto error_cleanup;
1063     }
1064
1065     /* Register setsockopt */
1066     if (nf_register_sockopt(&ipt_acc_sockopts) < 0) {
1067         printk("ACCOUNT: Can't register sockopts. Aborting\n");
1068         goto error_cleanup;
1069     }
1070
1071     if (ipt_register_target(&ipt_acc_reg))
1072         goto error_cleanup;
1073
1074     return 0;
1075
1076 error_cleanup:
1077     if(ipt_acc_tables)
1078         kfree(ipt_acc_tables);
1079     if(ipt_acc_handles)
1080         kfree(ipt_acc_handles);
1081     if (ipt_acc_tmpbuf)
1082         free_page((unsigned long)ipt_acc_tmpbuf);
1083
1084     return -EINVAL;
1085 }
1086
1087 static void __exit fini(void)
1088 {
1089     ipt_unregister_target(&ipt_acc_reg);
1090
1091     nf_unregister_sockopt(&ipt_acc_sockopts);
1092
1093     kfree(ipt_acc_tables);
1094     kfree(ipt_acc_handles);
1095     free_page((unsigned long)ipt_acc_tmpbuf);
1096 }
1097
1098 module_init(init);
1099 module_exit(fini);
1100 MODULE_LICENSE("GPL");