Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ c87d99e9

History | View | Annotate | Download (22.4 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2013 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
DEFAULT_CONNECT_RETRIES = 3
28
DEFAULT_RETRY_WAIT = 0.1
29
# Default values per http://www.iana.org/assignments/port-numbers
30
DEFAULT_MIN_PORT = 49152
31
DEFAULT_MAX_PORT = 65535
32

    
33
import os
34
import sys
35
import logging
36
import gevent
37
import daemon
38
import random
39
import daemon.runner
40

    
41
import rfb
42

    
43
try:
44
    import simplejson as json
45
except ImportError:
46
    import json
47

    
48
from gevent import socket
49
from signal import SIGINT, SIGTERM
50
from gevent.select import select
51
from time import sleep
52

    
53
from lockfile import LockTimeout, AlreadyLocked
54
# Take care of differences between python-daemon versions.
55
try:
56
    from daemon import pidfile as pidlockfile
57
except:
58
    from daemon import pidlockfile
59

    
60

    
61
logger = None
62

    
63

    
64
# Currently, gevent uses libevent-dns for asynchronous DNS resolution,
65
# which opens a socket upon initialization time. Since we can't get the fd
66
# reliably, We have to maintain all file descriptors open (which won't harm
67
# anyway)
68
class AllFilesDaemonContext(daemon.DaemonContext):
69
    """DaemonContext class keeping all file descriptors open"""
70
    def _get_exclude_file_descriptors(self):
71
        class All:
72
            def __contains__(self, value):
73
                return True
74
        return All()
75

    
76

    
77
class VncAuthProxy(gevent.Greenlet):
78
    """
79
    Simple class implementing a VNC Forwarder with MITM authentication as a
80
    Greenlet
81

82
    VncAuthProxy forwards VNC traffic from a specified port of the local host
83
    to a specified remote host:port. Furthermore, it implements VNC
84
    Authentication, intercepting the client/server handshake and asking the
85
    client for authentication even if the backend requires none.
86

87
    It is primarily intended for use in virtualization environments, as a VNC
88
    ``switch''.
89

90
    """
91
    id = 1
92

    
93
    def __init__(self, logger, listeners, pool, daddr, dport, server, password,
94
                 connect_timeout):
95
        """
96
        @type logger: logging.Logger
97
        @param logger: the logger to use
98
        @type listeners: list
99
        @param listeners: list of listening sockets to use for clients
100
        @type pool: list
101
        @param pool: if not None, return the client number into this port pool
102
        @type daddr: str
103
        @param daddr: destination address (IPv4, IPv6 or hostname)
104
        @type dport: int
105
        @param dport: destination port
106
        @type server: socket
107
        @param server: VNC server socket
108
        @type password: str
109
        @param password: password to request from the client
110
        @type connect_timeout: int
111
        @param connect_timeout: how long to wait for client connections
112
                                (seconds)
113

114
        """
115
        gevent.Greenlet.__init__(self)
116
        self.id = VncAuthProxy.id
117
        VncAuthProxy.id += 1
118
        self.log = logger
119
        self.listeners = listeners
120
        # All listening sockets are assumed to be on the same port
121
        self.sport = listeners[0].getsockname()[1]
122
        self.pool = pool
123
        self.daddr = daddr
124
        self.dport = dport
125
        self.server = server
126
        self.password = password
127
        self.client = None
128
        self.timeout = connect_timeout
129

    
130
    def _cleanup(self):
131
        """Close all active sockets and exit gracefully"""
132
        # Reintroduce the port number of the client socket in
133
        # the port pool, if applicable.
134
        if not self.pool is None:
135
            self.pool.append(self.sport)
136
            self.debug("Returned port %d to port pool, contains %d ports",
137
                       self.sport, len(self.pool))
138

    
139
        while self.listeners:
140
            self.listeners.pop().close()
141
        if self.server:
142
            self.server.close()
143
        if self.client:
144
            self.client.close()
145

    
146
        raise gevent.GreenletExit
147

    
148
    def __str__(self):
149
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr,
150
                                              self.dport)
151

    
152
    def _forward(self, source, dest):
153
        """
154
        Forward traffic from source to dest
155

156
        @type source: socket
157
        @param source: source socket
158
        @type dest: socket
159
        @param dest: destination socket
160

161
        """
162

    
163
        while True:
164
            d = source.recv(16384)
165
            if d == '':
166
                if source == self.client:
167
                    self.info("Client connection closed")
168
                else:
169
                    self.info("Server connection closed")
170
                break
171
            dest.sendall(d)
172
        # No need to close the source and dest sockets here.
173
        # They are owned by and will be closed by the original greenlet.
174

    
175
    def _client_handshake(self):
176
        """
177
        Perform handshake/authentication with a connecting client
178

179
        Outline:
180
        1. Client connects
181
        2. We fake RFB 3.8 protocol and require VNC authentication
182
           [processing also supports RFB 3.3]
183
        3. Client accepts authentication method
184
        4. We send an authentication challenge
185
        5. Client sends the authentication response
186
        6. We check the authentication
187

188
        Upon return, self.client socket is connected to the client.
189

190
        """
191
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
192
        client_version_str = self.client.recv(1024)
193
        client_version = rfb.check_version(client_version_str)
194
        if not client_version:
195
            self.error("Invalid version: %s", client_version_str)
196
            raise gevent.GreenletExit
197

    
198
        # Both for RFB 3.3 and 3.8
199
        self.debug("Requesting authentication")
200
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
201
                                             version=client_version)
202
        self.client.send(auth_request)
203

    
204
        # The client gets to propose an authtype only for RFB 3.8
205
        if client_version == rfb.RFB_VERSION_3_8:
206
            res = self.client.recv(1024)
207
            type = rfb.parse_client_authtype(res)
208
            if type == rfb.RFB_AUTHTYPE_ERROR:
209
                self.warn("Client refused authentication: %s", res[1:])
210
            else:
211
                self.debug("Client requested authtype %x", type)
212

    
213
            if type != rfb.RFB_AUTHTYPE_VNC:
214
                self.error("Wrong auth type: %d", type)
215
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
216
                raise gevent.GreenletExit
217

    
218
        # Generate the challenge
219
        challenge = os.urandom(16)
220
        self.client.send(challenge)
221
        response = self.client.recv(1024)
222
        if len(response) != 16:
223
            self.error("Wrong response length %d, should be 16", len(response))
224
            raise gevent.GreenletExit
225

    
226
        if rfb.check_password(challenge, response, self.password):
227
            self.debug("Authentication successful")
228
        else:
229
            self.warn("Authentication failed")
230
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
231
            raise gevent.GreenletExit
232

    
233
        # Accept the authentication
234
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
235

    
236
    def _run(self):
237
        try:
238
            self.info("Waiting for a client to connect at %s",
239
                      ", ".join(["%s:%d" % s.getsockname()[:2]
240
                                 for s in self.listeners]))
241
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
242

    
243
            if not rlist:
244
                self.info("Timed out, no connection after %d sec",
245
                          self.timeout)
246
                raise gevent.GreenletExit
247

    
248
            for sock in rlist:
249
                self.client, addrinfo = sock.accept()
250
                self.info("Connection from %s:%d", addrinfo[:2])
251

    
252
                # Close all listening sockets, we only want a one-shot
253
                # connection from a single client.
254
                while self.listeners:
255
                    self.listeners.pop().close()
256
                break
257

    
258
            # Perform RFB handshake with the client.
259
            self._client_handshake()
260

    
261
            # Bridge both connections through two "forwarder" greenlets.
262
            self.workers = [gevent.spawn(self._forward,
263
                                         self.client, self.server),
264
                            gevent.spawn(self._forward,
265
                                         self.server, self.client)]
266

    
267
            # If one greenlet goes, the other has to go too.
268
            self.workers[0].link(self.workers[1])
269
            self.workers[1].link(self.workers[0])
270
            gevent.joinall(self.workers)
271
            del self.workers
272
            raise gevent.GreenletExit
273
        except Exception, e:
274
            # Any unhandled exception in the previous block
275
            # is an error and must be logged accordingly
276
            if not isinstance(e, gevent.GreenletExit):
277
                self.exception(e)
278
            raise e
279
        finally:
280
            self._cleanup()
281

    
282
# Logging support inside VncAuthproxy
283
# Wrap all common logging functions in logging-specific methods
284
for funcname in ["info", "debug", "warn", "error", "critical",
285
                 "exception"]:
286
    def gen(funcname):
287
        def wrapped_log_func(self, *args, **kwargs):
288
            func = getattr(self.log, funcname)
289
            func("[C%d] %s" % (self.id, args[0]), *args[1:], **kwargs)
290
        return wrapped_log_func
291
    setattr(VncAuthProxy, funcname, gen(funcname))
292

    
293

    
294
def fatal_signal_handler(signame):
295
    logger.info("Caught %s, will raise SystemExit", signame)
296
    raise SystemExit
297

    
298

    
299
def get_listening_sockets(sport):
300
    sockets = []
301

    
302
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
303
    # addresses do not work reliably everywhere (under linux it may have
304
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
305
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
306
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
307
        af, socktype, proto, canonname, sa = res
308
        try:
309
            s = None
310
            s = socket.socket(af, socktype, proto)
311
            if af == socket.AF_INET6:
312
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
313
                # will fail.
314
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
315
            s.bind(sa)
316
            s.listen(1)
317
            sockets.append(s)
318
            logger.debug("Listening on %s:%d", sa[:2])
319
        except socket.error, msg:
320
            logger.error("Error binding to %s:%d: %s", sa[0], sa[1], msg[1])
321
            if s:
322
                s.close()
323
            while sockets:
324
                sockets.pop().close()
325

    
326
            # Make sure we fail immediately if we cannot get a socket
327
            raise msg
328

    
329
    return sockets
330

    
331

    
332
def perform_server_handshake(daddr, dport, tries, retry_wait):
333
    """
334
    Initiate a connection with the backend server and perform basic
335
    RFB 3.8 handshake with it.
336

337
    Return a socket connected to the backend server.
338

339
    """
340
    server = None
341

    
342
    while tries:
343
        tries -= 1
344

    
345
        # Initiate server connection
346
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
347
                                      socket.SOCK_STREAM, 0,
348
                                      socket.AI_PASSIVE):
349
            af, socktype, proto, canonname, sa = res
350
            try:
351
                server = socket.socket(af, socktype, proto)
352
            except socket.error:
353
                server = None
354
                continue
355

    
356
            try:
357
                logger.debug("Connecting to %s:%s", sa[:2])
358
                server.connect(sa)
359
                logger.debug("Connection to %s:%s successful", sa[:2])
360
            except socket.error:
361
                server.close()
362
                server = None
363
                continue
364

    
365
            # We succesfully connected to the server
366
            tries = 0
367
            break
368

    
369
        # Wait and retry
370
        sleep(retry_wait)
371

    
372
    if server is None:
373
        raise Exception("Failed to connect to server")
374

    
375
    version = server.recv(1024)
376
    if not rfb.check_version(version):
377
        raise Exception("Unsupported RFB version: %s" % version.strip())
378

    
379
    server.send(rfb.RFB_VERSION_3_8 + "\n")
380

    
381
    res = server.recv(1024)
382
    types = rfb.parse_auth_request(res)
383
    if not types:
384
        raise Exception("Error handshaking with the server")
385

    
386
    else:
387
        logger.debug("Supported authentication types: %s",
388
                     " ".join([str(x) for x in types]))
389

    
390
    if rfb.RFB_AUTHTYPE_NONE not in types:
391
        raise Exception("Error, server demands authentication")
392

    
393
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
394

    
395
    # Check authentication response
396
    res = server.recv(4)
397
    res = rfb.from_u32(res)
398

    
399
    if res != 0:
400
        raise Exception("Authentication error")
401

    
402
    return server
403

    
404

    
405
def parse_arguments(args):
406
    from optparse import OptionParser
407

    
408
    parser = OptionParser()
409
    parser.add_option("-s", "--socket", dest="ctrl_socket",
410
                      default=DEFAULT_CTRL_SOCKET,
411
                      metavar="PATH",
412
                      help=("UNIX socket for control connections (default: "
413
                            "%s" % DEFAULT_CTRL_SOCKET))
414
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
415
                      help="Enable debugging information")
416
    parser.add_option("-l", "--log", dest="log_file",
417
                      default=DEFAULT_LOG_FILE,
418
                      metavar="FILE",
419
                      help=("Write log to FILE instead of %s" %
420
                            DEFAULT_LOG_FILE))
421
    parser.add_option('--pid-file', dest="pid_file",
422
                      default=DEFAULT_PID_FILE,
423
                      metavar='PIDFILE',
424
                      help=("Save PID to file (default: %s)" %
425
                            DEFAULT_PID_FILE))
426
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
427
                      default=DEFAULT_CONNECT_TIMEOUT, type="int",
428
                      metavar="SECONDS", help=("Wait SECONDS sec for a client "
429
                                               "to connect"))
430
    parser.add_option("-r", "--connect-retries", dest="connect_retries",
431
                      default=DEFAULT_CONNECT_RETRIES, type="int",
432
                      metavar="RETRIES",
433
                      help="How many times to try to connect to the server")
434
    parser.add_option("-w", "--retry-wait", dest="retry_wait",
435
                      default=DEFAULT_RETRY_WAIT, type="float",
436
                      metavar="SECONDS", help=("Retry connection to server "
437
                                               "every SECONDS sec"))
438
    parser.add_option("-p", "--min-port", dest="min_port",
439
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
440
                      help=("The minimum port number to use for automatically-"
441
                            "allocated ephemeral ports"))
442
    parser.add_option("-P", "--max-port", dest="max_port",
443
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
444
                      help=("The maximum port number to use for automatically-"
445
                            "allocated ephemeral ports"))
446

    
447
    return parser.parse_args(args)
448

    
449

    
450
def main():
451
    """Run the daemon from the command line."""
452

    
453
    (opts, args) = parse_arguments(sys.argv[1:])
454

    
455
    # Create pidfile
456
    pidf = pidlockfile.TimeoutPIDLockFile(opts.pid_file, 10)
457

    
458
    # Initialize logger
459
    lvl = logging.DEBUG if opts.debug else logging.INFO
460

    
461
    global logger
462
    logger = logging.getLogger("vncauthproxy")
463
    logger.setLevel(lvl)
464
    formatter = logging.Formatter(("%(asctime)s %(module)s[%(process)d] "
465
                                   " %(levelname)s: %(message)s"),
466
                                  "%Y-%m-%d %H:%M:%S")
467
    handler = logging.FileHandler(opts.log_file)
468
    handler.setFormatter(formatter)
469
    logger.addHandler(handler)
470

    
471
    # Become a daemon:
472
    # Redirect stdout and stderr to handler.stream to catch
473
    # early errors in the daemonization process [e.g., pidfile creation]
474
    # which will otherwise go to /dev/null.
475
    daemon_context = AllFilesDaemonContext(
476
        pidfile=pidf,
477
        umask=0022,
478
        stdout=handler.stream,
479
        stderr=handler.stream,
480
        files_preserve=[handler.stream])
481

    
482
    # Remove any stale PID files, left behind by previous invocations
483
    if daemon.runner.is_pidfile_stale(pidf):
484
        logger.warning("Removing stale PID lock file %s", pidf.path)
485
        pidf.break_lock()
486

    
487
    try:
488
        daemon_context.open()
489
    except (AlreadyLocked, LockTimeout):
490
        logger.critical(("Failed to lock PID file %s, another instance "
491
                         "running?"), pidf.path)
492
        sys.exit(1)
493
    logger.info("Became a daemon")
494

    
495
    # A fork() has occured while daemonizing,
496
    # we *must* reinit gevent
497
    gevent.reinit()
498

    
499
    if os.path.exists(opts.ctrl_socket):
500
        logger.critical("Socket '%s' already exists" % opts.ctrl_socket)
501
        sys.exit(1)
502

    
503
    # TODO: make this tunable? chgrp as well?
504
    old_umask = os.umask(0007)
505

    
506
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
507
    ctrl.bind(opts.ctrl_socket)
508

    
509
    os.umask(old_umask)
510

    
511
    ctrl.listen(1)
512
    logger.info(("Initialized, waiting for control connections at %s" %
513
                 opts.ctrl_socket))
514

    
515
    # Catch signals to ensure graceful shutdown,
516
    # e.g., to make sure the control socket gets unlink()ed.
517
    #
518
    # Uses gevent.signal so the handler fires even during
519
    # gevent.socket.accept()
520
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
521
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
522

    
523
    # Init ephemeral port pool
524
    ports = range(opts.min_port, opts.max_port + 1)
525

    
526
    while True:
527
        try:
528
            client, addr = ctrl.accept()
529
            logger.info("New control connection")
530

    
531
            # Receive and parse a client request.
532
            response = {
533
                "source_port": 0,
534
                "status": "FAILED"
535
            }
536
            try:
537
                # TODO: support multiple forwardings in the same message?
538
                #
539
                # Control request, in JSON:
540
                #
541
                # {
542
                #     "source_port":
543
                #         <source port or 0 for automatic allocation>,
544
                #     "destination_address":
545
                #         <destination address of backend server>,
546
                #     "destination_port":
547
                #         <destination port>
548
                #     "password":
549
                #         <the password to use to authenticate clients>
550
                # }
551
                #
552
                # The <password> is used for MITM authentication of clients
553
                # connecting to <source_port>, who will subsequently be
554
                # forwarded to a VNC server listening at
555
                # <destination_address>:<destination_port>
556
                #
557
                # Control reply, in JSON:
558
                # {
559
                #     "source_port": <the allocated source port>
560
                #     "status": <one of "OK" or "FAILED">
561
                # }
562
                #
563
                buf = client.recv(1024)
564
                req = json.loads(buf)
565

    
566
                sport_orig = int(req['source_port'])
567
                daddr = req['destination_address']
568
                dport = int(req['destination_port'])
569
                password = req['password']
570
            except Exception, e:
571
                logger.warn("Malformed request: %s" % buf)
572
                client.send(json.dumps(response))
573
                client.close()
574
                continue
575

    
576
            # Spawn a new Greenlet to service the request.
577
            server = None
578
            try:
579
                # If the client has so indicated, pick an ephemeral source port
580
                # randomly, and remove it from the port pool.
581
                if sport_orig == 0:
582
                    sport = random.choice(ports)
583
                    ports.remove(sport)
584
                    logger.debug(("Got port %d from pool, %d remaining",
585
                                  sport, len(ports)))
586
                    pool = ports
587
                else:
588
                    sport = sport_orig
589
                    pool = None
590

    
591
                listeners = get_listening_sockets(sport)
592
                server = perform_server_handshake(daddr, dport,
593
                                                  opts.connect_retries,
594
                                                  opts.retry_wait)
595

    
596
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
597
                                   server, password, opts.connect_timeout)
598

    
599
                logger.info(("New forwarding: %d (client req'd: %d) -> %s:%d" %
600
                             (sport, sport_orig, daddr, dport)))
601
                response = {"source_port": sport,
602
                            "status": "OK"}
603
            except IndexError:
604
                logger.error(("FAILED forwarding, out of ports for [req'd by "
605
                              "client: %d -> %s:%d]" % (sport_orig, daddr,
606
                                                        dport)))
607
            except Exception, msg:
608
                logger.error(msg)
609
                logger.error(("FAILED forwarding: %d (client req'd: %d) -> "
610
                              "%s:%d" % (sport, sport_orig, daddr, dport)))
611
                if not pool is None:
612
                    pool.append(sport)
613
                    logger.debug(("Returned port %d to pool, %d remanining",
614
                                  sport, len(pool)))
615
                if not server is None:
616
                    server.close()
617
            finally:
618
                client.send(json.dumps(response))
619
                client.close()
620
        except Exception, e:
621
            logger.exception(e)
622
            continue
623
        except SystemExit:
624
            break
625

    
626
    logger.info("Unlinking control socket at %s" % opts.ctrl_socket)
627
    os.unlink(opts.ctrl_socket)
628
    daemon_context.close()
629
    sys.exit(0)