Add #defines needed by OpenSolaris, fix breakage by the #defines
[qemu] / slirp / tftp.c
1 /*
2  * tftp.c - a simple, read-only tftp server for qemu
3  *
4  * Copyright (c) 2004 Magnus Damm <damm@opensource.se>
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  */
24
25 #include <slirp.h>
26 #include "qemu-common.h"
27
28 static inline int tftp_session_in_use(struct tftp_session *spt)
29 {
30     return (spt->slirp != NULL);
31 }
32
33 static inline void tftp_session_update(struct tftp_session *spt)
34 {
35     spt->timestamp = curtime;
36 }
37
38 static void tftp_session_terminate(struct tftp_session *spt)
39 {
40     qemu_free(spt->filename);
41     spt->slirp = NULL;
42 }
43
44 static int tftp_session_allocate(Slirp *slirp, struct tftp_t *tp)
45 {
46   struct tftp_session *spt;
47   int k;
48
49   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
50     spt = &slirp->tftp_sessions[k];
51
52     if (!tftp_session_in_use(spt))
53         goto found;
54
55     /* sessions time out after 5 inactive seconds */
56     if ((int)(curtime - spt->timestamp) > 5000) {
57         qemu_free(spt->filename);
58         goto found;
59     }
60   }
61
62   return -1;
63
64  found:
65   memset(spt, 0, sizeof(*spt));
66   memcpy(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip));
67   spt->client_port = tp->udp.uh_sport;
68   spt->slirp = slirp;
69
70   tftp_session_update(spt);
71
72   return k;
73 }
74
75 static int tftp_session_find(Slirp *slirp, struct tftp_t *tp)
76 {
77   struct tftp_session *spt;
78   int k;
79
80   for (k = 0; k < TFTP_SESSIONS_MAX; k++) {
81     spt = &slirp->tftp_sessions[k];
82
83     if (tftp_session_in_use(spt)) {
84       if (!memcmp(&spt->client_ip, &tp->ip.ip_src, sizeof(spt->client_ip))) {
85         if (spt->client_port == tp->udp.uh_sport) {
86           return k;
87         }
88       }
89     }
90   }
91
92   return -1;
93 }
94
95 static int tftp_read_data(struct tftp_session *spt, u_int16_t block_nr,
96                           u_int8_t *buf, int len)
97 {
98   int fd;
99   int bytes_read = 0;
100
101   fd = open(spt->filename, O_RDONLY | O_BINARY);
102
103   if (fd < 0) {
104     return -1;
105   }
106
107   if (len) {
108     lseek(fd, block_nr * 512, SEEK_SET);
109
110     bytes_read = read(fd, buf, len);
111   }
112
113   close(fd);
114
115   return bytes_read;
116 }
117
118 static int tftp_send_oack(struct tftp_session *spt,
119                           const char *key, uint32_t value,
120                           struct tftp_t *recv_tp)
121 {
122     struct sockaddr_in saddr, daddr;
123     struct mbuf *m;
124     struct tftp_t *tp;
125     int n = 0;
126
127     m = m_get(spt->slirp);
128
129     if (!m)
130         return -1;
131
132     memset(m->m_data, 0, m->m_size);
133
134     m->m_data += IF_MAXLINKHDR;
135     tp = (void *)m->m_data;
136     m->m_data += sizeof(struct udpiphdr);
137
138     tp->tp_op = htons(TFTP_OACK);
139     n += snprintf((char *)tp->x.tp_buf + n, sizeof(tp->x.tp_buf) - n, "%s",
140                   key) + 1;
141     n += snprintf((char *)tp->x.tp_buf + n, sizeof(tp->x.tp_buf) - n, "%u",
142                   value) + 1;
143
144     saddr.sin_addr = recv_tp->ip.ip_dst;
145     saddr.sin_port = recv_tp->udp.uh_dport;
146
147     daddr.sin_addr = spt->client_ip;
148     daddr.sin_port = spt->client_port;
149
150     m->m_len = sizeof(struct tftp_t) - 514 + n -
151         sizeof(struct ip) - sizeof(struct udphdr);
152     udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
153
154     return 0;
155 }
156
157 static void tftp_send_error(struct tftp_session *spt,
158                             u_int16_t errorcode, const char *msg,
159                             struct tftp_t *recv_tp)
160 {
161   struct sockaddr_in saddr, daddr;
162   struct mbuf *m;
163   struct tftp_t *tp;
164   int nobytes;
165
166   m = m_get(spt->slirp);
167
168   if (!m) {
169     goto out;
170   }
171
172   memset(m->m_data, 0, m->m_size);
173
174   m->m_data += IF_MAXLINKHDR;
175   tp = (void *)m->m_data;
176   m->m_data += sizeof(struct udpiphdr);
177
178   tp->tp_op = htons(TFTP_ERROR);
179   tp->x.tp_error.tp_error_code = htons(errorcode);
180   pstrcpy((char *)tp->x.tp_error.tp_msg, sizeof(tp->x.tp_error.tp_msg), msg);
181
182   saddr.sin_addr = recv_tp->ip.ip_dst;
183   saddr.sin_port = recv_tp->udp.uh_dport;
184
185   daddr.sin_addr = spt->client_ip;
186   daddr.sin_port = spt->client_port;
187
188   nobytes = 2;
189
190   m->m_len = sizeof(struct tftp_t) - 514 + 3 + strlen(msg) -
191         sizeof(struct ip) - sizeof(struct udphdr);
192
193   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
194
195 out:
196   tftp_session_terminate(spt);
197 }
198
199 static int tftp_send_data(struct tftp_session *spt,
200                           u_int16_t block_nr,
201                           struct tftp_t *recv_tp)
202 {
203   struct sockaddr_in saddr, daddr;
204   struct mbuf *m;
205   struct tftp_t *tp;
206   int nobytes;
207
208   if (block_nr < 1) {
209     return -1;
210   }
211
212   m = m_get(spt->slirp);
213
214   if (!m) {
215     return -1;
216   }
217
218   memset(m->m_data, 0, m->m_size);
219
220   m->m_data += IF_MAXLINKHDR;
221   tp = (void *)m->m_data;
222   m->m_data += sizeof(struct udpiphdr);
223
224   tp->tp_op = htons(TFTP_DATA);
225   tp->x.tp_data.tp_block_nr = htons(block_nr);
226
227   saddr.sin_addr = recv_tp->ip.ip_dst;
228   saddr.sin_port = recv_tp->udp.uh_dport;
229
230   daddr.sin_addr = spt->client_ip;
231   daddr.sin_port = spt->client_port;
232
233   nobytes = tftp_read_data(spt, block_nr - 1, tp->x.tp_data.tp_buf, 512);
234
235   if (nobytes < 0) {
236     m_free(m);
237
238     /* send "file not found" error back */
239
240     tftp_send_error(spt, 1, "File not found", tp);
241
242     return -1;
243   }
244
245   m->m_len = sizeof(struct tftp_t) - (512 - nobytes) -
246         sizeof(struct ip) - sizeof(struct udphdr);
247
248   udp_output2(NULL, m, &saddr, &daddr, IPTOS_LOWDELAY);
249
250   if (nobytes == 512) {
251     tftp_session_update(spt);
252   }
253   else {
254     tftp_session_terminate(spt);
255   }
256
257   return 0;
258 }
259
260 static void tftp_handle_rrq(Slirp *slirp, struct tftp_t *tp, int pktlen)
261 {
262   struct tftp_session *spt;
263   int s, k;
264   size_t prefix_len;
265   char *req_fname;
266
267   s = tftp_session_allocate(slirp, tp);
268
269   if (s < 0) {
270     return;
271   }
272
273   spt = &slirp->tftp_sessions[s];
274
275   /* unspecifed prefix means service disabled */
276   if (!slirp->tftp_prefix) {
277       tftp_send_error(spt, 2, "Access violation", tp);
278       return;
279   }
280
281   /* skip header fields */
282   k = 0;
283   pktlen -= ((uint8_t *)&tp->x.tp_buf[0] - (uint8_t *)tp);
284
285   /* prepend tftp_prefix */
286   prefix_len = strlen(slirp->tftp_prefix);
287   spt->filename = qemu_malloc(prefix_len + TFTP_FILENAME_MAX + 2);
288   memcpy(spt->filename, slirp->tftp_prefix, prefix_len);
289   spt->filename[prefix_len] = '/';
290
291   /* get name */
292   req_fname = spt->filename + prefix_len + 1;
293
294   while (1) {
295     if (k >= TFTP_FILENAME_MAX || k >= pktlen) {
296       tftp_send_error(spt, 2, "Access violation", tp);
297       return;
298     }
299     req_fname[k] = (char)tp->x.tp_buf[k];
300     if (req_fname[k++] == '\0') {
301       break;
302     }
303   }
304
305   /* check mode */
306   if ((pktlen - k) < 6) {
307     tftp_send_error(spt, 2, "Access violation", tp);
308     return;
309   }
310
311   if (memcmp(&tp->x.tp_buf[k], "octet\0", 6) != 0) {
312       tftp_send_error(spt, 4, "Unsupported transfer mode", tp);
313       return;
314   }
315
316   k += 6; /* skipping octet */
317
318   /* do sanity checks on the filename */
319   if (!strncmp(req_fname, "../", 3) ||
320       req_fname[strlen(req_fname) - 1] == '/' ||
321       strstr(req_fname, "/../")) {
322       tftp_send_error(spt, 2, "Access violation", tp);
323       return;
324   }
325
326   /* check if the file exists */
327   if (tftp_read_data(spt, 0, NULL, 0) < 0) {
328       tftp_send_error(spt, 1, "File not found", tp);
329       return;
330   }
331
332   if (tp->x.tp_buf[pktlen - 1] != 0) {
333       tftp_send_error(spt, 2, "Access violation", tp);
334       return;
335   }
336
337   while (k < pktlen) {
338       const char *key, *value;
339
340       key = (const char *)&tp->x.tp_buf[k];
341       k += strlen(key) + 1;
342
343       if (k >= pktlen) {
344           tftp_send_error(spt, 2, "Access violation", tp);
345           return;
346       }
347
348       value = (const char *)&tp->x.tp_buf[k];
349       k += strlen(value) + 1;
350
351       if (strcmp(key, "tsize") == 0) {
352           int tsize = atoi(value);
353           struct stat stat_p;
354
355           if (tsize == 0) {
356               if (stat(spt->filename, &stat_p) == 0)
357                   tsize = stat_p.st_size;
358               else {
359                   tftp_send_error(spt, 1, "File not found", tp);
360                   return;
361               }
362           }
363
364           tftp_send_oack(spt, "tsize", tsize, tp);
365       }
366   }
367
368   tftp_send_data(spt, 1, tp);
369 }
370
371 static void tftp_handle_ack(Slirp *slirp, struct tftp_t *tp, int pktlen)
372 {
373   int s;
374
375   s = tftp_session_find(slirp, tp);
376
377   if (s < 0) {
378     return;
379   }
380
381   if (tftp_send_data(&slirp->tftp_sessions[s],
382                      ntohs(tp->x.tp_data.tp_block_nr) + 1,
383                      tp) < 0) {
384     return;
385   }
386 }
387
388 void tftp_input(struct mbuf *m)
389 {
390   struct tftp_t *tp = (struct tftp_t *)m->m_data;
391
392   switch(ntohs(tp->tp_op)) {
393   case TFTP_RRQ:
394     tftp_handle_rrq(m->slirp, tp, m->m_len);
395     break;
396
397   case TFTP_ACK:
398     tftp_handle_ack(m->slirp, tp, m->m_len);
399     break;
400   }
401 }