Statistics
| Branch: | Tag: | Revision:

root / vncauthproxy / proxy.py @ 0423d976

History | View | Annotate | Download (22.4 kB)

1
#!/usr/bin/env python
2
"""
3
vncauthproxy - a VNC authentication proxy
4
"""
5
#
6
# Copyright (c) 2010-2013 Greek Research and Technology Network S.A.
7
#
8
# This program is free software; you can redistribute it and/or modify
9
# it under the terms of the GNU General Public License as published by
10
# the Free Software Foundation; either version 2 of the License, or
11
# (at your option) any later version.
12
#
13
# This program is distributed in the hope that it will be useful, but
14
# WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16
# General Public License for more details.
17
#
18
# You should have received a copy of the GNU General Public License
19
# along with this program; if not, write to the Free Software
20
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21
# 02110-1301, USA.
22

    
23
DEFAULT_CTRL_SOCKET = "/var/run/vncauthproxy/ctrl.sock"
24
DEFAULT_LOG_FILE = "/var/log/vncauthproxy/vncauthproxy.log"
25
DEFAULT_PID_FILE = "/var/run/vncauthproxy/vncauthproxy.pid"
26
DEFAULT_CONNECT_TIMEOUT = 30
27
DEFAULT_CONNECT_RETRIES = 3
28
DEFAULT_RETRY_WAIT = 0.1
29
# Default values per http://www.iana.org/assignments/port-numbers
30
DEFAULT_MIN_PORT = 49152
31
DEFAULT_MAX_PORT = 65535
32

    
33
import os
34
import sys
35
import logging
36
import gevent
37
import daemon
38
import random
39
import daemon.runner
40

    
41
import rfb
42

    
43
try:
44
    import simplejson as json
45
except ImportError:
46
    import json
47

    
48
from gevent import socket
49
from signal import SIGINT, SIGTERM
50
from gevent.select import select
51

    
52
from lockfile import LockTimeout, AlreadyLocked
53
# Take care of differences between python-daemon versions.
54
try:
55
    from daemon import pidfile as pidlockfile
56
except:
57
    from daemon import pidlockfile
58

    
59

    
60
logger = None
61

    
62

    
63
# Currently, gevent uses libevent-dns for asynchronous DNS resolution,
64
# which opens a socket upon initialization time. Since we can't get the fd
65
# reliably, We have to maintain all file descriptors open (which won't harm
66
# anyway)
67
class AllFilesDaemonContext(daemon.DaemonContext):
68
    """DaemonContext class keeping all file descriptors open"""
69
    def _get_exclude_file_descriptors(self):
70
        class All:
71
            def __contains__(self, value):
72
                return True
73
        return All()
74

    
75

    
76
class VncAuthProxy(gevent.Greenlet):
77
    """
78
    Simple class implementing a VNC Forwarder with MITM authentication as a
79
    Greenlet
80

81
    VncAuthProxy forwards VNC traffic from a specified port of the local host
82
    to a specified remote host:port. Furthermore, it implements VNC
83
    Authentication, intercepting the client/server handshake and asking the
84
    client for authentication even if the backend requires none.
85

86
    It is primarily intended for use in virtualization environments, as a VNC
87
    ``switch''.
88

89
    """
90
    id = 1
91

    
92
    def __init__(self, logger, listeners, pool, daddr, dport, server, password,
93
                 connect_timeout):
94
        """
95
        @type logger: logging.Logger
96
        @param logger: the logger to use
97
        @type listeners: list
98
        @param listeners: list of listening sockets to use for clients
99
        @type pool: list
100
        @param pool: if not None, return the client number into this port pool
101
        @type daddr: str
102
        @param daddr: destination address (IPv4, IPv6 or hostname)
103
        @type dport: int
104
        @param dport: destination port
105
        @type server: socket
106
        @param server: VNC server socket
107
        @type password: str
108
        @param password: password to request from the client
109
        @type connect_timeout: int
110
        @param connect_timeout: how long to wait for client connections
111
                                (seconds)
112

113
        """
114
        gevent.Greenlet.__init__(self)
115
        self.id = VncAuthProxy.id
116
        VncAuthProxy.id += 1
117
        self.log = logger
118
        self.listeners = listeners
119
        # All listening sockets are assumed to be on the same port
120
        self.sport = listeners[0].getsockname()[1]
121
        self.pool = pool
122
        self.daddr = daddr
123
        self.dport = dport
124
        self.server = server
125
        self.password = password
126
        self.client = None
127
        self.timeout = connect_timeout
128

    
129
    def _cleanup(self):
130
        """Close all active sockets and exit gracefully"""
131
        # Reintroduce the port number of the client socket in
132
        # the port pool, if applicable.
133
        if not self.pool is None:
134
            self.pool.append(self.sport)
135
            self.debug("Returned port %d to port pool, contains %d ports",
136
                       self.sport, len(self.pool))
137

    
138
        while self.listeners:
139
            self.listeners.pop().close()
140
        if self.server:
141
            self.server.close()
142
        if self.client:
143
            self.client.close()
144

    
145
        raise gevent.GreenletExit
146

    
147
    def __str__(self):
148
        return "VncAuthProxy: %d -> %s:%d" % (self.sport, self.daddr,
149
                                              self.dport)
150

    
151
    def _forward(self, source, dest):
152
        """
153
        Forward traffic from source to dest
154

155
        @type source: socket
156
        @param source: source socket
157
        @type dest: socket
158
        @param dest: destination socket
159

160
        """
161

    
162
        while True:
163
            d = source.recv(16384)
164
            if d == '':
165
                if source == self.client:
166
                    self.info("Client connection closed")
167
                else:
168
                    self.info("Server connection closed")
169
                break
170
            dest.sendall(d)
171
        # No need to close the source and dest sockets here.
172
        # They are owned by and will be closed by the original greenlet.
173

    
174
    def _client_handshake(self):
175
        """
176
        Perform handshake/authentication with a connecting client
177

178
        Outline:
179
        1. Client connects
180
        2. We fake RFB 3.8 protocol and require VNC authentication
181
           [processing also supports RFB 3.3]
182
        3. Client accepts authentication method
183
        4. We send an authentication challenge
184
        5. Client sends the authentication response
185
        6. We check the authentication
186

187
        Upon return, self.client socket is connected to the client.
188

189
        """
190
        self.client.send(rfb.RFB_VERSION_3_8 + "\n")
191
        client_version_str = self.client.recv(1024)
192
        client_version = rfb.check_version(client_version_str)
193
        if not client_version:
194
            self.error("Invalid version: %s", client_version_str)
195
            raise gevent.GreenletExit
196

    
197
        # Both for RFB 3.3 and 3.8
198
        self.debug("Requesting authentication")
199
        auth_request = rfb.make_auth_request(rfb.RFB_AUTHTYPE_VNC,
200
                                             version=client_version)
201
        self.client.send(auth_request)
202

    
203
        # The client gets to propose an authtype only for RFB 3.8
204
        if client_version == rfb.RFB_VERSION_3_8:
205
            res = self.client.recv(1024)
206
            type = rfb.parse_client_authtype(res)
207
            if type == rfb.RFB_AUTHTYPE_ERROR:
208
                self.warn("Client refused authentication: %s", res[1:])
209
            else:
210
                self.debug("Client requested authtype %x", type)
211

    
212
            if type != rfb.RFB_AUTHTYPE_VNC:
213
                self.error("Wrong auth type: %d", type)
214
                self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
215
                raise gevent.GreenletExit
216

    
217
        # Generate the challenge
218
        challenge = os.urandom(16)
219
        self.client.send(challenge)
220
        response = self.client.recv(1024)
221
        if len(response) != 16:
222
            self.error("Wrong response length %d, should be 16", len(response))
223
            raise gevent.GreenletExit
224

    
225
        if rfb.check_password(challenge, response, self.password):
226
            self.debug("Authentication successful")
227
        else:
228
            self.warn("Authentication failed")
229
            self.client.send(rfb.to_u32(rfb.RFB_AUTH_ERROR))
230
            raise gevent.GreenletExit
231

    
232
        # Accept the authentication
233
        self.client.send(rfb.to_u32(rfb.RFB_AUTH_SUCCESS))
234

    
235
    def _run(self):
236
        try:
237
            self.info("Waiting for a client to connect at %s",
238
                      ", ".join(["%s:%d" % s.getsockname()[:2]
239
                                 for s in self.listeners]))
240
            rlist, _, _ = select(self.listeners, [], [], timeout=self.timeout)
241

    
242
            if not rlist:
243
                self.info("Timed out, no connection after %d sec",
244
                          self.timeout)
245
                raise gevent.GreenletExit
246

    
247
            for sock in rlist:
248
                self.client, addrinfo = sock.accept()
249
                self.info("Connection from %s:%d", addrinfo[:2])
250

    
251
                # Close all listening sockets, we only want a one-shot
252
                # connection from a single client.
253
                while self.listeners:
254
                    self.listeners.pop().close()
255
                break
256

    
257
            # Perform RFB handshake with the client.
258
            self._client_handshake()
259

    
260
            # Bridge both connections through two "forwarder" greenlets.
261
            self.workers = [gevent.spawn(self._forward,
262
                                         self.client, self.server),
263
                            gevent.spawn(self._forward,
264
                                         self.server, self.client)]
265

    
266
            # If one greenlet goes, the other has to go too.
267
            self.workers[0].link(self.workers[1])
268
            self.workers[1].link(self.workers[0])
269
            gevent.joinall(self.workers)
270
            del self.workers
271
            raise gevent.GreenletExit
272
        except Exception, e:
273
            # Any unhandled exception in the previous block
274
            # is an error and must be logged accordingly
275
            if not isinstance(e, gevent.GreenletExit):
276
                self.exception(e)
277
            raise e
278
        finally:
279
            self._cleanup()
280

    
281
# Logging support inside VncAuthproxy
282
# Wrap all common logging functions in logging-specific methods
283
for funcname in ["info", "debug", "warn", "error", "critical",
284
                 "exception"]:
285
    def gen(funcname):
286
        def wrapped_log_func(self, *args, **kwargs):
287
            func = getattr(self.log, funcname)
288
            func("[C%d] %s" % (self.id, args[0]), *args[1:], **kwargs)
289
        return wrapped_log_func
290
    setattr(VncAuthProxy, funcname, gen(funcname))
291

    
292

    
293
def fatal_signal_handler(signame):
294
    logger.info("Caught %s, will raise SystemExit", signame)
295
    raise SystemExit
296

    
297

    
298
def get_listening_sockets(sport):
299
    sockets = []
300

    
301
    # Use two sockets, one for IPv4, one for IPv6. IPv4-to-IPv6 mapped
302
    # addresses do not work reliably everywhere (under linux it may have
303
    # been disabled in /proc/sys/net/ipv6/bind_ipv6_only).
304
    for res in socket.getaddrinfo(None, sport, socket.AF_UNSPEC,
305
                                  socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
306
        af, socktype, proto, canonname, sa = res
307
        try:
308
            s = None
309
            s = socket.socket(af, socktype, proto)
310
            if af == socket.AF_INET6:
311
                # Bind v6 only when AF_INET6, otherwise either v4 or v6 bind
312
                # will fail.
313
                s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
314
            s.bind(sa)
315
            s.listen(1)
316
            sockets.append(s)
317
            logger.debug("Listening on %s:%d", sa[:2])
318
        except socket.error, msg:
319
            logger.error("Error binding to %s:%d: %s", sa[0], sa[1], msg[1])
320
            if s:
321
                s.close()
322
            while sockets:
323
                sockets.pop().close()
324

    
325
            # Make sure we fail immediately if we cannot get a socket
326
            raise msg
327

    
328
    return sockets
329

    
330

    
331
def perform_server_handshake(daddr, dport, tries, retry_wait):
332
    """
333
    Initiate a connection with the backend server and perform basic
334
    RFB 3.8 handshake with it.
335

336
    Return a socket connected to the backend server.
337

338
    """
339
    server = None
340

    
341
    while tries:
342
        tries -= 1
343

    
344
        # Initiate server connection
345
        for res in socket.getaddrinfo(daddr, dport, socket.AF_UNSPEC,
346
                                      socket.SOCK_STREAM, 0,
347
                                      socket.AI_PASSIVE):
348
            af, socktype, proto, canonname, sa = res
349
            try:
350
                server = socket.socket(af, socktype, proto)
351
            except socket.error:
352
                server = None
353
                continue
354

    
355
            try:
356
                logger.debug("Connecting to %s:%s", sa[:2])
357
                server.connect(sa)
358
                logger.debug("Connection to %s:%s successful", sa[:2])
359
            except socket.error:
360
                server.close()
361
                server = None
362
                continue
363

    
364
            # We succesfully connected to the server
365
            tries = 0
366
            break
367

    
368
        # Wait and retry
369
        gevent.sleep(retry_wait)
370

    
371
    if server is None:
372
        raise Exception("Failed to connect to server")
373

    
374
    version = server.recv(1024)
375
    if not rfb.check_version(version):
376
        raise Exception("Unsupported RFB version: %s" % version.strip())
377

    
378
    server.send(rfb.RFB_VERSION_3_8 + "\n")
379

    
380
    res = server.recv(1024)
381
    types = rfb.parse_auth_request(res)
382
    if not types:
383
        raise Exception("Error handshaking with the server")
384

    
385
    else:
386
        logger.debug("Supported authentication types: %s",
387
                     " ".join([str(x) for x in types]))
388

    
389
    if rfb.RFB_AUTHTYPE_NONE not in types:
390
        raise Exception("Error, server demands authentication")
391

    
392
    server.send(rfb.to_u8(rfb.RFB_AUTHTYPE_NONE))
393

    
394
    # Check authentication response
395
    res = server.recv(4)
396
    res = rfb.from_u32(res)
397

    
398
    if res != 0:
399
        raise Exception("Authentication error")
400

    
401
    return server
402

    
403

    
404
def parse_arguments(args):
405
    from optparse import OptionParser
406

    
407
    parser = OptionParser()
408
    parser.add_option("-s", "--socket", dest="ctrl_socket",
409
                      default=DEFAULT_CTRL_SOCKET,
410
                      metavar="PATH",
411
                      help=("UNIX socket for control connections (default: "
412
                            "%s" % DEFAULT_CTRL_SOCKET))
413
    parser.add_option("-d", "--debug", action="store_true", dest="debug",
414
                      help="Enable debugging information")
415
    parser.add_option("-l", "--log", dest="log_file",
416
                      default=DEFAULT_LOG_FILE,
417
                      metavar="FILE",
418
                      help=("Write log to FILE instead of %s" %
419
                            DEFAULT_LOG_FILE))
420
    parser.add_option('--pid-file', dest="pid_file",
421
                      default=DEFAULT_PID_FILE,
422
                      metavar='PIDFILE',
423
                      help=("Save PID to file (default: %s)" %
424
                            DEFAULT_PID_FILE))
425
    parser.add_option("-t", "--connect-timeout", dest="connect_timeout",
426
                      default=DEFAULT_CONNECT_TIMEOUT, type="int",
427
                      metavar="SECONDS", help=("Wait SECONDS sec for a client "
428
                                               "to connect"))
429
    parser.add_option("-r", "--connect-retries", dest="connect_retries",
430
                      default=DEFAULT_CONNECT_RETRIES, type="int",
431
                      metavar="RETRIES",
432
                      help="How many times to try to connect to the server")
433
    parser.add_option("-w", "--retry-wait", dest="retry_wait",
434
                      default=DEFAULT_RETRY_WAIT, type="float",
435
                      metavar="SECONDS", help=("Retry connection to server "
436
                                               "every SECONDS sec"))
437
    parser.add_option("-p", "--min-port", dest="min_port",
438
                      default=DEFAULT_MIN_PORT, type="int", metavar="MIN_PORT",
439
                      help=("The minimum port number to use for automatically-"
440
                            "allocated ephemeral ports"))
441
    parser.add_option("-P", "--max-port", dest="max_port",
442
                      default=DEFAULT_MAX_PORT, type="int", metavar="MAX_PORT",
443
                      help=("The maximum port number to use for automatically-"
444
                            "allocated ephemeral ports"))
445

    
446
    return parser.parse_args(args)
447

    
448

    
449
def main():
450
    """Run the daemon from the command line."""
451

    
452
    (opts, args) = parse_arguments(sys.argv[1:])
453

    
454
    # Create pidfile
455
    pidf = pidlockfile.TimeoutPIDLockFile(opts.pid_file, 10)
456

    
457
    # Initialize logger
458
    lvl = logging.DEBUG if opts.debug else logging.INFO
459

    
460
    global logger
461
    logger = logging.getLogger("vncauthproxy")
462
    logger.setLevel(lvl)
463
    formatter = logging.Formatter(("%(asctime)s %(module)s[%(process)d] "
464
                                   " %(levelname)s: %(message)s"),
465
                                  "%Y-%m-%d %H:%M:%S")
466
    handler = logging.FileHandler(opts.log_file)
467
    handler.setFormatter(formatter)
468
    logger.addHandler(handler)
469

    
470
    # Become a daemon:
471
    # Redirect stdout and stderr to handler.stream to catch
472
    # early errors in the daemonization process [e.g., pidfile creation]
473
    # which will otherwise go to /dev/null.
474
    daemon_context = AllFilesDaemonContext(
475
        pidfile=pidf,
476
        umask=0022,
477
        stdout=handler.stream,
478
        stderr=handler.stream,
479
        files_preserve=[handler.stream])
480

    
481
    # Remove any stale PID files, left behind by previous invocations
482
    if daemon.runner.is_pidfile_stale(pidf):
483
        logger.warning("Removing stale PID lock file %s", pidf.path)
484
        pidf.break_lock()
485

    
486
    try:
487
        daemon_context.open()
488
    except (AlreadyLocked, LockTimeout):
489
        logger.critical(("Failed to lock PID file %s, another instance "
490
                         "running?"), pidf.path)
491
        sys.exit(1)
492
    logger.info("Became a daemon")
493

    
494
    # A fork() has occured while daemonizing,
495
    # we *must* reinit gevent
496
    gevent.reinit()
497

    
498
    if os.path.exists(opts.ctrl_socket):
499
        logger.critical("Socket '%s' already exists" % opts.ctrl_socket)
500
        sys.exit(1)
501

    
502
    # TODO: make this tunable? chgrp as well?
503
    old_umask = os.umask(0007)
504

    
505
    ctrl = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
506
    ctrl.bind(opts.ctrl_socket)
507

    
508
    os.umask(old_umask)
509

    
510
    ctrl.listen(1)
511
    logger.info(("Initialized, waiting for control connections at %s" %
512
                 opts.ctrl_socket))
513

    
514
    # Catch signals to ensure graceful shutdown,
515
    # e.g., to make sure the control socket gets unlink()ed.
516
    #
517
    # Uses gevent.signal so the handler fires even during
518
    # gevent.socket.accept()
519
    gevent.signal(SIGINT, fatal_signal_handler, "SIGINT")
520
    gevent.signal(SIGTERM, fatal_signal_handler, "SIGTERM")
521

    
522
    # Init ephemeral port pool
523
    ports = range(opts.min_port, opts.max_port + 1)
524

    
525
    while True:
526
        try:
527
            client, addr = ctrl.accept()
528
            logger.info("New control connection")
529

    
530
            # Receive and parse a client request.
531
            response = {
532
                "source_port": 0,
533
                "status": "FAILED"
534
            }
535
            try:
536
                # TODO: support multiple forwardings in the same message?
537
                #
538
                # Control request, in JSON:
539
                #
540
                # {
541
                #     "source_port":
542
                #         <source port or 0 for automatic allocation>,
543
                #     "destination_address":
544
                #         <destination address of backend server>,
545
                #     "destination_port":
546
                #         <destination port>
547
                #     "password":
548
                #         <the password to use to authenticate clients>
549
                # }
550
                #
551
                # The <password> is used for MITM authentication of clients
552
                # connecting to <source_port>, who will subsequently be
553
                # forwarded to a VNC server listening at
554
                # <destination_address>:<destination_port>
555
                #
556
                # Control reply, in JSON:
557
                # {
558
                #     "source_port": <the allocated source port>
559
                #     "status": <one of "OK" or "FAILED">
560
                # }
561
                #
562
                buf = client.recv(1024)
563
                req = json.loads(buf)
564

    
565
                sport_orig = int(req['source_port'])
566
                daddr = req['destination_address']
567
                dport = int(req['destination_port'])
568
                password = req['password']
569
            except Exception, e:
570
                logger.warn("Malformed request: %s" % buf)
571
                client.send(json.dumps(response))
572
                client.close()
573
                continue
574

    
575
            # Spawn a new Greenlet to service the request.
576
            server = None
577
            try:
578
                # If the client has so indicated, pick an ephemeral source port
579
                # randomly, and remove it from the port pool.
580
                if sport_orig == 0:
581
                    sport = random.choice(ports)
582
                    ports.remove(sport)
583
                    logger.debug(("Got port %d from pool, %d remaining",
584
                                  sport, len(ports)))
585
                    pool = ports
586
                else:
587
                    sport = sport_orig
588
                    pool = None
589

    
590
                listeners = get_listening_sockets(sport)
591
                server = perform_server_handshake(daddr, dport,
592
                                                  opts.connect_retries,
593
                                                  opts.retry_wait)
594

    
595
                VncAuthProxy.spawn(logger, listeners, pool, daddr, dport,
596
                                   server, password, opts.connect_timeout)
597

    
598
                logger.info(("New forwarding: %d (client req'd: %d) -> %s:%d" %
599
                             (sport, sport_orig, daddr, dport)))
600
                response = {"source_port": sport,
601
                            "status": "OK"}
602
            except IndexError:
603
                logger.error(("FAILED forwarding, out of ports for [req'd by "
604
                              "client: %d -> %s:%d]" % (sport_orig, daddr,
605
                                                        dport)))
606
            except Exception, msg:
607
                logger.error(msg)
608
                logger.error(("FAILED forwarding: %d (client req'd: %d) -> "
609
                              "%s:%d" % (sport, sport_orig, daddr, dport)))
610
                if not pool is None:
611
                    pool.append(sport)
612
                    logger.debug(("Returned port %d to pool, %d remanining",
613
                                  sport, len(pool)))
614
                if not server is None:
615
                    server.close()
616
            finally:
617
                client.send(json.dumps(response))
618
                client.close()
619
        except Exception, e:
620
            logger.exception(e)
621
            continue
622
        except SystemExit:
623
            break
624

    
625
    logger.info("Unlinking control socket at %s" % opts.ctrl_socket)
626
    os.unlink(opts.ctrl_socket)
627
    daemon_context.close()
628
    sys.exit(0)