b3947eb38ae20489e650a71f581a7bf13fef9ea9
[qemu] / slirp / tftp.c
1 /*
2  * tftp.c - a simple, read-only tftp server for qemu
3  * 
4  * Copyright (c) 2004 Magnus Damm <damm@opensource.se>
5  * 
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  */
24
25 #include <slirp.h>
26
27 struct tftp_session {
28     int in_use;
29     unsigned char filename[TFTP_FILENAME_MAX];
30     
31     struct in_addr client_ip;
32     u_int16_t client_port;
33     
34     int timestamp;
35 };
36
37 struct tftp_session tftp_sessions[TFTP_SESSIONS_MAX];
38
39 const char *tftp_prefix;
40
41 static void tftp_session_update(struct tftp_session *spt)
42 {
43     spt->timestamp = curtime;
44     spt->in_use = 1;
45 }
46
47 static void tftp_session_terminate(struct tftp_session *spt)
48 {
49   spt->in_use = 0;
50 }
51
52 static int tftp_session_allocate(struct tftp_t *tp)
53 {
54   struct tftp_session *spt;
55   int k;
56
57   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
58     spt = &tftp_sessions[k];
59
60     if (!spt->in_use)
61         goto found;
62
63     /* sessions time out after 5 inactive seconds */
64     if ((int)(curtime - spt->timestamp) > 5000)
65         goto found;
66   }
67
68   return -1;
69
70  found:
71   memset(spt, 0, sizeof(*spt));
72   memcpy(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip));
73   spt->client_port = tp->udp.uh_sport;
74
75   tftp_session_update(spt);
76
77   return k;
78 }
79
80 static int tftp_session_find(struct tftp_t *tp)
81 {
82   struct tftp_session *spt;
83   int k;
84
85   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
86     spt = &tftp_sessions[k];
87
88     if (spt->in_use) {
89       if (!memcmp(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip))) {
90         if (spt->client_port == tp->udp.uh_sport) {
91           return k;
92         }
93       }
94     }
95   }
96
97   return -1;
98 }
99
100 static int tftp_read_data(struct tftp_session *spt, u_int16_t block_nr,
101                           u_int8_t *buf, int len)
102 {
103   int fd;
104   int bytes_read = 0;
105
106   fd = open(spt->filename, O_RDONLY | O_BINARY);
107
108   if (fd < 0) {
109     return -1;
110   }
111
112   if (len) {
113     lseek(fd, block_nr * 512, SEEK_SET);
114
115     bytes_read = read(fd, buf, len);
116   }
117
118   close(fd);
119
120   return bytes_read;
121 }
122
123 static int tftp_send_oack(struct tftp_session *spt, 
124                           const char *key, uint32_t value,
125                           struct tftp_t *recv_tp)
126 {
127     struct sockaddr_in saddr, daddr;
128     struct mbuf *m;
129     struct tftp_t *tp;
130     int n = 0;
131
132     m = m_get();
133
134     if (!m)
135         return -1;
136
137     memset(m->m_data, 0, m->m_size);
138
139     m->m_data += if_maxlinkhdr;
140     tp = (void *)m->m_data;
141     m->m_data += sizeof(struct udpiphdr);
142     
143     tp->tp_op = htons(TFTP_OACK);
144     n += sprintf(tp->x.tp_buf + n, "%s", key) + 1;
145     n += sprintf(tp->x.tp_buf + n, "%u", value) + 1;
146
147     saddr.sin_addr = recv_tp->ip.ip_dst;
148     saddr.sin_port = recv_tp->udp.uh_dport;
149     
150     daddr.sin_addr = spt->client_ip;
151     daddr.sin_port = spt->client_port;
152
153     m->m_len = sizeof(struct tftp_t) - 514 + n - 
154         sizeof(struct ip) - sizeof(struct udphdr);
155     udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
156
157     return 0;
158 }
159
160
161
162 static int tftp_send_error(struct tftp_session *spt, 
163                            u_int16_t errorcode, const char *msg,
164                            struct tftp_t *recv_tp)
165 {
166   struct sockaddr_in saddr, daddr;
167   struct mbuf *m;
168   struct tftp_t *tp;
169   int nobytes;
170
171   m = m_get();
172
173   if (!m) {
174     return -1;
175   }
176
177   memset(m->m_data, 0, m->m_size);
178
179   m->m_data += if_maxlinkhdr;
180   tp = (void *)m->m_data;
181   m->m_data += sizeof(struct udpiphdr);
182   
183   tp->tp_op = htons(TFTP_ERROR);
184   tp->x.tp_error.tp_error_code = htons(errorcode);
185   strcpy(tp->x.tp_error.tp_msg, msg);
186
187   saddr.sin_addr = recv_tp->ip.ip_dst;
188   saddr.sin_port = recv_tp->udp.uh_dport;
189
190   daddr.sin_addr = spt->client_ip;
191   daddr.sin_port = spt->client_port;
192
193   nobytes = 2;
194
195   m->m_len = sizeof(struct tftp_t) - 514 + 3 + strlen(msg) - 
196         sizeof(struct ip) - sizeof(struct udphdr);
197
198   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
199
200   tftp_session_terminate(spt);
201
202   return 0;
203 }
204
205 static int tftp_send_data(struct tftp_session *spt, 
206                           u_int16_t block_nr,
207                           struct tftp_t *recv_tp)
208 {
209   struct sockaddr_in saddr, daddr;
210   struct mbuf *m;
211   struct tftp_t *tp;
212   int nobytes;
213
214   if (block_nr < 1) {
215     return -1;
216   }
217
218   m = m_get();
219
220   if (!m) {
221     return -1;
222   }
223
224   memset(m->m_data, 0, m->m_size);
225
226   m->m_data += if_maxlinkhdr;
227   tp = (void *)m->m_data;
228   m->m_data += sizeof(struct udpiphdr);
229   
230   tp->tp_op = htons(TFTP_DATA);
231   tp->x.tp_data.tp_block_nr = htons(block_nr);
232
233   saddr.sin_addr = recv_tp->ip.ip_dst;
234   saddr.sin_port = recv_tp->udp.uh_dport;
235
236   daddr.sin_addr = spt->client_ip;
237   daddr.sin_port = spt->client_port;
238
239   nobytes = tftp_read_data(spt, block_nr - 1, tp->x.tp_data.tp_buf, 512);
240
241   if (nobytes < 0) {
242     m_free(m);
243
244     /* send "file not found" error back */
245
246     tftp_send_error(spt, 1, "File not found", tp);
247
248     return -1;
249   }
250
251   m->m_len = sizeof(struct tftp_t) - (512 - nobytes) - 
252         sizeof(struct ip) - sizeof(struct udphdr);
253
254   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
255
256   if (nobytes == 512) {
257     tftp_session_update(spt);
258   }
259   else {
260     tftp_session_terminate(spt);
261   }
262
263   return 0;
264 }
265
266 static void tftp_handle_rrq(struct tftp_t *tp, int pktlen)
267 {
268   struct tftp_session *spt;
269   int s, k, n;
270   u_int8_t *src, *dst;
271
272   s = tftp_session_allocate(tp);
273
274   if (s < 0) {
275     return;
276   }
277
278   spt = &tftp_sessions[s];
279
280   src = tp->x.tp_buf;
281   dst = spt->filename;
282   n = pktlen - ((uint8_t *)&tp->x.tp_buf[0] - (uint8_t *)tp);
283
284   /* get name */
285
286   for (k = 0; k < n; k++) {
287     if (k < TFTP_FILENAME_MAX) {
288       dst[k] = src[k];
289     }
290     else {
291       return;
292     }
293     
294     if (src[k] == '\0') {
295       break;
296     }
297   }
298       
299   if (k >= n) {
300     return;
301   }
302   
303   k++;
304   
305   /* check mode */
306   if ((n - k) < 6) {
307     return;
308   }
309   
310   if (memcmp(&src[k], "octet\0", 6) != 0) {
311       tftp_send_error(spt, 4, "Unsupported transfer mode", tp);
312       return;
313   }
314
315   k += 6; /* skipping octet */
316
317   /* do sanity checks on the filename */
318
319   if ((spt->filename[0] != '/')
320       || (spt->filename[strlen(spt->filename) - 1] == '/')
321       ||  strstr(spt->filename, "/../")) {
322       tftp_send_error(spt, 2, "Access violation", tp);
323       return;
324   }
325
326   /* only allow exported prefixes */
327
328   if (!tftp_prefix
329       || (strncmp(spt->filename, tftp_prefix, strlen(tftp_prefix)) != 0)) {
330       tftp_send_error(spt, 2, "Access violation", tp);
331       return;
332   }
333
334   /* check if the file exists */
335   
336   if (tftp_read_data(spt, 0, spt->filename, 0) < 0) {
337       tftp_send_error(spt, 1, "File not found", tp);
338       return;
339   }
340
341   if (src[n - 1] != 0) {
342       tftp_send_error(spt, 2, "Access violation", tp);
343       return;
344   }
345
346   while (k < n) {
347       const char *key, *value;
348
349       key = src + k;
350       k += strlen(key) + 1;
351
352       if (k >= n) {
353           tftp_send_error(spt, 2, "Access violation", tp);
354           return;
355       }
356
357       value = src + k;
358       k += strlen(value) + 1;
359
360       if (strcmp(key, "tsize") == 0) {
361           int tsize = atoi(value);
362           struct stat stat_p;
363
364           if (tsize == 0 && tftp_prefix) {
365               char buffer[1024];
366               int len;
367
368               len = snprintf(buffer, sizeof(buffer), "%s/%s",
369                              tftp_prefix, spt->filename);
370
371               if (stat(buffer, &stat_p) == 0)
372                   tsize = stat_p.st_size;
373               else {
374                   tftp_send_error(spt, 1, "File not found", tp);
375                   return;
376               }
377           }
378
379           tftp_send_oack(spt, "tsize", tsize, tp);
380       }
381   }
382
383   tftp_send_data(spt, 1, tp);
384 }
385
386 static void tftp_handle_ack(struct tftp_t *tp, int pktlen)
387 {
388   int s;
389
390   s = tftp_session_find(tp);
391
392   if (s < 0) {
393     return;
394   }
395
396   if (tftp_send_data(&tftp_sessions[s], 
397                      ntohs(tp->x.tp_data.tp_block_nr) + 1, 
398                      tp) < 0) {
399     return;
400   }
401 }
402
403 void tftp_input(struct mbuf *m)
404 {
405   struct tftp_t *tp = (struct tftp_t *)m->m_data;
406
407   switch(ntohs(tp->tp_op)) {
408   case TFTP_RRQ:
409     tftp_handle_rrq(tp, m->m_len);
410     break;
411
412   case TFTP_ACK:
413     tftp_handle_ack(tp, m->m_len);
414     break;
415   }
416 }