qemu-nbd: remove useless parameter from nbd_negotiate() (Laurent Vivier)
[qemu] / nbd.c
1 /*
2  *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3  *
4  *  Network Block Device
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; under version 2 of the License.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, write to the Free Software
17  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18  */
19
20 #include "nbd.h"
21
22 #include <errno.h>
23 #include <string.h>
24 #include <sys/ioctl.h>
25 #ifdef __sun__
26 #include <sys/ioccom.h>
27 #endif
28 #include <ctype.h>
29 #include <inttypes.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <netinet/in.h>
33 #include <netinet/tcp.h>
34 #include <arpa/inet.h>
35 #include <netdb.h>
36
37 #if defined(QEMU_NBD)
38 extern int verbose;
39 #else
40 static int verbose = 0;
41 #endif
42
43 #define TRACE(msg, ...) do { \
44     if (verbose) LOG(msg, ## __VA_ARGS__); \
45 } while(0)
46
47 #define LOG(msg, ...) do { \
48     fprintf(stderr, "%s:%s():L%d: " msg "\n", \
49             __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
50 } while(0)
51
52 /* This is all part of the "official" NBD API */
53
54 #define NBD_REQUEST_MAGIC       0x25609513
55 #define NBD_REPLY_MAGIC         0x67446698
56
57 #define NBD_SET_SOCK            _IO(0xab, 0)
58 #define NBD_SET_BLKSIZE         _IO(0xab, 1)
59 #define NBD_SET_SIZE            _IO(0xab, 2)
60 #define NBD_DO_IT               _IO(0xab, 3)
61 #define NBD_CLEAR_SOCK          _IO(0xab, 4)
62 #define NBD_CLEAR_QUE           _IO(0xab, 5)
63 #define NBD_PRINT_DEBUG         _IO(0xab, 6)
64 #define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
65 #define NBD_DISCONNECT          _IO(0xab, 8)
66
67 /* That's all folks */
68
69 #define read_sync(fd, buffer, size) nbd_wr_sync(fd, buffer, size, true)
70 #define write_sync(fd, buffer, size) nbd_wr_sync(fd, buffer, size, false)
71
72 size_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
73 {
74     size_t offset = 0;
75
76     while (offset < size) {
77         ssize_t len;
78
79         if (do_read) {
80             len = read(fd, buffer + offset, size - offset);
81         } else {
82             len = write(fd, buffer + offset, size - offset);
83         }
84
85         /* recoverable error */
86         if (len == -1 && (errno == EAGAIN || errno == EINTR)) {
87             continue;
88         }
89
90         /* eof */
91         if (len == 0) {
92             break;
93         }
94
95         /* unrecoverable error */
96         if (len == -1) {
97             return 0;
98         }
99
100         offset += len;
101     }
102
103     return offset;
104 }
105
106 int tcp_socket_outgoing(const char *address, uint16_t port)
107 {
108     int s;
109     struct in_addr in;
110     struct sockaddr_in addr;
111     int serrno;
112
113     s = socket(PF_INET, SOCK_STREAM, 0);
114     if (s == -1) {
115         return -1;
116     }
117
118     if (inet_aton(address, &in) == 0) {
119         struct hostent *ent;
120
121         ent = gethostbyname(address);
122         if (ent == NULL) {
123             goto error;
124         }
125
126         memcpy(&in, ent->h_addr, sizeof(in));
127     }
128
129     addr.sin_family = AF_INET;
130     addr.sin_port = htons(port);
131     memcpy(&addr.sin_addr.s_addr, &in, sizeof(in));
132
133     if (connect(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
134         goto error;
135     }
136
137     return s;
138 error:
139     serrno = errno;
140     close(s);
141     errno = serrno;
142     return -1;
143 }
144
145 int tcp_socket_incoming(const char *address, uint16_t port)
146 {
147     int s;
148     struct in_addr in;
149     struct sockaddr_in addr;
150     int serrno;
151     int opt;
152
153     s = socket(PF_INET, SOCK_STREAM, 0);
154     if (s == -1) {
155         return -1;
156     }
157
158     if (inet_aton(address, &in) == 0) {
159         struct hostent *ent;
160
161         ent = gethostbyname(address);
162         if (ent == NULL) {
163             goto error;
164         }
165
166         memcpy(&in, ent->h_addr, sizeof(in));
167     }
168
169     addr.sin_family = AF_INET;
170     addr.sin_port = htons(port);
171     memcpy(&addr.sin_addr.s_addr, &in, sizeof(in));
172
173     opt = 1;
174     if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
175         goto error;
176     }
177
178     if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
179         goto error;
180     }
181
182     if (listen(s, 128) == -1) {
183         goto error;
184     }
185
186     return s;
187 error:
188     serrno = errno;
189     close(s);
190     errno = serrno;
191     return -1;
192 }
193
194 int unix_socket_incoming(const char *path)
195 {
196     int s;
197     struct sockaddr_un addr;
198     int serrno;
199
200     s = socket(PF_UNIX, SOCK_STREAM, 0);
201     if (s == -1) {
202         return -1;
203     }
204
205     memset(&addr, 0, sizeof(addr));
206     addr.sun_family = AF_UNIX;
207     pstrcpy(addr.sun_path, sizeof(addr.sun_path), path);
208
209     if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
210         goto error;
211     }
212
213     if (listen(s, 128) == -1) {
214         goto error;
215     }
216
217     return s;
218 error:
219     serrno = errno;
220     close(s);
221     errno = serrno;
222     return -1;
223 }
224
225 int unix_socket_outgoing(const char *path)
226 {
227     int s;
228     struct sockaddr_un addr;
229     int serrno;
230
231     s = socket(PF_UNIX, SOCK_STREAM, 0);
232     if (s == -1) {
233         return -1;
234     }
235
236     memset(&addr, 0, sizeof(addr));
237     addr.sun_family = AF_UNIX;
238     pstrcpy(addr.sun_path, sizeof(addr.sun_path), path);
239
240     if (connect(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
241         goto error;
242     }
243
244     return s;
245 error:
246     serrno = errno;
247     close(s);
248     errno = serrno;
249     return -1;
250 }
251
252
253 /* Basic flow
254
255    Server         Client
256
257    Negotiate
258                   Request
259    Response
260                   Request
261    Response
262                   ...
263    ...
264                   Request (type == 2)
265 */
266
267 int nbd_negotiate(int csock, off_t size)
268 {
269         char buf[8 + 8 + 8 + 128];
270
271         /* Negotiate
272            [ 0 ..   7]   passwd   ("NBDMAGIC")
273            [ 8 ..  15]   magic    (0x00420281861253)
274            [16 ..  23]   size
275            [24 .. 151]   reserved (0)
276          */
277
278         TRACE("Beginning negotiation.");
279         memcpy(buf, "NBDMAGIC", 8);
280         cpu_to_be64w((uint64_t*)(buf + 8), 0x00420281861253LL);
281         cpu_to_be64w((uint64_t*)(buf + 16), size);
282         memset(buf + 24, 0, 128);
283
284         if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
285                 LOG("write failed");
286                 errno = EINVAL;
287                 return -1;
288         }
289
290         TRACE("Negotation succeeded.");
291
292         return 0;
293 }
294
295 int nbd_receive_negotiate(int csock, off_t *size, size_t *blocksize)
296 {
297         char buf[8 + 8 + 8 + 128];
298         uint64_t magic;
299
300         TRACE("Receiving negotation.");
301
302         if (read_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
303                 LOG("read failed");
304                 errno = EINVAL;
305                 return -1;
306         }
307
308         magic = be64_to_cpup((uint64_t*)(buf + 8));
309         *size = be64_to_cpup((uint64_t*)(buf + 16));
310         *blocksize = 1024;
311
312         TRACE("Magic is %c%c%c%c%c%c%c%c",
313               isprint(buf[0]) ? buf[0] : '.',
314               isprint(buf[1]) ? buf[1] : '.',
315               isprint(buf[2]) ? buf[2] : '.',
316               isprint(buf[3]) ? buf[3] : '.',
317               isprint(buf[4]) ? buf[4] : '.',
318               isprint(buf[5]) ? buf[5] : '.',
319               isprint(buf[6]) ? buf[6] : '.',
320               isprint(buf[7]) ? buf[7] : '.');
321         TRACE("Magic is 0x%" PRIx64, magic);
322         TRACE("Size is %" PRIu64, *size);
323
324         if (memcmp(buf, "NBDMAGIC", 8) != 0) {
325                 LOG("Invalid magic received");
326                 errno = EINVAL;
327                 return -1;
328         }
329
330         TRACE("Checking magic");
331
332         if (magic != 0x00420281861253LL) {
333                 LOG("Bad magic received");
334                 errno = EINVAL;
335                 return -1;
336         }
337         return 0;
338 }
339
340 int nbd_init(int fd, int csock, off_t size, size_t blocksize)
341 {
342         TRACE("Setting block size to %lu", (unsigned long)blocksize);
343
344         if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) == -1) {
345                 int serrno = errno;
346                 LOG("Failed setting NBD block size");
347                 errno = serrno;
348                 return -1;
349         }
350
351         TRACE("Setting size to %llu block(s)",
352               (unsigned long long)(size / blocksize));
353
354         if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) == -1) {
355                 int serrno = errno;
356                 LOG("Failed setting size (in blocks)");
357                 errno = serrno;
358                 return -1;
359         }
360
361         TRACE("Clearing NBD socket");
362
363         if (ioctl(fd, NBD_CLEAR_SOCK) == -1) {
364                 int serrno = errno;
365                 LOG("Failed clearing NBD socket");
366                 errno = serrno;
367                 return -1;
368         }
369
370         TRACE("Setting NBD socket");
371
372         if (ioctl(fd, NBD_SET_SOCK, csock) == -1) {
373                 int serrno = errno;
374                 LOG("Failed to set NBD socket");
375                 errno = serrno;
376                 return -1;
377         }
378
379         TRACE("Negotiation ended");
380
381         return 0;
382 }
383
384 int nbd_disconnect(int fd)
385 {
386         ioctl(fd, NBD_CLEAR_QUE);
387         ioctl(fd, NBD_DISCONNECT);
388         ioctl(fd, NBD_CLEAR_SOCK);
389         return 0;
390 }
391
392 int nbd_client(int fd, int csock)
393 {
394         int ret;
395         int serrno;
396
397         TRACE("Doing NBD loop");
398
399         ret = ioctl(fd, NBD_DO_IT);
400         serrno = errno;
401
402         TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
403
404         TRACE("Clearing NBD queue");
405         ioctl(fd, NBD_CLEAR_QUE);
406
407         TRACE("Clearing NBD socket");
408         ioctl(fd, NBD_CLEAR_SOCK);
409
410         errno = serrno;
411         return ret;
412 }
413
414 int nbd_send_request(int csock, struct nbd_request *request)
415 {
416         uint8_t buf[4 + 4 + 8 + 8 + 4];
417
418         cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
419         cpu_to_be32w((uint32_t*)(buf + 4), request->type);
420         cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
421         cpu_to_be64w((uint64_t*)(buf + 16), request->from);
422         cpu_to_be32w((uint32_t*)(buf + 24), request->len);
423
424         TRACE("Sending request to client");
425
426         if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
427                 LOG("writing to socket failed");
428                 errno = EINVAL;
429                 return -1;
430         }
431         return 0;
432 }
433
434
435 static int nbd_receive_request(int csock, struct nbd_request *request)
436 {
437         uint8_t buf[4 + 4 + 8 + 8 + 4];
438         uint32_t magic;
439
440         if (read_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
441                 LOG("read failed");
442                 errno = EINVAL;
443                 return -1;
444         }
445
446         /* Request
447            [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
448            [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
449            [ 8 .. 15]   handle
450            [16 .. 23]   from
451            [24 .. 27]   len
452          */
453
454         magic = be32_to_cpup((uint32_t*)buf);
455         request->type  = be32_to_cpup((uint32_t*)(buf + 4));
456         request->handle = be64_to_cpup((uint64_t*)(buf + 8));
457         request->from  = be64_to_cpup((uint64_t*)(buf + 16));
458         request->len   = be32_to_cpup((uint32_t*)(buf + 24));
459
460         TRACE("Got request: "
461               "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
462               magic, request->type, request->from, request->len);
463
464         if (magic != NBD_REQUEST_MAGIC) {
465                 LOG("invalid magic (got 0x%x)", magic);
466                 errno = EINVAL;
467                 return -1;
468         }
469         return 0;
470 }
471
472 int nbd_receive_reply(int csock, struct nbd_reply *reply)
473 {
474         uint8_t buf[4 + 4 + 8];
475         uint32_t magic;
476
477         memset(buf, 0xAA, sizeof(buf));
478
479         if (read_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
480                 LOG("read failed");
481                 errno = EINVAL;
482                 return -1;
483         }
484
485         /* Reply
486            [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
487            [ 4 ..  7]    error   (0 == no error)
488            [ 7 .. 15]    handle
489          */
490
491         magic = be32_to_cpup((uint32_t*)buf);
492         reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
493         reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
494
495         TRACE("Got reply: "
496               "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
497               magic, reply->error, reply->handle);
498
499         if (magic != NBD_REPLY_MAGIC) {
500                 LOG("invalid magic (got 0x%x)", magic);
501                 errno = EINVAL;
502                 return -1;
503         }
504         return 0;
505 }
506
507 static int nbd_send_reply(int csock, struct nbd_reply *reply)
508 {
509         uint8_t buf[4 + 4 + 8];
510
511         /* Reply
512            [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
513            [ 4 ..  7]    error   (0 == no error)
514            [ 7 .. 15]    handle
515          */
516         cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
517         cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
518         cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
519
520         TRACE("Sending response to client");
521
522         if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
523                 LOG("writing to socket failed");
524                 errno = EINVAL;
525                 return -1;
526         }
527         return 0;
528 }
529
530 int nbd_trip(BlockDriverState *bs, int csock, off_t size, uint64_t dev_offset,
531              off_t *offset, bool readonly, uint8_t *data, int data_size)
532 {
533         struct nbd_request request;
534         struct nbd_reply reply;
535
536         TRACE("Reading request.");
537
538         if (nbd_receive_request(csock, &request) == -1)
539                 return -1;
540
541         if (request.len > data_size) {
542                 LOG("len (%u) is larger than max len (%u)",
543                     request.len, data_size);
544                 errno = EINVAL;
545                 return -1;
546         }
547
548         if ((request.from + request.len) < request.from) {
549                 LOG("integer overflow detected! "
550                     "you're probably being attacked");
551                 errno = EINVAL;
552                 return -1;
553         }
554
555         if ((request.from + request.len) > size) {
556                 LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
557                     ", Offset: %" PRIu64 "\n",
558                      request.from, request.len, size, dev_offset);
559                 LOG("requested operation past EOF--bad client?");
560                 errno = EINVAL;
561                 return -1;
562         }
563
564         TRACE("Decoding type");
565
566         reply.handle = request.handle;
567         reply.error = 0;
568
569         switch (request.type) {
570         case NBD_CMD_READ:
571                 TRACE("Request type is READ");
572
573                 if (bdrv_read(bs, (request.from + dev_offset) / 512, data,
574                               request.len / 512) == -1) {
575                         LOG("reading from file failed");
576                         errno = EINVAL;
577                         return -1;
578                 }
579                 *offset += request.len;
580
581                 TRACE("Read %u byte(s)", request.len);
582
583                 if (nbd_send_reply(csock, &reply) == -1)
584                         return -1;
585
586                 TRACE("Sending data to client");
587
588                 if (write_sync(csock, data, request.len) != request.len) {
589                         LOG("writing to socket failed");
590                         errno = EINVAL;
591                         return -1;
592                 }
593                 break;
594         case NBD_CMD_WRITE:
595                 TRACE("Request type is WRITE");
596
597                 TRACE("Reading %u byte(s)", request.len);
598
599                 if (read_sync(csock, data, request.len) != request.len) {
600                         LOG("reading from socket failed");
601                         errno = EINVAL;
602                         return -1;
603                 }
604
605                 if (readonly) {
606                         TRACE("Server is read-only, return error");
607                         reply.error = 1;
608                 } else {
609                         TRACE("Writing to device");
610
611                         if (bdrv_write(bs, (request.from + dev_offset) / 512,
612                                        data, request.len / 512) == -1) {
613                                 LOG("writing to file failed");
614                                 errno = EINVAL;
615                                 return -1;
616                         }
617
618                         *offset += request.len;
619                 }
620
621                 if (nbd_send_reply(csock, &reply) == -1)
622                         return -1;
623                 break;
624         case NBD_CMD_DISC:
625                 TRACE("Request type is DISCONNECT");
626                 errno = 0;
627                 return 1;
628         default:
629                 LOG("invalid request type (%u) received", request.type);
630                 errno = EINVAL;
631                 return -1;
632         }
633
634         TRACE("Request/Reply complete");
635
636         return 0;
637 }