made content module object-oriented; lots of code organization changes overall
[ncclient] / ncclient / ssh.py
1 # Copyright 2009 Shikhar Bhushan
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #    http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import logging
16 from cStringIO import StringIO
17 from os import SEEK_CUR
18 import socket
19
20 import paramiko
21
22
23 from session import Session, SessionError
24
25 logger = logging.getLogger('ncclient.ssh')
26
27
28 class SessionCloseError(SessionError):
29     
30     def __str__(self):
31         return 'RECEIVED: %s | UNSENT: %s' % (self._in_buf, self._out_buf)
32     
33     def __init__(self, in_buf, out_buf=None):
34         SessionError.__init__(self)
35         self._in_buf, self._out_buf = in_buf, out_buf
36
37
38 class SSHSession(Session):
39
40     BUF_SIZE = 4096
41     MSG_DELIM = ']]>]]>'
42     
43     def __init__(self, load_known_hosts=True,
44                  missing_host_key_policy=paramiko.RejectPolicy()):
45         Session.__init__(self)
46         self._client = paramiko.SSHClient()
47         self._channel = None
48         if load_known_hosts:
49             self._client.load_system_host_keys()
50         self._client.set_missing_host_key_policy(missing_host_key_policy)
51         self._in_buf = StringIO()
52         self._parsing_state = 0
53         self._parsing_pos = 0
54     
55     def _close(self):
56         self._channel.close()
57         self._connected = False
58     
59     def _fresh_data(self):
60         delim = SSHSession.MSG_DELIM
61         n = len(delim) - 1
62         state = self._parsing_state
63         buf = self._in_buf
64         buf.seek(self._parsing_pos)
65         while True:
66             x = buf.read(1)
67             if not x: # done reading
68                 break
69             elif x == delim[state]:
70                 state += 1
71             else:
72                 continue
73             # loop till last delim char expected, break if other char encountered
74             for i in range(state, n):
75                 x = buf.read(1)
76                 if not x: # done reading
77                     break
78                 if x==delim[i]: # what we expected
79                     state += 1 # expect the next delim char
80                 else:
81                     state = 0 # reset
82                     break
83             else: # if we didn't break out of above loop, full delim parsed
84                 till = buf.tell() - n
85                 buf.seek(0)
86                 msg = buf.read(till)
87                 self.dispatch('reply', msg)
88                 buf.seek(n+1, SEEK_CUR)
89                 rest = buf.read()
90                 buf = StringIO()
91                 buf.write(rest)
92                 buf.seek(0)
93                 state = 0
94         self._in_buf = buf
95         self._parsing_state = state
96         self._parsing_pos = self._in_buf.tell()
97
98     #def load_host_keys(self, filename):
99     #    self._client.load_host_keys(filename)
100     #
101     #def set_missing_host_key_policy(self, policy):
102     #    self._client.set_missing_host_key_policy(policy)
103     #
104     #def connect(self, hostname, port=830, username=None, password=None,
105     #            key_filename=None, timeout=None, allow_agent=True,
106     #            look_for_keys=True):
107     #    self._client.connect(hostname, port=port, username=username,
108     #                        password=password, key_filename=key_filename,
109     #                        timeout=timeout, allow_agent=allow_agent,
110     #                        look_for_keys=look_for_keys)    
111     #    transport = self._client.get_transport()
112     #    self._channel = transport.open_session()
113     #    self._channel.invoke_subsystem('netconf')
114     #    self._channel.set_name('netconf')
115     #    self._connected = True
116     #    self._post_connect()
117
118     def connect(self, hostname, port=830, username=None, password=None,
119                 key_filename=None, timeout=None, allow_agent=True,
120                 look_for_keys=True):
121         self._transport = paramiko.Transport()
122     
123     def run(self):
124         chan = self._channel
125         chan.setblocking(0)
126         q = self._q
127         try:
128             while True:    
129                 if chan.closed:
130                     raise SessionCloseError(self._in_buf.getvalue())         
131                 if chan.send_ready() and not q.empty():
132                     data = q.get() + SSHSession.MSG_DELIM
133                     while data:
134                         n = chan.send(data)
135                         if n <= 0:
136                             raise SessionCloseError(self._in_buf.getvalue(), data)
137                         data = data[n:]
138                 if chan.recv_ready():
139                     data = chan.recv(SSHSession.BUF_SIZE)
140                     if data:
141                         self._in_buf.write(data)
142                         self._fresh_data()
143                     else:
144                         raise SessionCloseError(self._in_buf.getvalue())
145         except Exception as e:
146             logger.debug('*** broke out of main loop ***')
147             self.dispatch('error', e)
148
149 class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):
150     
151     def __init__(self, cb):
152         self._cb = cb
153     
154     def missing_host_key(self, client, hostname, key):
155         if not self._cb(hostname, key):
156             raise SSHError