iptables: (tomj) small header, added ipt_account.c here until we have patch-o-matic
[ipt_ACCOUNT] / linux / net / ipv4 / netfilter / ipt_ACCOUNT.c
1 /*
2  * This is a module which is used for counting packets.
3  */
4 #include <linux/module.h>
5 #include <linux/skbuff.h>
6 #include <linux/ip.h>
7 #include <linux/spinlock.h>
8 #include <net/icmp.h>
9 #include <net/udp.h>
10 #include <net/tcp.h>
11 #include <linux/netfilter_ipv4/ip_tables.h>
12
13 struct in_device;
14 #include <net/route.h>
15 #include <linux/netfilter_ipv4/ipt_ACCOUNT.h>
16
17 //#if 0
18 #define DEBUGP printk
19 //#else
20 //#define DEBUGP(format, args...)
21 //#endif
22
23 struct ipt_account_table *ipt_account_tables = NULL;
24 struct ipt_account_handle *ipt_account_handles = NULL;
25
26 static spinlock_t ipt_account_lock = SPIN_LOCK_UNLOCKED;
27
28 static unsigned int
29 ipt_account_target(struct sk_buff **pskb,
30     unsigned int hooknum,
31     const struct net_device *in,
32     const struct net_device *out,
33     const void *targinfo,
34     void *userinfo)
35 {
36     spin_lock_bh(&ipt_account_lock);
37     spin_unlock_bh(&ipt_account_lock);
38
39     return IPT_CONTINUE;
40 }
41
42 /* Recursive free of all data structures */
43 void ipt_account_data_free(void *data, unsigned char netsize)
44 {
45     // Empty data set
46     if (!data)
47         return;
48         
49     // Free for 8 bit network. Special: 0.0.0.0/0
50     if (netsize >= 24 || netsize == 0)
51     {
52         kfree(data);
53         data = NULL;
54         return;
55     }
56     
57     // Free for 16 bit network
58     if (netsize >= 16)
59     {
60         struct ipt_account_mask_16 *mask_16 = (struct ipt_account_mask_16 *)data;
61         unsigned char b;
62         for (b=0; b < 255; b++)
63         {
64             if (mask_16->mask_24[b] != 0)
65             {
66                 kfree(mask_16->mask_24[b]);
67                 mask_16->mask_24[b] = NULL;
68             }
69         }
70         kfree(data);
71         data = NULL;
72         return;
73     } 
74    
75     // Free for 24 bit network
76     if (netsize >= 8)
77     {
78         unsigned char a, b;
79         for (a=0; a < 255; a++)
80         {
81             if (((struct ipt_account_mask_8 *)data)->mask_16[a])
82             {
83                 struct ipt_account_mask_16 *mask_16 = (struct ipt_account_mask_16*)((struct ipt_account_mask_8 *)data)->mask_16[a];
84                 for (b=0; b < 255; b++)
85                 {
86                     if (mask_16->mask_24[b]) {
87                         kfree(mask_16->mask_24[b]);
88                         mask_16->mask_24[b] = NULL;
89                     }
90                 }
91                 kfree(mask_16);
92                 mask_16 = NULL;
93             }
94         }
95         kfree(data);
96         data = NULL;
97         return;
98     }
99     
100     printk("ACCOUNT: ipt_account_data_free called with broken netsize: %d\n", netsize);
101     return;
102 }
103
104 /* Look for existing table / insert new one. Return internal ID or -1 on error */
105 int ipt_account_table_insert(char *name, unsigned int ip, unsigned int netmask)
106 {
107     unsigned int i;
108
109     DEBUGP("ACCOUNT: ipt_account_table_insert: %s, %u/%u\n", name, ip, netmask);
110
111     // Look for existing table
112     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
113     {
114         if (strcmp(ipt_account_tables[i].name, name) == 0)
115         {
116             DEBUGP("ACCOUT: Found existing slot: %d - %u/%u\n", i, ipt_account_tables[i].ip, ipt_account_tables[i].netmask);
117             
118             if (ipt_account_tables[i].ip != ip || ipt_account_tables[i].netmask != netmask)
119             {
120                 printk("ACCOUNT: Table %s found, but IP/netmask mismatch. IP/netmask found: %u/%u\n",
121                         name, ipt_account_tables[i].ip, ipt_account_tables[i].netmask);
122                 return -1;
123             }
124
125             ipt_account_tables[i].refcount++;
126             DEBUGP("ACCOUNT: Refcount: %d\n", ipt_account_tables[i].refcount);
127             return i;
128         }
129     }
130
131     // Insert new table
132     for (i = 0; i < ACCOUNT_MAX_TABLES; i++)
133     {
134         // Found free slot
135         if (ipt_account_tables[i].name[0] == 0)
136         {
137             DEBUGP("ACCOUNT: Found free slot: %d\n", i);
138         
139             strncpy (ipt_account_tables[i].name, name, ACCOUNT_TABLE_NAME_LEN-1);
140             
141             ipt_account_tables[i].ip = ip;
142             ipt_account_tables[i].netmask = netmask;
143             
144             // calculate netsize
145             unsigned int j, calc_mask;
146             calc_mask = htonl(netmask);
147             for (j = 31; j > 0; j--)
148             {
149                 if (calc_mask&(1<<j))
150                     ipt_account_tables[i].netsize++;
151                 else
152                     break;
153             }
154             printk("ACCOUNT: calculated netsize: %u\n", ipt_account_tables[i].netsize);
155                         
156             ipt_account_tables[i].refcount++;
157             if (!(ipt_account_tables[i].data = (void *)get_zeroed_page(GFP_KERNEL)))
158             {
159                 printk("ACCOUNT: Out of memory for data of table: %s\n", name);
160                 memset(&ipt_account_tables[i], 0, sizeof(struct ipt_account_table));
161                 return -1;
162             }
163             return i;
164         }
165     }
166         
167     // No free slot found
168     printk("ACCOUNT: No free table slot found (max: %d). Please increase ACCOUNT_MAX_TABLES.\n", ACCOUNT_MAX_TABLES);
169     return -1;
170 }
171
172 static int ipt_account_checkentry(const char *tablename,
173     const struct ipt_entry *e,
174     void *targinfo,
175     unsigned int targinfosize,
176     unsigned int hook_mask)
177 {
178     struct ipt_account_info *info = targinfo;
179     
180     if (targinfosize != IPT_ALIGN(sizeof(struct ipt_account_info))) {
181         DEBUGP("ACCOUNT: targinfosize %u != %u\n",
182                 targinfosize, IPT_ALIGN(sizeof(struct ipt_account_info)));
183         return 0;
184     }
185
186     int table_nr = ipt_account_table_insert(info->table_name, info->net_ip, info->net_mask);
187     if (table_nr == -1)
188     {
189         printk("ACCOUNT: Table insert problem. Aborting\n");
190         return 0;
191     }
192
193     // Table nr caching so we don't have to do an extra string compare for every packet
194     info->table_nr = table_nr;
195     
196     return 1;
197 }
198
199 static struct ipt_target ipt_account_reg
200 = { { NULL, NULL }, "ACCOUNT", ipt_account_target, ipt_account_checkentry, NULL, 
201     THIS_MODULE };
202
203 static int __init init(void)
204 {
205     if (!(ipt_account_tables = kmalloc(ACCOUNT_MAX_TABLES, sizeof(struct ipt_account_table))))
206     {
207             printk("ACCOUNT: Out of memory allocating account_tables structure");
208             return -EINVAL;
209     }
210     memset(ipt_account_tables, 0, sizeof(struct ipt_account_table));
211                             
212     if (!(ipt_account_handles = kmalloc(ACCOUNT_MAX_HANDLES, sizeof(struct ipt_account_handle))))
213     {
214             printk("ACCOUNT: Out of memory allocating account_handles structure");
215             kfree (ipt_account_tables);
216             ipt_account_tables = NULL;
217             return -EINVAL;
218     }
219     memset(ipt_account_handles, 0, sizeof(struct ipt_account_handle));
220     
221     if (ipt_register_target(&ipt_account_reg))
222             return -EINVAL;
223
224     return 0;
225 }
226
227 static void __exit fini(void)
228 {
229     ipt_unregister_target(&ipt_account_reg);
230
231     kfree(ipt_account_tables);
232     ipt_account_tables = NULL;
233         
234     kfree(ipt_account_handles);
235     ipt_account_handles = NULL;
236 }
237
238 module_init(init);
239 module_exit(fini);
240 MODULE_LICENSE("GPL");