Fix some warnings that would be generated by gcc -Wredundant-decls
[qemu] / slirp / tftp.c
1 /*
2  * tftp.c - a simple, read-only tftp server for qemu
3  *
4  * Copyright (c) 2004 Magnus Damm <damm@opensource.se>
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  */
24
25 #include <slirp.h>
26 #include "qemu-common.h" // for pstrcpy
27
28 struct tftp_session {
29     int in_use;
30     unsigned char filename[TFTP_FILENAME_MAX];
31
32     struct in_addr client_ip;
33     u_int16_t client_port;
34
35     int timestamp;
36 };
37
38 static struct tftp_session tftp_sessions[TFTP_SESSIONS_MAX];
39
40 const char *tftp_prefix;
41
42 static void tftp_session_update(struct tftp_session *spt)
43 {
44     spt->timestamp = curtime;
45     spt->in_use = 1;
46 }
47
48 static void tftp_session_terminate(struct tftp_session *spt)
49 {
50   spt->in_use = 0;
51 }
52
53 static int tftp_session_allocate(struct tftp_t *tp)
54 {
55   struct tftp_session *spt;
56   int k;
57
58   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
59     spt = &tftp_sessions[k];
60
61     if (!spt->in_use)
62         goto found;
63
64     /* sessions time out after 5 inactive seconds */
65     if ((int)(curtime - spt->timestamp) > 5000)
66         goto found;
67   }
68
69   return -1;
70
71  found:
72   memset(spt, 0, sizeof(*spt));
73   memcpy(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip));
74   spt->client_port = tp->udp.uh_sport;
75
76   tftp_session_update(spt);
77
78   return k;
79 }
80
81 static int tftp_session_find(struct tftp_t *tp)
82 {
83   struct tftp_session *spt;
84   int k;
85
86   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
87     spt = &tftp_sessions[k];
88
89     if (spt->in_use) {
90       if (!memcmp(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip))) {
91         if (spt->client_port == tp->udp.uh_sport) {
92           return k;
93         }
94       }
95     }
96   }
97
98   return -1;
99 }
100
101 static int tftp_read_data(struct tftp_session *spt, u_int16_t block_nr,
102                           u_int8_t *buf, int len)
103 {
104   int fd;
105   int bytes_read = 0;
106   char buffer[1024];
107   int n;
108
109   n = snprintf(buffer, sizeof(buffer), "%s/%s",
110                tftp_prefix, spt->filename);
111   if (n >= sizeof(buffer))
112     return -1;
113
114   fd = open(buffer, O_RDONLY | O_BINARY);
115
116   if (fd < 0) {
117     return -1;
118   }
119
120   if (len) {
121     lseek(fd, block_nr * 512, SEEK_SET);
122
123     bytes_read = read(fd, buf, len);
124   }
125
126   close(fd);
127
128   return bytes_read;
129 }
130
131 static int tftp_send_oack(struct tftp_session *spt,
132                           const char *key, uint32_t value,
133                           struct tftp_t *recv_tp)
134 {
135     struct sockaddr_in saddr, daddr;
136     struct mbuf *m;
137     struct tftp_t *tp;
138     int n = 0;
139
140     m = m_get();
141
142     if (!m)
143         return -1;
144
145     memset(m->m_data, 0, m->m_size);
146
147     m->m_data += IF_MAXLINKHDR;
148     tp = (void *)m->m_data;
149     m->m_data += sizeof(struct udpiphdr);
150
151     tp->tp_op = htons(TFTP_OACK);
152     n += snprintf(tp->x.tp_buf + n, sizeof(tp->x.tp_buf) - n, "%s", key) + 1;
153     n += snprintf(tp->x.tp_buf + n, sizeof(tp->x.tp_buf) - n, "%u", value) + 1;
154
155     saddr.sin_addr = recv_tp->ip.ip_dst;
156     saddr.sin_port = recv_tp->udp.uh_dport;
157
158     daddr.sin_addr = spt->client_ip;
159     daddr.sin_port = spt->client_port;
160
161     m->m_len = sizeof(struct tftp_t) - 514 + n -
162         sizeof(struct ip) - sizeof(struct udphdr);
163     udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
164
165     return 0;
166 }
167
168
169
170 static int tftp_send_error(struct tftp_session *spt,
171                            u_int16_t errorcode, const char *msg,
172                            struct tftp_t *recv_tp)
173 {
174   struct sockaddr_in saddr, daddr;
175   struct mbuf *m;
176   struct tftp_t *tp;
177   int nobytes;
178
179   m = m_get();
180
181   if (!m) {
182     return -1;
183   }
184
185   memset(m->m_data, 0, m->m_size);
186
187   m->m_data += IF_MAXLINKHDR;
188   tp = (void *)m->m_data;
189   m->m_data += sizeof(struct udpiphdr);
190
191   tp->tp_op = htons(TFTP_ERROR);
192   tp->x.tp_error.tp_error_code = htons(errorcode);
193   pstrcpy(tp->x.tp_error.tp_msg, sizeof(tp->x.tp_error.tp_msg), msg);
194
195   saddr.sin_addr = recv_tp->ip.ip_dst;
196   saddr.sin_port = recv_tp->udp.uh_dport;
197
198   daddr.sin_addr = spt->client_ip;
199   daddr.sin_port = spt->client_port;
200
201   nobytes = 2;
202
203   m->m_len = sizeof(struct tftp_t) - 514 + 3 + strlen(msg) -
204         sizeof(struct ip) - sizeof(struct udphdr);
205
206   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
207
208   tftp_session_terminate(spt);
209
210   return 0;
211 }
212
213 static int tftp_send_data(struct tftp_session *spt,
214                           u_int16_t block_nr,
215                           struct tftp_t *recv_tp)
216 {
217   struct sockaddr_in saddr, daddr;
218   struct mbuf *m;
219   struct tftp_t *tp;
220   int nobytes;
221
222   if (block_nr < 1) {
223     return -1;
224   }
225
226   m = m_get();
227
228   if (!m) {
229     return -1;
230   }
231
232   memset(m->m_data, 0, m->m_size);
233
234   m->m_data += IF_MAXLINKHDR;
235   tp = (void *)m->m_data;
236   m->m_data += sizeof(struct udpiphdr);
237
238   tp->tp_op = htons(TFTP_DATA);
239   tp->x.tp_data.tp_block_nr = htons(block_nr);
240
241   saddr.sin_addr = recv_tp->ip.ip_dst;
242   saddr.sin_port = recv_tp->udp.uh_dport;
243
244   daddr.sin_addr = spt->client_ip;
245   daddr.sin_port = spt->client_port;
246
247   nobytes = tftp_read_data(spt, block_nr - 1, tp->x.tp_data.tp_buf, 512);
248
249   if (nobytes < 0) {
250     m_free(m);
251
252     /* send "file not found" error back */
253
254     tftp_send_error(spt, 1, "File not found", tp);
255
256     return -1;
257   }
258
259   m->m_len = sizeof(struct tftp_t) - (512 - nobytes) -
260         sizeof(struct ip) - sizeof(struct udphdr);
261
262   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
263
264   if (nobytes == 512) {
265     tftp_session_update(spt);
266   }
267   else {
268     tftp_session_terminate(spt);
269   }
270
271   return 0;
272 }
273
274 static void tftp_handle_rrq(struct tftp_t *tp, int pktlen)
275 {
276   struct tftp_session *spt;
277   int s, k, n;
278   u_int8_t *src, *dst;
279
280   s = tftp_session_allocate(tp);
281
282   if (s < 0) {
283     return;
284   }
285
286   spt = &tftp_sessions[s];
287
288   src = tp->x.tp_buf;
289   dst = spt->filename;
290   n = pktlen - ((uint8_t *)&tp->x.tp_buf[0] - (uint8_t *)tp);
291
292   /* get name */
293
294   for (k = 0; k < n; k++) {
295     if (k < TFTP_FILENAME_MAX) {
296       dst[k] = src[k];
297     }
298     else {
299       return;
300     }
301
302     if (src[k] == '\0') {
303       break;
304     }
305   }
306
307   if (k >= n) {
308     return;
309   }
310
311   k++;
312
313   /* check mode */
314   if ((n - k) < 6) {
315     return;
316   }
317
318   if (memcmp(&src[k], "octet\0", 6) != 0) {
319       tftp_send_error(spt, 4, "Unsupported transfer mode", tp);
320       return;
321   }
322
323   k += 6; /* skipping octet */
324
325   /* do sanity checks on the filename */
326
327   if ((spt->filename[0] != '/')
328       || (spt->filename[strlen(spt->filename) - 1] == '/')
329       ||  strstr(spt->filename, "/../")) {
330       tftp_send_error(spt, 2, "Access violation", tp);
331       return;
332   }
333
334   /* only allow exported prefixes */
335
336   if (!tftp_prefix) {
337       tftp_send_error(spt, 2, "Access violation", tp);
338       return;
339   }
340
341   /* check if the file exists */
342
343   if (tftp_read_data(spt, 0, spt->filename, 0) < 0) {
344       tftp_send_error(spt, 1, "File not found", tp);
345       return;
346   }
347
348   if (src[n - 1] != 0) {
349       tftp_send_error(spt, 2, "Access violation", tp);
350       return;
351   }
352
353   while (k < n) {
354       const char *key, *value;
355
356       key = src + k;
357       k += strlen(key) + 1;
358
359       if (k >= n) {
360           tftp_send_error(spt, 2, "Access violation", tp);
361           return;
362       }
363
364       value = src + k;
365       k += strlen(value) + 1;
366
367       if (strcmp(key, "tsize") == 0) {
368           int tsize = atoi(value);
369           struct stat stat_p;
370
371           if (tsize == 0 && tftp_prefix) {
372               char buffer[1024];
373               int len;
374
375               len = snprintf(buffer, sizeof(buffer), "%s/%s",
376                              tftp_prefix, spt->filename);
377
378               if (stat(buffer, &stat_p) == 0)
379                   tsize = stat_p.st_size;
380               else {
381                   tftp_send_error(spt, 1, "File not found", tp);
382                   return;
383               }
384           }
385
386           tftp_send_oack(spt, "tsize", tsize, tp);
387       }
388   }
389
390   tftp_send_data(spt, 1, tp);
391 }
392
393 static void tftp_handle_ack(struct tftp_t *tp, int pktlen)
394 {
395   int s;
396
397   s = tftp_session_find(tp);
398
399   if (s < 0) {
400     return;
401   }
402
403   if (tftp_send_data(&tftp_sessions[s],
404                      ntohs(tp->x.tp_data.tp_block_nr) + 1,
405                      tp) < 0) {
406     return;
407   }
408 }
409
410 void tftp_input(struct mbuf *m)
411 {
412   struct tftp_t *tp = (struct tftp_t *)m->m_data;
413
414   switch(ntohs(tp->tp_op)) {
415   case TFTP_RRQ:
416     tftp_handle_rrq(tp, m->m_len);
417     break;
418
419   case TFTP_ACK:
420     tftp_handle_ack(tp, m->m_len);
421     break;
422   }
423 }