Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ 1e3d1c7d

History | View | Annotate | Download (23.8 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2013 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
DEFAULT_CONNECT_RETRIES = 3
28
DEFAULT_RETRY_WAIT = 0.1
29
# We must take care not to fall into the ephemeral port range,
30
# this can lead to transient failures to bind a chosen port.
31
#
32
# By default, Linux uses 32768 to 61000, see:
33
# http://www.ncftp.com/ncftpd/doc/misc/ephemeral_ports.html#Linux
34
# so 25000-30000 seems to be a sensible default.
35
DEFAULT_MIN_PORT = 25000
36
DEFAULT_MAX_PORT = 30000
37

    
38
import os
39
import sys
40
import logging
41
import gevent
42
import gevent.event
43
import daemon
44
import random
45
import daemon.runner
46

    
47
import rfb
48

    
49
try:
50
    import simplejson as json
51
except ImportError:
52
    import json
53

    
54
from gevent import socket
55
from signal import SIGINT, SIGTERM
56
from gevent.select import select
57

    
58
from lockfile import LockTimeout, AlreadyLocked
59
# Take care of differences between python-daemon versions.
60
try:
61
    from daemon import pidfile as pidlockfile
62
except:
63
    from daemon import pidlockfile
64

    
65

    
66
logger = None
67

    
68

    
69
# Currently, gevent uses libevent-dns for asynchronous DNS resolution,
70
# which opens a socket upon initialization time. Since we can't get the fd
71
# reliably, We have to maintain all file descriptors open (which won't harm
72
# anyway)
73
class AllFilesDaemonContext(daemon.DaemonContext):
74
    """DaemonContext class keeping all file descriptors open"""
75
    def _get_exclude_file_descriptors(self):
76
        class All:
77
            def __contains__(self, value):
78
                return True
79
        return All()
80

    
81

    
82
class VncAuthProxy(gevent.Greenlet):
83
    """
84
    Simple class implementing a VNC Forwarder with MITM authentication as a
85
    Greenlet
86

87
    VncAuthProxy forwards VNC traffic from a specified port of the local host
88
    to a specified remote host:port. Furthermore, it implements VNC
89
    Authentication, intercepting the client/server handshake and asking the
90
    client for authentication even if the backend requires none.
91

92
    It is primarily intended for use in virtualization environments, as a VNC
93
    ``switch''.
94

95
    """
96
    id = 1
97

    
98
    def __init__(self, logger, listeners, pool, daddr, dport, server, password,
99
                 connect_timeout):
100
        """
101
        @type logger: logging.Logger
102
        @param logger: the logger to use
103
        @type listeners: list
104
        @param listeners: list of listening sockets to use for clients
105
        @type pool: list
106
        @param pool: if not None, return the client number into this port pool
107
        @type daddr: str
108
        @param daddr: destination address (IPv4, IPv6 or hostname)
109
        @type dport: int
110
        @param dport: destination port
111
        @type server: socket
112
        @param server: VNC server socket
113
        @type password: str
114
        @param password: password to request from the client
115
        @type connect_timeout: int
116
        @param connect_timeout: how long to wait for client connections
117
                                (seconds)
118

119
        """
120
        gevent.Greenlet.__init__(self)
121
        self.id = VncAuthProxy.id
122
        VncAuthProxy.id += 1
123
        self.log = logger
124
        self.listeners = listeners
125
        # A list of worker/forwarder greenlets, one for each direction
126
        self.workers = []
127
        # All listening sockets are assumed to be on the same port
128
        self.sport = listeners[0].getsockname()[1]
129
        self.pool = pool
130
        self.daddr = daddr
131
        self.dport = dport
132
        self.server = server
133
        self.password = password
134
        self.client = None
135
        self.timeout = connect_timeout
136

    
137
    def _cleanup(self):
138
        """Cleanup everything: workers, sockets, ports
139

140
        Kill all remaining forwarder greenlets, close all active sockets,
141
        return the source port to the pool if applicable, then exit
142
        gracefully.
143

144
        """
145
        # Make sure all greenlets are dead, then clean them up
146
        self.debug("Cleaning up %d workers", len(self.workers))
147
        for g in self.workers:
148
            g.kill()
149
        gevent.joinall(self.workers)
150
        del self.workers
151

    
152
        self.debug("Cleaning up sockets")
153
        while self.listeners:
154
            self.listeners.pop().close()
155
        if self.server:
156
            self.server.close()
157
        if self.client:
158
            self.client.close()
159

    
160
        # Reintroduce the port number of the client socket in
161
        # the port pool, if applicable.
162
        if not self.pool is None:
163
            self.pool.append(self.sport)
164
            self.debug("Returned port %d to port pool, contains %d ports",
165
                       self.sport, len(self.pool))
166

    
167
        self.info("Cleaned up connection, all done")
168
        raise gevent.GreenletExit
169

    
170
    def __str__(self):
171
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr,
172
                                              self.dport)
173

    
174
    def _forward(self, source, dest):
175
        """
176
        Forward traffic from source to dest
177

178
        @type source: socket
179
        @param source: source socket
180
        @type dest: socket
181
        @param dest: destination socket
182

183
        """
184

    
185
        while True:
186
            d = source.recv(16384)
187
            if d == '':
188
                if source == self.client:
189
                    self.info("Client connection closed")
190
                else:
191
                    self.info("Server connection closed")
192
                break
193
            dest.sendall(d)
194
        # No need to close the source and dest sockets here.
195
        # They are owned by and will be closed by the original greenlet.
196

    
197
    def _client_handshake(self):
198
        """
199
        Perform handshake/authentication with a connecting client
200

201
        Outline:
202
        1. Client connects
203
        2. We fake RFB 3.8 protocol and require VNC authentication
204
           [processing also supports RFB 3.3]
205
        3. Client accepts authentication method
206
        4. We send an authentication challenge
207
        5. Client sends the authentication response
208
        6. We check the authentication
209

210
        Upon return, self.client socket is connected to the client.
211

212
        """
213
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
214
        client_version_str = self.client.recv(1024)
215
        client_version = rfb.check_version(client_version_str)
216
        if not client_version:
217
            self.error("Invalid version: %s", client_version_str)
218
            raise gevent.GreenletExit
219

    
220
        # Both for RFB 3.3 and 3.8
221
        self.debug("Requesting authentication")
222
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
223
                                             version=client_version)
224
        self.client.send(auth_request)
225

    
226
        # The client gets to propose an authtype only for RFB 3.8
227
        if client_version == rfb.RFB_VERSION_3_8:
228
            res = self.client.recv(1024)
229
            type = rfb.parse_client_authtype(res)
230
            if type == rfb.RFB_AUTHTYPE_ERROR:
231
                self.warn("Client refused authentication: %s", res[1:])
232
            else:
233
                self.debug("Client requested authtype %x", type)
234

    
235
            if type != rfb.RFB_AUTHTYPE_VNC:
236
                self.error("Wrong auth type: %d", type)
237
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
238
                raise gevent.GreenletExit
239

    
240
        # Generate the challenge
241
        challenge = os.urandom(16)
242
        self.client.send(challenge)
243
        response = self.client.recv(1024)
244
        if len(response) != 16:
245
            self.error("Wrong response length %d, should be 16", len(response))
246
            raise gevent.GreenletExit
247

    
248
        if rfb.check_password(challenge, response, self.password):
249
            self.debug("Authentication successful")
250
        else:
251
            self.warn("Authentication failed")
252
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
253
            raise gevent.GreenletExit
254

    
255
        # Accept the authentication
256
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
257

    
258
    def _run(self):
259
        try:
260
            self.info("Waiting for a client to connect at %s",
261
                      ", ".join(["%s:%d" % s.getsockname()[:2]
262
                                 for s in self.listeners]))
263
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
264

    
265
            if not rlist:
266
                self.info("Timed out, no connection after %d sec",
267
                          self.timeout)
268
                raise gevent.GreenletExit
269

    
270
            for sock in rlist:
271
                self.client, addrinfo = sock.accept()
272
                self.info("Connection from %s:%d", *addrinfo[:2])
273

    
274
                # Close all listening sockets, we only want a one-shot
275
                # connection from a single client.
276
                while self.listeners:
277
                    self.listeners.pop().close()
278
                break
279

    
280
            # Perform RFB handshake with the client.
281
            self._client_handshake()
282

    
283
            # Bridge both connections through two "forwarder" greenlets.
284
            # This greenlet will wait until any of the workers dies.
285
            # Final cleanup will take place in _cleanup().
286
            dead = gevent.event.Event()
287
            dead.clear()
288

    
289
            # This callback will get called if any of the two workers dies.
290
            def callback(g):
291
                self.debug("Worker %d/%d died", self.workers.index(g),
292
                           len(self.workers))
293
                dead.set()
294

    
295
            self.workers.append(gevent.spawn(self._forward,
296
                                             self.client, self.server))
297
            self.workers.append(gevent.spawn(self._forward,
298
                                             self.server, self.client))
299
            for g in self.workers:
300
                g.link(callback)
301

    
302
            # Wait until any of the workers dies
303
            self.debug("Waiting for any of %d workers to die",
304
                       len(self.workers))
305
            dead.wait()
306

    
307
            # We can go now, _cleanup() will take care of
308
            # all worker, socket and port cleanup
309
            self.debug("A forwarder died, our work here is done")
310
            raise gevent.GreenletExit
311
        except Exception, e:
312
            # Any unhandled exception in the previous block
313
            # is an error and must be logged accordingly
314
            if not isinstance(e, gevent.GreenletExit):
315
                self.exception(e)
316
            raise e
317
        finally:
318
            self._cleanup()
319

    
320
# Logging support inside VncAuthproxy
321
# Wrap all common logging functions in logging-specific methods
322
for funcname in ["info", "debug", "warn", "error", "critical",
323
                 "exception"]:
324
    def gen(funcname):
325
        def wrapped_log_func(self, *args, **kwargs):
326
            func = getattr(self.log, funcname)
327
            func("[C%d] %s" % (self.id, args[0]), *args[1:], **kwargs)
328
        return wrapped_log_func
329
    setattr(VncAuthProxy, funcname, gen(funcname))
330

    
331

    
332
def fatal_signal_handler(signame):
333
    logger.info("Caught %s, will raise SystemExit", signame)
334
    raise SystemExit
335

    
336

    
337
def get_listening_sockets(sport):
338
    sockets = []
339

    
340
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
341
    # addresses do not work reliably everywhere (under linux it may have
342
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
343
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
344
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
345
        af, socktype, proto, canonname, sa = res
346
        try:
347
            s = None
348
            s = socket.socket(af, socktype, proto)
349
            if af == socket.AF_INET6:
350
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
351
                # will fail.
352
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
353
            s.bind(sa)
354
            s.listen(1)
355
            sockets.append(s)
356
            logger.debug("Listening on %s:%d", *sa[:2])
357
        except socket.error, msg:
358
            logger.error("Error binding to %s:%d: %s", sa[0], sa[1], msg[1])
359
            if s:
360
                s.close()
361
            while sockets:
362
                sockets.pop().close()
363

    
364
            # Make sure we fail immediately if we cannot get a socket
365
            raise msg
366

    
367
    return sockets
368

    
369

    
370
def perform_server_handshake(daddr, dport, tries, retry_wait):
371
    """
372
    Initiate a connection with the backend server and perform basic
373
    RFB 3.8 handshake with it.
374

375
    Return a socket connected to the backend server.
376

377
    """
378
    server = None
379

    
380
    while tries:
381
        tries -= 1
382

    
383
        # Initiate server connection
384
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
385
                                      socket.SOCK_STREAM, 0,
386
                                      socket.AI_PASSIVE):
387
            af, socktype, proto, canonname, sa = res
388
            try:
389
                server = socket.socket(af, socktype, proto)
390
            except socket.error:
391
                server = None
392
                continue
393

    
394
            try:
395
                logger.debug("Connecting to %s:%s", *sa[:2])
396
                server.connect(sa)
397
                logger.debug("Connection to %s:%s successful", *sa[:2])
398
            except socket.error:
399
                server.close()
400
                server = None
401
                continue
402

    
403
            # We succesfully connected to the server
404
            tries = 0
405
            break
406

    
407
        # Wait and retry
408
        gevent.sleep(retry_wait)
409

    
410
    if server is None:
411
        raise Exception("Failed to connect to server")
412

    
413
    version = server.recv(1024)
414
    if not rfb.check_version(version):
415
        raise Exception("Unsupported RFB version: %s" % version.strip())
416

    
417
    server.send(rfb.RFB_VERSION_3_8 + "\n")
418

    
419
    res = server.recv(1024)
420
    types = rfb.parse_auth_request(res)
421
    if not types:
422
        raise Exception("Error handshaking with the server")
423

    
424
    else:
425
        logger.debug("Supported authentication types: %s",
426
                     " ".join([str(x) for x in types]))
427

    
428
    if rfb.RFB_AUTHTYPE_NONE not in types:
429
        raise Exception("Error, server demands authentication")
430

    
431
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
432

    
433
    # Check authentication response
434
    res = server.recv(4)
435
    res = rfb.from_u32(res)
436

    
437
    if res != 0:
438
        raise Exception("Authentication error")
439

    
440
    return server
441

    
442

    
443
def parse_arguments(args):
444
    from optparse import OptionParser
445

    
446
    parser = OptionParser()
447
    parser.add_option("-s", "--socket", dest="ctrl_socket",
448
                      default=DEFAULT_CTRL_SOCKET,
449
                      metavar="PATH",
450
                      help=("UNIX socket for control connections (default: "
451
                            "%s" % DEFAULT_CTRL_SOCKET))
452
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
453
                      help="Enable debugging information")
454
    parser.add_option("-l", "--log", dest="log_file",
455
                      default=DEFAULT_LOG_FILE,
456
                      metavar="FILE",
457
                      help=("Write log to FILE instead of %s" %
458
                            DEFAULT_LOG_FILE))
459
    parser.add_option('--pid-file', dest="pid_file",
460
                      default=DEFAULT_PID_FILE,
461
                      metavar='PIDFILE',
462
                      help=("Save PID to file (default: %s)" %
463
                            DEFAULT_PID_FILE))
464
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
465
                      default=DEFAULT_CONNECT_TIMEOUT, type="int",
466
                      metavar="SECONDS", help=("Wait SECONDS sec for a client "
467
                                               "to connect"))
468
    parser.add_option("-r", "--connect-retries", dest="connect_retries",
469
                      default=DEFAULT_CONNECT_RETRIES, type="int",
470
                      metavar="RETRIES",
471
                      help="How many times to try to connect to the server")
472
    parser.add_option("-w", "--retry-wait", dest="retry_wait",
473
                      default=DEFAULT_RETRY_WAIT, type="float",
474
                      metavar="SECONDS", help=("Retry connection to server "
475
                                               "every SECONDS sec"))
476
    parser.add_option("-p", "--min-port", dest="min_port",
477
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
478
                      help=("The minimum port number to use for automatically-"
479
                            "allocated ephemeral ports"))
480
    parser.add_option("-P", "--max-port", dest="max_port",
481
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
482
                      help=("The maximum port number to use for automatically-"
483
                            "allocated ephemeral ports"))
484

    
485
    return parser.parse_args(args)
486

    
487

    
488
def main():
489
    """Run the daemon from the command line"""
490

    
491
    (opts, args) = parse_arguments(sys.argv[1:])
492

    
493
    # Create pidfile
494
    pidf = pidlockfile.TimeoutPIDLockFile(opts.pid_file, 10)
495

    
496
    # Initialize logger
497
    lvl = logging.DEBUG if opts.debug else logging.INFO
498

    
499
    global logger
500
    logger = logging.getLogger("vncauthproxy")
501
    logger.setLevel(lvl)
502
    formatter = logging.Formatter(("%(asctime)s %(module)s[%(process)d] "
503
                                   " %(levelname)s: %(message)s"),
504
                                  "%Y-%m-%d %H:%M:%S")
505
    handler = logging.FileHandler(opts.log_file)
506
    handler.setFormatter(formatter)
507
    logger.addHandler(handler)
508

    
509
    # Become a daemon:
510
    # Redirect stdout and stderr to handler.stream to catch
511
    # early errors in the daemonization process [e.g., pidfile creation]
512
    # which will otherwise go to /dev/null.
513
    daemon_context = AllFilesDaemonContext(
514
        pidfile=pidf,
515
        umask=0022,
516
        stdout=handler.stream,
517
        stderr=handler.stream,
518
        files_preserve=[handler.stream])
519

    
520
    # Remove any stale PID files, left behind by previous invocations
521
    if daemon.runner.is_pidfile_stale(pidf):
522
        logger.warning("Removing stale PID lock file %s", pidf.path)
523
        pidf.break_lock()
524

    
525
    try:
526
        daemon_context.open()
527
    except (AlreadyLocked, LockTimeout):
528
        logger.critical(("Failed to lock PID file %s, another instance "
529
                         "running?"), pidf.path)
530
        sys.exit(1)
531
    logger.info("Became a daemon")
532

    
533
    # A fork() has occured while daemonizing,
534
    # we *must* reinit gevent
535
    gevent.reinit()
536

    
537
    if os.path.exists(opts.ctrl_socket):
538
        logger.critical("Socket '%s' already exists", opts.ctrl_socket)
539
        sys.exit(1)
540

    
541
    # TODO: make this tunable? chgrp as well?
542
    old_umask = os.umask(0007)
543

    
544
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
545
    ctrl.bind(opts.ctrl_socket)
546

    
547
    os.umask(old_umask)
548

    
549
    ctrl.listen(1)
550
    logger.info("Initialized, waiting for control connections at %s",
551
                opts.ctrl_socket)
552

    
553
    # Catch signals to ensure graceful shutdown,
554
    # e.g., to make sure the control socket gets unlink()ed.
555
    #
556
    # Uses gevent.signal so the handler fires even during
557
    # gevent.socket.accept()
558
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
559
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
560

    
561
    # Init ephemeral port pool
562
    ports = range(opts.min_port, opts.max_port + 1)
563

    
564
    while True:
565
        try:
566
            client, addr = ctrl.accept()
567
            logger.info("New control connection")
568

    
569
            # Receive and parse a client request.
570
            response = {
571
                "source_port": 0,
572
                "status": "FAILED"
573
            }
574
            try:
575
                # TODO: support multiple forwardings in the same message?
576
                #
577
                # Control request, in JSON:
578
                #
579
                # {
580
                #     "source_port":
581
                #         <source port or 0 for automatic allocation>,
582
                #     "destination_address":
583
                #         <destination address of backend server>,
584
                #     "destination_port":
585
                #         <destination port>
586
                #     "password":
587
                #         <the password to use to authenticate clients>
588
                # }
589
                #
590
                # The <password> is used for MITM authentication of clients
591
                # connecting to <source_port>, who will subsequently be
592
                # forwarded to a VNC server listening at
593
                # <destination_address>:<destination_port>
594
                #
595
                # Control reply, in JSON:
596
                # {
597
                #     "source_port": <the allocated source port>
598
                #     "status": <one of "OK" or "FAILED">
599
                # }
600
                #
601
                buf = client.recv(1024)
602
                req = json.loads(buf)
603

    
604
                sport_orig = int(req['source_port'])
605
                daddr = req['destination_address']
606
                dport = int(req['destination_port'])
607
                password = req['password']
608
            except Exception, e:
609
                logger.warn("Malformed request: %s", buf)
610
                client.send(json.dumps(response))
611
                client.close()
612
                continue
613

    
614
            # Spawn a new Greenlet to service the request.
615
            server = None
616
            try:
617
                # If the client has so indicated, pick an ephemeral source port
618
                # randomly, and remove it from the port pool.
619
                if sport_orig == 0:
620
                    sport = random.choice(ports)
621
                    ports.remove(sport)
622
                    logger.debug("Got port %d from pool, %d remaining",
623
                                 sport, len(ports))
624
                    pool = ports
625
                else:
626
                    sport = sport_orig
627
                    pool = None
628

    
629
                listeners = get_listening_sockets(sport)
630
                server = perform_server_handshake(daddr, dport,
631
                                                  opts.connect_retries,
632
                                                  opts.retry_wait)
633

    
634
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
635
                                   server, password, opts.connect_timeout)
636

    
637
                logger.info("New forwarding: %d (client req'd: %d) -> %s:%d",
638
                            sport, sport_orig, daddr, dport)
639
                response = {"source_port": sport,
640
                            "status": "OK"}
641
            except IndexError:
642
                logger.error(("FAILED forwarding, out of ports for [req'd by "
643
                              "client: %d -> %s:%d]"),
644
                             sport_orig, daddr, dport)
645
            except Exception, msg:
646
                logger.error(msg)
647
                logger.error(("FAILED forwarding: %d (client req'd: %d) -> "
648
                              "%s:%d"), sport, sport_orig, daddr, dport)
649
                if not pool is None:
650
                    pool.append(sport)
651
                    logger.debug("Returned port %d to pool, %d remanining",
652
                                 sport, len(pool))
653
                if not server is None:
654
                    server.close()
655
            finally:
656
                client.send(json.dumps(response))
657
                client.close()
658
        except Exception, e:
659
            logger.exception(e)
660
            continue
661
        except SystemExit:
662
            break
663

    
664
    logger.info("Unlinking control socket at %s", opts.ctrl_socket)
665
    os.unlink(opts.ctrl_socket)
666
    daemon_context.close()
667
    sys.exit(0)