Statistics
| Branch: | Revision:

root / nbd.c @ 38ceff04

History | View | Annotate | Download (24.1 kB)

1
/*
2
 *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3
 *
4
 *  Network Block Device
5
 *
6
 *  This program is free software; you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation; under version 2 of the License.
9
 *
10
 *  This program is distributed in the hope that it will be useful,
11
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
 *  GNU General Public License for more details.
14
 *
15
 *  You should have received a copy of the GNU General Public License
16
 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
17
 */
18

    
19
#include "nbd.h"
20
#include "block.h"
21

    
22
#include "qemu-coroutine.h"
23

    
24
#include <errno.h>
25
#include <string.h>
26
#ifndef _WIN32
27
#include <sys/ioctl.h>
28
#endif
29
#if defined(__sun__) || defined(__HAIKU__)
30
#include <sys/ioccom.h>
31
#endif
32
#include <ctype.h>
33
#include <inttypes.h>
34

    
35
#ifdef __linux__
36
#include <linux/fs.h>
37
#endif
38

    
39
#include "qemu_socket.h"
40
#include "qemu-queue.h"
41

    
42
//#define DEBUG_NBD
43

    
44
#ifdef DEBUG_NBD
45
#define TRACE(msg, ...) do { \
46
    LOG(msg, ## __VA_ARGS__); \
47
} while(0)
48
#else
49
#define TRACE(msg, ...) \
50
    do { } while (0)
51
#endif
52

    
53
#define LOG(msg, ...) do { \
54
    fprintf(stderr, "%s:%s():L%d: " msg "\n", \
55
            __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
56
} while(0)
57

    
58
/* This is all part of the "official" NBD API */
59

    
60
#define NBD_REPLY_SIZE          (4 + 4 + 8)
61
#define NBD_REQUEST_MAGIC       0x25609513
62
#define NBD_REPLY_MAGIC         0x67446698
63

    
64
#define NBD_SET_SOCK            _IO(0xab, 0)
65
#define NBD_SET_BLKSIZE         _IO(0xab, 1)
66
#define NBD_SET_SIZE            _IO(0xab, 2)
67
#define NBD_DO_IT               _IO(0xab, 3)
68
#define NBD_CLEAR_SOCK          _IO(0xab, 4)
69
#define NBD_CLEAR_QUE           _IO(0xab, 5)
70
#define NBD_PRINT_DEBUG         _IO(0xab, 6)
71
#define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
72
#define NBD_DISCONNECT          _IO(0xab, 8)
73
#define NBD_SET_TIMEOUT         _IO(0xab, 9)
74
#define NBD_SET_FLAGS           _IO(0xab, 10)
75

    
76
#define NBD_OPT_EXPORT_NAME     (1 << 0)
77

    
78
/* That's all folks */
79

    
80
ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
81
{
82
    size_t offset = 0;
83
    int err;
84

    
85
    if (qemu_in_coroutine()) {
86
        if (do_read) {
87
            return qemu_co_recv(fd, buffer, size);
88
        } else {
89
            return qemu_co_send(fd, buffer, size);
90
        }
91
    }
92

    
93
    while (offset < size) {
94
        ssize_t len;
95

    
96
        if (do_read) {
97
            len = qemu_recv(fd, buffer + offset, size - offset, 0);
98
        } else {
99
            len = send(fd, buffer + offset, size - offset, 0);
100
        }
101

    
102
        if (len < 0) {
103
            err = socket_error();
104

    
105
            /* recoverable error */
106
            if (err == EINTR || (offset > 0 && err == EAGAIN)) {
107
                continue;
108
            }
109

    
110
            /* unrecoverable error */
111
            return -err;
112
        }
113

    
114
        /* eof */
115
        if (len == 0) {
116
            break;
117
        }
118

    
119
        offset += len;
120
    }
121

    
122
    return offset;
123
}
124

    
125
static ssize_t read_sync(int fd, void *buffer, size_t size)
126
{
127
    /* Sockets are kept in blocking mode in the negotiation phase.  After
128
     * that, a non-readable socket simply means that another thread stole
129
     * our request/reply.  Synchronization is done with recv_coroutine, so
130
     * that this is coroutine-safe.
131
     */
132
    return nbd_wr_sync(fd, buffer, size, true);
133
}
134

    
135
static ssize_t write_sync(int fd, void *buffer, size_t size)
136
{
137
    int ret;
138
    do {
139
        /* For writes, we do expect the socket to be writable.  */
140
        ret = nbd_wr_sync(fd, buffer, size, false);
141
    } while (ret == -EAGAIN);
142
    return ret;
143
}
144

    
145
static void combine_addr(char *buf, size_t len, const char* address,
146
                         uint16_t port)
147
{
148
    /* If the address-part contains a colon, it's an IPv6 IP so needs [] */
149
    if (strstr(address, ":")) {
150
        snprintf(buf, len, "[%s]:%u", address, port);
151
    } else {
152
        snprintf(buf, len, "%s:%u", address, port);
153
    }
154
}
155

    
156
int tcp_socket_outgoing(const char *address, uint16_t port)
157
{
158
    char address_and_port[128];
159
    combine_addr(address_and_port, 128, address, port);
160
    return tcp_socket_outgoing_spec(address_and_port);
161
}
162

    
163
int tcp_socket_outgoing_spec(const char *address_and_port)
164
{
165
    return inet_connect(address_and_port, SOCK_STREAM);
166
}
167

    
168
int tcp_socket_incoming(const char *address, uint16_t port)
169
{
170
    char address_and_port[128];
171
    combine_addr(address_and_port, 128, address, port);
172
    return tcp_socket_incoming_spec(address_and_port);
173
}
174

    
175
int tcp_socket_incoming_spec(const char *address_and_port)
176
{
177
    char *ostr  = NULL;
178
    int olen = 0;
179
    return inet_listen(address_and_port, ostr, olen, SOCK_STREAM, 0);
180
}
181

    
182
int unix_socket_incoming(const char *path)
183
{
184
    char *ostr = NULL;
185
    int olen = 0;
186

    
187
    return unix_listen(path, ostr, olen);
188
}
189

    
190
int unix_socket_outgoing(const char *path)
191
{
192
    return unix_connect(path);
193
}
194

    
195
/* Basic flow
196

197
   Server         Client
198

199
   Negotiate
200
                  Request
201
   Response
202
                  Request
203
   Response
204
                  ...
205
   ...
206
                  Request (type == 2)
207
*/
208

    
209
static int nbd_send_negotiate(int csock, off_t size, uint32_t flags)
210
{
211
    char buf[8 + 8 + 8 + 128];
212
    int rc;
213

    
214
    /* Negotiate
215
        [ 0 ..   7]   passwd   ("NBDMAGIC")
216
        [ 8 ..  15]   magic    (0x00420281861253)
217
        [16 ..  23]   size
218
        [24 ..  27]   flags
219
        [28 .. 151]   reserved (0)
220
     */
221

    
222
    socket_set_block(csock);
223
    rc = -EINVAL;
224

    
225
    TRACE("Beginning negotiation.");
226
    memcpy(buf, "NBDMAGIC", 8);
227
    cpu_to_be64w((uint64_t*)(buf + 8), 0x00420281861253LL);
228
    cpu_to_be64w((uint64_t*)(buf + 16), size);
229
    cpu_to_be32w((uint32_t*)(buf + 24),
230
                 flags | NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
231
                 NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
232
    memset(buf + 28, 0, 124);
233

    
234
    if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
235
        LOG("write failed");
236
        goto fail;
237
    }
238

    
239
    TRACE("Negotiation succeeded.");
240
    rc = 0;
241
fail:
242
    socket_set_nonblock(csock);
243
    return rc;
244
}
245

    
246
int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
247
                          off_t *size, size_t *blocksize)
248
{
249
    char buf[256];
250
    uint64_t magic, s;
251
    uint16_t tmp;
252
    int rc;
253

    
254
    TRACE("Receiving negotiation.");
255

    
256
    socket_set_block(csock);
257
    rc = -EINVAL;
258

    
259
    if (read_sync(csock, buf, 8) != 8) {
260
        LOG("read failed");
261
        goto fail;
262
    }
263

    
264
    buf[8] = '\0';
265
    if (strlen(buf) == 0) {
266
        LOG("server connection closed");
267
        goto fail;
268
    }
269

    
270
    TRACE("Magic is %c%c%c%c%c%c%c%c",
271
          qemu_isprint(buf[0]) ? buf[0] : '.',
272
          qemu_isprint(buf[1]) ? buf[1] : '.',
273
          qemu_isprint(buf[2]) ? buf[2] : '.',
274
          qemu_isprint(buf[3]) ? buf[3] : '.',
275
          qemu_isprint(buf[4]) ? buf[4] : '.',
276
          qemu_isprint(buf[5]) ? buf[5] : '.',
277
          qemu_isprint(buf[6]) ? buf[6] : '.',
278
          qemu_isprint(buf[7]) ? buf[7] : '.');
279

    
280
    if (memcmp(buf, "NBDMAGIC", 8) != 0) {
281
        LOG("Invalid magic received");
282
        goto fail;
283
    }
284

    
285
    if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
286
        LOG("read failed");
287
        goto fail;
288
    }
289
    magic = be64_to_cpu(magic);
290
    TRACE("Magic is 0x%" PRIx64, magic);
291

    
292
    if (name) {
293
        uint32_t reserved = 0;
294
        uint32_t opt;
295
        uint32_t namesize;
296

    
297
        TRACE("Checking magic (opts_magic)");
298
        if (magic != 0x49484156454F5054LL) {
299
            LOG("Bad magic received");
300
            goto fail;
301
        }
302
        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
303
            LOG("flags read failed");
304
            goto fail;
305
        }
306
        *flags = be16_to_cpu(tmp) << 16;
307
        /* reserved for future use */
308
        if (write_sync(csock, &reserved, sizeof(reserved)) !=
309
            sizeof(reserved)) {
310
            LOG("write failed (reserved)");
311
            goto fail;
312
        }
313
        /* write the export name */
314
        magic = cpu_to_be64(magic);
315
        if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
316
            LOG("write failed (magic)");
317
            goto fail;
318
        }
319
        opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
320
        if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
321
            LOG("write failed (opt)");
322
            goto fail;
323
        }
324
        namesize = cpu_to_be32(strlen(name));
325
        if (write_sync(csock, &namesize, sizeof(namesize)) !=
326
            sizeof(namesize)) {
327
            LOG("write failed (namesize)");
328
            goto fail;
329
        }
330
        if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
331
            LOG("write failed (name)");
332
            goto fail;
333
        }
334
    } else {
335
        TRACE("Checking magic (cli_magic)");
336

    
337
        if (magic != 0x00420281861253LL) {
338
            LOG("Bad magic received");
339
            goto fail;
340
        }
341
    }
342

    
343
    if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
344
        LOG("read failed");
345
        goto fail;
346
    }
347
    *size = be64_to_cpu(s);
348
    *blocksize = 1024;
349
    TRACE("Size is %" PRIu64, *size);
350

    
351
    if (!name) {
352
        if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
353
            LOG("read failed (flags)");
354
            goto fail;
355
        }
356
        *flags = be32_to_cpup(flags);
357
    } else {
358
        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
359
            LOG("read failed (tmp)");
360
            goto fail;
361
        }
362
        *flags |= be32_to_cpu(tmp);
363
    }
364
    if (read_sync(csock, &buf, 124) != 124) {
365
        LOG("read failed (buf)");
366
        goto fail;
367
    }
368
    rc = 0;
369

    
370
fail:
371
    socket_set_nonblock(csock);
372
    return rc;
373
}
374

    
375
#ifdef __linux__
376
int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
377
{
378
    TRACE("Setting NBD socket");
379

    
380
    if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
381
        int serrno = errno;
382
        LOG("Failed to set NBD socket");
383
        return -serrno;
384
    }
385

    
386
    TRACE("Setting block size to %lu", (unsigned long)blocksize);
387

    
388
    if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
389
        int serrno = errno;
390
        LOG("Failed setting NBD block size");
391
        return -serrno;
392
    }
393

    
394
        TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
395

    
396
    if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
397
        int serrno = errno;
398
        LOG("Failed setting size (in blocks)");
399
        return -serrno;
400
    }
401

    
402
    if (flags & NBD_FLAG_READ_ONLY) {
403
        int read_only = 1;
404
        TRACE("Setting readonly attribute");
405

    
406
        if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
407
            int serrno = errno;
408
            LOG("Failed setting read-only attribute");
409
            return -serrno;
410
        }
411
    }
412

    
413
    if (ioctl(fd, NBD_SET_FLAGS, flags) < 0
414
        && errno != ENOTTY) {
415
        int serrno = errno;
416
        LOG("Failed setting flags");
417
        return -serrno;
418
    }
419

    
420
    TRACE("Negotiation ended");
421

    
422
    return 0;
423
}
424

    
425
int nbd_disconnect(int fd)
426
{
427
    ioctl(fd, NBD_CLEAR_QUE);
428
    ioctl(fd, NBD_DISCONNECT);
429
    ioctl(fd, NBD_CLEAR_SOCK);
430
    return 0;
431
}
432

    
433
int nbd_client(int fd)
434
{
435
    int ret;
436
    int serrno;
437

    
438
    TRACE("Doing NBD loop");
439

    
440
    ret = ioctl(fd, NBD_DO_IT);
441
    if (ret < 0 && errno == EPIPE) {
442
        /* NBD_DO_IT normally returns EPIPE when someone has disconnected
443
         * the socket via NBD_DISCONNECT.  We do not want to return 1 in
444
         * that case.
445
         */
446
        ret = 0;
447
    }
448
    serrno = errno;
449

    
450
    TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
451

    
452
    TRACE("Clearing NBD queue");
453
    ioctl(fd, NBD_CLEAR_QUE);
454

    
455
    TRACE("Clearing NBD socket");
456
    ioctl(fd, NBD_CLEAR_SOCK);
457

    
458
    errno = serrno;
459
    return ret;
460
}
461
#else
462
int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
463
{
464
    return -ENOTSUP;
465
}
466

    
467
int nbd_disconnect(int fd)
468
{
469
    return -ENOTSUP;
470
}
471

    
472
int nbd_client(int fd)
473
{
474
    return -ENOTSUP;
475
}
476
#endif
477

    
478
ssize_t nbd_send_request(int csock, struct nbd_request *request)
479
{
480
    uint8_t buf[4 + 4 + 8 + 8 + 4];
481
    ssize_t ret;
482

    
483
    cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
484
    cpu_to_be32w((uint32_t*)(buf + 4), request->type);
485
    cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
486
    cpu_to_be64w((uint64_t*)(buf + 16), request->from);
487
    cpu_to_be32w((uint32_t*)(buf + 24), request->len);
488

    
489
    TRACE("Sending request to client: "
490
          "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
491
          request->from, request->len, request->handle, request->type);
492

    
493
    ret = write_sync(csock, buf, sizeof(buf));
494
    if (ret < 0) {
495
        return ret;
496
    }
497

    
498
    if (ret != sizeof(buf)) {
499
        LOG("writing to socket failed");
500
        return -EINVAL;
501
    }
502
    return 0;
503
}
504

    
505
static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
506
{
507
    uint8_t buf[4 + 4 + 8 + 8 + 4];
508
    uint32_t magic;
509
    ssize_t ret;
510

    
511
    ret = read_sync(csock, buf, sizeof(buf));
512
    if (ret < 0) {
513
        return ret;
514
    }
515

    
516
    if (ret != sizeof(buf)) {
517
        LOG("read failed");
518
        return -EINVAL;
519
    }
520

    
521
    /* Request
522
       [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
523
       [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
524
       [ 8 .. 15]   handle
525
       [16 .. 23]   from
526
       [24 .. 27]   len
527
     */
528

    
529
    magic = be32_to_cpup((uint32_t*)buf);
530
    request->type  = be32_to_cpup((uint32_t*)(buf + 4));
531
    request->handle = be64_to_cpup((uint64_t*)(buf + 8));
532
    request->from  = be64_to_cpup((uint64_t*)(buf + 16));
533
    request->len   = be32_to_cpup((uint32_t*)(buf + 24));
534

    
535
    TRACE("Got request: "
536
          "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
537
          magic, request->type, request->from, request->len);
538

    
539
    if (magic != NBD_REQUEST_MAGIC) {
540
        LOG("invalid magic (got 0x%x)", magic);
541
        return -EINVAL;
542
    }
543
    return 0;
544
}
545

    
546
ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
547
{
548
    uint8_t buf[NBD_REPLY_SIZE];
549
    uint32_t magic;
550
    ssize_t ret;
551

    
552
    ret = read_sync(csock, buf, sizeof(buf));
553
    if (ret < 0) {
554
        return ret;
555
    }
556

    
557
    if (ret != sizeof(buf)) {
558
        LOG("read failed");
559
        return -EINVAL;
560
    }
561

    
562
    /* Reply
563
       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
564
       [ 4 ..  7]    error   (0 == no error)
565
       [ 7 .. 15]    handle
566
     */
567

    
568
    magic = be32_to_cpup((uint32_t*)buf);
569
    reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
570
    reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
571

    
572
    TRACE("Got reply: "
573
          "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
574
          magic, reply->error, reply->handle);
575

    
576
    if (magic != NBD_REPLY_MAGIC) {
577
        LOG("invalid magic (got 0x%x)", magic);
578
        return -EINVAL;
579
    }
580
    return 0;
581
}
582

    
583
static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
584
{
585
    uint8_t buf[4 + 4 + 8];
586
    ssize_t ret;
587

    
588
    /* Reply
589
       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
590
       [ 4 ..  7]    error   (0 == no error)
591
       [ 7 .. 15]    handle
592
     */
593
    cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
594
    cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
595
    cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
596

    
597
    TRACE("Sending response to client");
598

    
599
    ret = write_sync(csock, buf, sizeof(buf));
600
    if (ret < 0) {
601
        return ret;
602
    }
603

    
604
    if (ret != sizeof(buf)) {
605
        LOG("writing to socket failed");
606
        return -EINVAL;
607
    }
608
    return 0;
609
}
610

    
611
#define MAX_NBD_REQUESTS 16
612

    
613
typedef struct NBDRequest NBDRequest;
614

    
615
struct NBDRequest {
616
    QSIMPLEQ_ENTRY(NBDRequest) entry;
617
    NBDClient *client;
618
    uint8_t *data;
619
};
620

    
621
struct NBDExport {
622
    BlockDriverState *bs;
623
    off_t dev_offset;
624
    off_t size;
625
    uint32_t nbdflags;
626
    QSIMPLEQ_HEAD(, NBDRequest) requests;
627
};
628

    
629
struct NBDClient {
630
    int refcount;
631
    void (*close)(NBDClient *client);
632

    
633
    NBDExport *exp;
634
    int sock;
635

    
636
    Coroutine *recv_coroutine;
637

    
638
    CoMutex send_lock;
639
    Coroutine *send_coroutine;
640

    
641
    int nb_requests;
642
};
643

    
644
static void nbd_client_get(NBDClient *client)
645
{
646
    client->refcount++;
647
}
648

    
649
static void nbd_client_put(NBDClient *client)
650
{
651
    if (--client->refcount == 0) {
652
        g_free(client);
653
    }
654
}
655

    
656
static void nbd_client_close(NBDClient *client)
657
{
658
    qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
659
    close(client->sock);
660
    client->sock = -1;
661
    if (client->close) {
662
        client->close(client);
663
    }
664
    nbd_client_put(client);
665
}
666

    
667
static NBDRequest *nbd_request_get(NBDClient *client)
668
{
669
    NBDRequest *req;
670
    NBDExport *exp = client->exp;
671

    
672
    assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
673
    client->nb_requests++;
674

    
675
    if (QSIMPLEQ_EMPTY(&exp->requests)) {
676
        req = g_malloc0(sizeof(NBDRequest));
677
        req->data = qemu_blockalign(exp->bs, NBD_BUFFER_SIZE);
678
    } else {
679
        req = QSIMPLEQ_FIRST(&exp->requests);
680
        QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
681
    }
682
    nbd_client_get(client);
683
    req->client = client;
684
    return req;
685
}
686

    
687
static void nbd_request_put(NBDRequest *req)
688
{
689
    NBDClient *client = req->client;
690
    QSIMPLEQ_INSERT_HEAD(&client->exp->requests, req, entry);
691
    if (client->nb_requests-- == MAX_NBD_REQUESTS) {
692
        qemu_notify_event();
693
    }
694
    nbd_client_put(client);
695
}
696

    
697
NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
698
                          off_t size, uint32_t nbdflags)
699
{
700
    NBDExport *exp = g_malloc0(sizeof(NBDExport));
701
    QSIMPLEQ_INIT(&exp->requests);
702
    exp->bs = bs;
703
    exp->dev_offset = dev_offset;
704
    exp->nbdflags = nbdflags;
705
    exp->size = size == -1 ? bdrv_getlength(bs) : size;
706
    return exp;
707
}
708

    
709
void nbd_export_close(NBDExport *exp)
710
{
711
    while (!QSIMPLEQ_EMPTY(&exp->requests)) {
712
        NBDRequest *first = QSIMPLEQ_FIRST(&exp->requests);
713
        QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
714
        qemu_vfree(first->data);
715
        g_free(first);
716
    }
717

    
718
    bdrv_close(exp->bs);
719
    g_free(exp);
720
}
721

    
722
static int nbd_can_read(void *opaque);
723
static void nbd_read(void *opaque);
724
static void nbd_restart_write(void *opaque);
725

    
726
static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
727
                                 int len)
728
{
729
    NBDClient *client = req->client;
730
    int csock = client->sock;
731
    ssize_t rc, ret;
732

    
733
    qemu_co_mutex_lock(&client->send_lock);
734
    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read,
735
                         nbd_restart_write, client);
736
    client->send_coroutine = qemu_coroutine_self();
737

    
738
    if (!len) {
739
        rc = nbd_send_reply(csock, reply);
740
    } else {
741
        socket_set_cork(csock, 1);
742
        rc = nbd_send_reply(csock, reply);
743
        if (rc >= 0) {
744
            ret = qemu_co_send(csock, req->data, len);
745
            if (ret != len) {
746
                rc = -EIO;
747
            }
748
        }
749
        socket_set_cork(csock, 0);
750
    }
751

    
752
    client->send_coroutine = NULL;
753
    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
754
    qemu_co_mutex_unlock(&client->send_lock);
755
    return rc;
756
}
757

    
758
static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
759
{
760
    NBDClient *client = req->client;
761
    int csock = client->sock;
762
    ssize_t rc;
763

    
764
    client->recv_coroutine = qemu_coroutine_self();
765
    rc = nbd_receive_request(csock, request);
766
    if (rc < 0) {
767
        if (rc != -EAGAIN) {
768
            rc = -EIO;
769
        }
770
        goto out;
771
    }
772

    
773
    if (request->len > NBD_BUFFER_SIZE) {
774
        LOG("len (%u) is larger than max len (%u)",
775
            request->len, NBD_BUFFER_SIZE);
776
        rc = -EINVAL;
777
        goto out;
778
    }
779

    
780
    if ((request->from + request->len) < request->from) {
781
        LOG("integer overflow detected! "
782
            "you're probably being attacked");
783
        rc = -EINVAL;
784
        goto out;
785
    }
786

    
787
    TRACE("Decoding type");
788

    
789
    if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) {
790
        TRACE("Reading %u byte(s)", request->len);
791

    
792
        if (qemu_co_recv(csock, req->data, request->len) != request->len) {
793
            LOG("reading from socket failed");
794
            rc = -EIO;
795
            goto out;
796
        }
797
    }
798
    rc = 0;
799

    
800
out:
801
    client->recv_coroutine = NULL;
802
    return rc;
803
}
804

    
805
static void nbd_trip(void *opaque)
806
{
807
    NBDClient *client = opaque;
808
    NBDRequest *req = nbd_request_get(client);
809
    NBDExport *exp = client->exp;
810
    struct nbd_request request;
811
    struct nbd_reply reply;
812
    ssize_t ret;
813

    
814
    TRACE("Reading request.");
815

    
816
    ret = nbd_co_receive_request(req, &request);
817
    if (ret == -EAGAIN) {
818
        goto done;
819
    }
820
    if (ret == -EIO) {
821
        goto out;
822
    }
823

    
824
    reply.handle = request.handle;
825
    reply.error = 0;
826

    
827
    if (ret < 0) {
828
        reply.error = -ret;
829
        goto error_reply;
830
    }
831

    
832
    if ((request.from + request.len) > exp->size) {
833
            LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
834
            ", Offset: %" PRIu64 "\n",
835
                    request.from, request.len,
836
                    (uint64_t)exp->size, (uint64_t)exp->dev_offset);
837
        LOG("requested operation past EOF--bad client?");
838
        goto invalid_request;
839
    }
840

    
841
    switch (request.type & NBD_CMD_MASK_COMMAND) {
842
    case NBD_CMD_READ:
843
        TRACE("Request type is READ");
844

    
845
        ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
846
                        req->data, request.len / 512);
847
        if (ret < 0) {
848
            LOG("reading from file failed");
849
            reply.error = -ret;
850
            goto error_reply;
851
        }
852

    
853
        TRACE("Read %u byte(s)", request.len);
854
        if (nbd_co_send_reply(req, &reply, request.len) < 0)
855
            goto out;
856
        break;
857
    case NBD_CMD_WRITE:
858
        TRACE("Request type is WRITE");
859

    
860
        if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
861
            TRACE("Server is read-only, return error");
862
            reply.error = EROFS;
863
            goto error_reply;
864
        }
865

    
866
        TRACE("Writing to device");
867

    
868
        ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
869
                         req->data, request.len / 512);
870
        if (ret < 0) {
871
            LOG("writing to file failed");
872
            reply.error = -ret;
873
            goto error_reply;
874
        }
875

    
876
        if (request.type & NBD_CMD_FLAG_FUA) {
877
            ret = bdrv_co_flush(exp->bs);
878
            if (ret < 0) {
879
                LOG("flush failed");
880
                reply.error = -ret;
881
                goto error_reply;
882
            }
883
        }
884

    
885
        if (nbd_co_send_reply(req, &reply, 0) < 0) {
886
            goto out;
887
        }
888
        break;
889
    case NBD_CMD_DISC:
890
        TRACE("Request type is DISCONNECT");
891
        errno = 0;
892
        goto out;
893
    case NBD_CMD_FLUSH:
894
        TRACE("Request type is FLUSH");
895

    
896
        ret = bdrv_co_flush(exp->bs);
897
        if (ret < 0) {
898
            LOG("flush failed");
899
            reply.error = -ret;
900
        }
901
        if (nbd_co_send_reply(req, &reply, 0) < 0) {
902
            goto out;
903
        }
904
        break;
905
    case NBD_CMD_TRIM:
906
        TRACE("Request type is TRIM");
907
        ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
908
                              request.len / 512);
909
        if (ret < 0) {
910
            LOG("discard failed");
911
            reply.error = -ret;
912
        }
913
        if (nbd_co_send_reply(req, &reply, 0) < 0) {
914
            goto out;
915
        }
916
        break;
917
    default:
918
        LOG("invalid request type (%u) received", request.type);
919
    invalid_request:
920
        reply.error = -EINVAL;
921
    error_reply:
922
        if (nbd_co_send_reply(req, &reply, 0) < 0) {
923
            goto out;
924
        }
925
        break;
926
    }
927

    
928
    TRACE("Request/Reply complete");
929

    
930
done:
931
    nbd_request_put(req);
932
    return;
933

    
934
out:
935
    nbd_request_put(req);
936
    nbd_client_close(client);
937
}
938

    
939
static int nbd_can_read(void *opaque)
940
{
941
    NBDClient *client = opaque;
942

    
943
    return client->recv_coroutine || client->nb_requests < MAX_NBD_REQUESTS;
944
}
945

    
946
static void nbd_read(void *opaque)
947
{
948
    NBDClient *client = opaque;
949

    
950
    if (client->recv_coroutine) {
951
        qemu_coroutine_enter(client->recv_coroutine, NULL);
952
    } else {
953
        qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
954
    }
955
}
956

    
957
static void nbd_restart_write(void *opaque)
958
{
959
    NBDClient *client = opaque;
960

    
961
    qemu_coroutine_enter(client->send_coroutine, NULL);
962
}
963

    
964
NBDClient *nbd_client_new(NBDExport *exp, int csock,
965
                          void (*close)(NBDClient *))
966
{
967
    NBDClient *client;
968
    if (nbd_send_negotiate(csock, exp->size, exp->nbdflags) < 0) {
969
        return NULL;
970
    }
971
    client = g_malloc0(sizeof(NBDClient));
972
    client->refcount = 1;
973
    client->exp = exp;
974
    client->sock = csock;
975
    client->close = close;
976
    qemu_co_mutex_init(&client->send_lock);
977
    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
978
    return client;
979
}