Revision 4f650d54

b/ncclient/content.py
53 53
    @staticmethod
54 54
    def Element(spec):
55 55
        """DictTree -> Element
56
        
56

  
57 57
        :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
58

  
58 59
        :rtype: :class:`~xml.etree.ElementTree.Element`
59 60
        """
60 61
        if iselement(spec):
......
64 65
        if not isinstance(spec, dict):
65 66
            raise ContentError("Invalid tree spec")
66 67
        if 'tag' in spec:
67
            ele = ET.Element(spec.get('tag'), spec.get('attributes', {}))
68
            ele = ET.Element(spec.get('tag'), spec.get('attrib', {}))
68 69
            ele.text = spec.get('text', '')
69 70
            ele.tail = spec.get('tail', '')
70 71
            subtree = spec.get('subtree', [])
......
78 79
            return ET.Comment(spec.get('comment'))
79 80
        else:
80 81
            raise ContentError('Invalid tree spec')
81
    
82

  
82 83
    @staticmethod
83 84
    def XML(spec, encoding='UTF-8'):
84 85
        """DictTree -> XML
85
        
86

  
86 87
        :type spec: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
88

  
87 89
        :arg encoding: chraracter encoding
90

  
88 91
        :rtype: string
89 92
        """
90 93
        return Element.XML(DictTree.Element(spec), encoding)
91 94

  
92 95
class Element:
93
    
96

  
94 97
    @staticmethod
95 98
    def DictTree(ele):
96 99
        """DictTree -> Element
97
        
100

  
98 101
        :type spec: :class:`~xml.etree.ElementTree.Element`
99 102
        :rtype: :obj:`dict`
100 103
        """
......
105 108
            'tail': ele.tail,
106 109
            'subtree': [ Element.DictTree(child) for child in root.getchildren() ]
107 110
        }
108
    
111

  
109 112
    @staticmethod
110 113
    def XML(ele, encoding='UTF-8'):
111 114
        """Element -> XML
112
        
115

  
113 116
        :type spec: :class:`~xml.etree.ElementTree.Element`
114 117
        :arg encoding: character encoding
115 118
        :rtype: :obj:`string`
......
121 124
            return '<?xml version="1.0" encoding="%s"?>%s' % (encoding, xml)
122 125

  
123 126
class XML:
124
    
127

  
125 128
    @staticmethod
126 129
    def DictTree(xml):
127 130
        """XML -> DictTree
128
        
131

  
129 132
        :type spec: :obj:`string`
130 133
        :rtype: :obj:`dict`
131 134
        """
132 135
        return Element.DictTree(XML.Element(xml))
133
    
136

  
134 137
    @staticmethod
135 138
    def Element(xml):
136 139
        """XML -> Element
137
        
140

  
138 141
        :type xml: :obj:`string`
139 142
        :rtype: :class:`~xml.etree.ElementTree.Element`
140 143
        """
......
153 156

  
154 157
def find(ele, tag, nslist=[]):
155 158
    """If `nslist` is empty, same as :meth:`xml.etree.ElementTree.Element.find`. If it is not, `tag` is interpreted as an unqualified name and qualified using each item in `nslist`. The first match is returned.
156
    
159

  
157 160
    :arg nslist: optional list of namespaces
158 161
    """
159 162
    if nslist:
......
166 169

  
167 170
def parse_root(raw):
168 171
    """Efficiently parses the root element of an XML document.
169
    
172

  
170 173
    :type raw: string
171 174
    :returns: a tuple of `(tag, attributes)`, where `tag` is the (qualified) name of the element and `attributes` is a dictionary of its attributes.
172 175
    """
......
176 179

  
177 180
def validated_element(rep, tag=None, attrs=None, text=None):
178 181
    """Checks if the root element meets the supplied criteria. Returns a :class:`~xml.etree.ElementTree.Element` instance if so, otherwise raises :exc:`ContentError`.
179
    
182

  
180 183
    :arg tag: tag name or a list of allowable tag names
181 184
    :arg attrs: list of required attribute names, each item may be a list of allowable alternatives
182 185
    :arg text: textual content to match
183 186
    :type rep: :obj:`dict` or :obj:`string` or :class:`~xml.etree.ElementTree.Element`
184
    :see: :ref:`dtree`
185 187
    """
186 188
    ele = dtree2ele(rep)
187 189
    err = False
b/ncclient/manager.py
28 28
RAISE_ALL, RAISE_ERROR, RAISE_NONE = range(3)
29 29

  
30 30
class Manager:
31
    
31

  
32 32
    "Thin layer of abstraction for the ncclient API."
33
    
34
    def __init__(self, session, rpc_error=RAISE_ALL):
33

  
34
    def __init__(self, session):
35 35
        self._session = session
36
        self._raise = rpc_error
36
        self._rpc_error_handling = RAISE_ALL
37

  
38
    def set_rpc_error_option(self, option):
39
        self._rpc_error_handling = option
37 40

  
38 41
    def do(self, op, *args, **kwds):
39 42
        op = operations.OPERATIONS[op](self._session)
......
46 49
                    if error.severity == 'error':
47 50
                        raise error
48 51
        return reply
49
    
52

  
50 53
    def __enter__(self):
51 54
        pass
52
    
55

  
53 56
    def __exit__(self, *args):
54 57
        self.close()
55 58
        return False
56
    
59

  
57 60
    def locked(self, target):
58
        """Returns a context manager for use withthe 'with' statement.
61
        """Returns a context manager for use with the 'with' statement.
59 62
        `target` is the datastore to lock, e.g. 'candidate
60 63
        """
61 64
        return operations.LockContext(self._session, target)
62
     
65

  
63 66
    get = lambda self, *args, **kwds: self.do('get', *args, **kwds).data
64
    
67

  
65 68
    get_config = lambda self, *args, **kwds: self.do('get-config', *args, **kwds).data
66
    
69

  
67 70
    edit_config = lambda self, *args, **kwds: self.do('edit-config', *args, **kwds)
68
    
71

  
69 72
    copy_config = lambda self, *args, **kwds: self.do('copy-config', *args, **kwds)
70
    
73

  
71 74
    validate = lambda self, *args, **kwds: self.do('validate', *args, **kwds)
72
    
75

  
73 76
    commit = lambda self, *args, **kwds: self.do('commit', *args, **kwds)
74
    
77

  
75 78
    discard_changes = lambda self, *args, **kwds: self.do('discard-changes', *args, **kwds)
76
    
79

  
77 80
    delete_config = lambda self, *args, **kwds: self.do('delete-config', *args, **kwds)
78
    
81

  
79 82
    lock = lambda self, *args, **kwds: self.do('lock', *args, **kwds)
80
    
83

  
81 84
    unlock = lambda self, *args, **kwds: self.do('unlock', *args, **kwds)
82
    
85

  
83 86
    close_session = lambda self, *args, **kwds: self.do('close-session', *args, **kwds)
84
    
87

  
85 88
    kill_session = lambda self, *args, **kwds: self.do('kill-session', *args, **kwds)
86
    
89

  
87 90
    def close(self):
88 91
        try: # try doing it clean
89 92
            self.close_session()
......
91 94
            pass
92 95
        if self._session.connected: # if that didn't work...
93 96
            self._session.close()
94
    
97

  
95 98
    @property
96 99
    def session(self, session):
97 100
        return self._session
98
    
101

  
99 102
    def get_capabilities(self, whose):
100 103
        if whose in ('manager', 'client'):
101 104
            return self._session._client_capabilities
102 105
        elif whose in ('agent', 'server'):
103 106
            return self._session._server_capabilities
104
    
107

  
105 108
    @property
106 109
    def capabilities(self):
107 110
        return self._session._client_capabilities
111

  
112
    @property
113
    def server_capabilities(self):
114
        return self._session._server_capabilities
b/ncclient/operations/retrieve.py
28 28
    def _parsing_hook(self, root):
29 29
        self._data = None
30 30
        if not self._errors:
31
            self._data = content.find(root, 'data', strict=False)
31
            self._data = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
32 32
    
33 33
    @property
34 34
    def data(self):
b/ncclient/operations/rpc.py
22 22
from errors import OperationError
23 23

  
24 24
import logging
25
logger = logging.getLogger('ncclient.rpc')
25
logger = logging.getLogger('ncclient.operations.rpc')
26 26

  
27 27

  
28 28
class RPCReply:
29
    
30
    'NOTES: memory considerations?? storing both raw xml + ET.Element'
31
    
29

  
32 30
    def __init__(self, raw):
33 31
        self._raw = raw
34 32
        self._parsed = False
35 33
        self._root = None
36 34
        self._errors = []
37
    
35

  
38 36
    def __repr__(self):
39 37
        return self._raw
40
    
41
    def _parsing_hook(self, root):
42
        pass
43
    
38

  
39
    def _parsing_hook(self, root): pass
40

  
44 41
    def parse(self):
45 42
        if self._parsed:
46 43
            return
47 44
        root = self._root = content.xml2ele(self._raw) # <rpc-reply> element
48 45
        # per rfc 4741 an <ok/> tag is sent when there are no errors or warnings
49
        ok = content.find(root, 'data', strict=False)
46
        ok = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
50 47
        if ok is not None:
51 48
            logger.debug('parsed [%s]' % ok.tag)
52 49
        else: # create RPCError objects from <rpc-error> elements
53
            error = content.find(root, 'data', strict=False)
50
            error = content.find(root, 'data', nslist=[content.BASE_NS, content.CISCO_BS])
54 51
            if error is not None:
55 52
                logger.debug('parsed [%s]' % error.tag)
56 53
                for err in root.getiterator(error.tag):
......
65 62
                    self._errors.append(RPCError(d))
66 63
        self._parsing_hook(root)
67 64
        self._parsed = True
68
    
65

  
69 66
    @property
70 67
    def xml(self):
71 68
        '<rpc-reply> as returned'
72 69
        return self._raw
73
    
70

  
74 71
    @property
75 72
    def ok(self):
76 73
        if not self._parsed:
77 74
            self.parse()
78 75
        return not self._errors # empty list => false
79
    
76

  
80 77
    @property
81 78
    def error(self):
82 79
        if not self._parsed:
......
85 82
            return self._errors[0]
86 83
        else:
87 84
            return None
88
    
85

  
89 86
    @property
90 87
    def errors(self):
91 88
        'List of RPCError objects. Will be empty if no <rpc-error> elements in reply.'
......
95 92

  
96 93

  
97 94
class RPCError(OperationError): # raise it if you like
98
    
95

  
99 96
    def __init__(self, err_dict):
100 97
        self._dict = err_dict
101 98
        if self.message is not None:
102 99
            OperationError.__init__(self, self.message)
103 100
        else:
104 101
            OperationError.__init__(self)
105
    
102

  
106 103
    @property
107 104
    def type(self):
108 105
        return self.get('error-type', None)
109
    
106

  
110 107
    @property
111 108
    def severity(self):
112 109
        return self.get('error-severity', None)
113
    
110

  
114 111
    @property
115 112
    def tag(self):
116 113
        return self.get('error-tag', None)
117
    
114

  
118 115
    @property
119 116
    def path(self):
120 117
        return self.get('error-path', None)
121
    
118

  
122 119
    @property
123 120
    def message(self):
124 121
        return self.get('error-message', None)
125
    
122

  
126 123
    @property
127 124
    def info(self):
128 125
        return self.get('error-info', None)
129 126

  
130 127
    ## dictionary interface
131
    
128

  
132 129
    __getitem__ = lambda self, key: self._dict.__getitem__(key)
133
    
130

  
134 131
    __iter__ = lambda self: self._dict.__iter__()
135
    
132

  
136 133
    __contains__ = lambda self, key: self._dict.__contains__(key)
137
    
134

  
138 135
    keys = lambda self: self._dict.keys()
139
    
136

  
140 137
    get = lambda self, key, default: self._dict.get(key, default)
141
        
138

  
142 139
    iteritems = lambda self: self._dict.iteritems()
143
    
140

  
144 141
    iterkeys = lambda self: self._dict.iterkeys()
145
    
142

  
146 143
    itervalues = lambda self: self._dict.itervalues()
147
    
144

  
148 145
    values = lambda self: self._dict.values()
149
    
146

  
150 147
    items = lambda self: self._dict.items()
151
    
148

  
152 149
    __repr__ = lambda self: repr(self._dict)
153 150

  
154 151

  
155 152
class RPCReplyListener(SessionListener):
156
    
153

  
157 154
    # one instance per session
158 155
    def __new__(cls, session):
159 156
        instance = session.get_listener_instance(cls)
......
164 161
            instance._pipelined = session.can_pipeline
165 162
            session.add_listener(instance)
166 163
        return instance
167
    
164

  
168 165
    def register(self, id, rpc):
169 166
        with self._lock:
170 167
            self._id2rpc[id] = rpc
171
    
168

  
172 169
    def callback(self, root, raw):
173 170
        tag, attrs = root
174 171
        if content.unqualify(tag) != 'rpc-reply':
......
195 192
                logger.warning('<rpc-reply> without message-id received: %s' % raw)
196 193
        logger.debug('delivering to %r' % rpc)
197 194
        rpc.deliver(raw)
198
    
195

  
199 196
    def errback(self, err):
200 197
        for rpc in self._id2rpc.values():
201 198
            rpc.error(err)
202 199

  
203 200

  
204 201
class RPC(object):
205
    
202

  
206 203
    DEPENDS = []
207 204
    REPLY_CLS = RPCReply
208
    
205

  
209 206
    def __init__(self, session, async=False, timeout=None):
210 207
        if not session.can_pipeline:
211 208
            raise UserWarning('Asynchronous mode not supported for this device/session')
......
214 211
            for cap in self.DEPENDS:
215 212
                self._assert(cap)
216 213
        except AttributeError:
217
            pass        
214
            pass
218 215
        self._async = async
219 216
        self._timeout = timeout
220 217
        # keeps things simple instead of having a class attr that has to be locked
......
223 220
        self._listener = RPCReplyListener(session)
224 221
        self._listener.register(self._id, self)
225 222
        self._reply = None
223
        self._error = None
226 224
        self._reply_event = Event()
227
    
225

  
228 226
    def _build(self, opspec):
229 227
        "TODO: docstring"
230 228
        spec = {
231 229
            'tag': content.qualify('rpc'),
232
            'attributes': {'message-id': self._id},
230
            'attrib': {'message-id': self._id},
233 231
            'subtree': opspec
234 232
            }
235 233
        return content.dtree2xml(spec)
236
    
234

  
237 235
    def _request(self, op):
238 236
        req = self._build(op)
239 237
        self._session.send(req)
240 238
        if self._async:
241
            return (self._reply_event, self._error_event)
239
            return self._reply_event
242 240
        else:
243 241
            self._reply_event.wait(self._timeout)
244
            if self._reply_event.is_set():
242
            if self._reply_event.isSet():
245 243
                if self._error:
246 244
                    raise self._error
247 245
                self._reply.parse()
248 246
                return self._reply
249 247
            else:
250 248
                raise ReplyTimeoutError
251
    
249

  
252 250
    def request(self):
253 251
        return self._request(self.SPEC)
254
    
252

  
255 253
    def _delivery_hook(self):
256 254
        'For subclasses'
257 255
        pass
258
    
256

  
259 257
    def _assert(self, capability):
260 258
        if capability not in self._session.server_capabilities:
261 259
            raise MissingCapabilityError('Server does not support [%s]' % cap)
262
    
260

  
263 261
    def deliver(self, raw):
264 262
        self._reply = self.REPLY_CLS(raw)
265 263
        self._delivery_hook()
266 264
        self._reply_event.set()
267
    
265

  
268 266
    def error(self, err):
269 267
        self._error = err
270 268
        self._reply_event.set()
271
    
269

  
272 270
    @property
273 271
    def has_reply(self):
274 272
        return self._reply_event.is_set()
275
    
273

  
276 274
    @property
277 275
    def reply(self):
276
        if self.error:
277
            raise self._error
278 278
        return self._reply
279
    
279

  
280 280
    @property
281 281
    def id(self):
282 282
        return self._id
283
    
283

  
284 284
    @property
285 285
    def session(self):
286 286
        return self._session
287
    
287

  
288 288
    @property
289 289
    def reply_event(self):
290 290
        return self._reply_event
291
    
291

  
292 292
    def set_async(self, bool): self._async = bool
293 293
    async = property(fget=lambda self: self._async, fset=set_async)
294
    
294

  
295 295
    def set_timeout(self, timeout): self._timeout = timeout
296 296
    timeout = property(fget=lambda self: self._timeout, fset=set_timeout)
b/ncclient/operations/subscribe.py
14 14

  
15 15
from rpc import RPC
16 16

  
17
from ncclient.glue import Listener
18 17
from ncclient.content import qualify as _
18
from ncclient.transport import SessionListener
19

  
20
NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
19 21

  
20 22
# TODO when can actually test it...
21 23

  
......
28 30

  
29 31
class Notification: pass
30 32

  
31
class NotificationListener(Listener): pass
33
class NotificationListener(SessionListener): pass
b/ncclient/operations/util.py
46 46
        type, criteria = tuple
47 47
        rep = {
48 48
            'tag': 'filter',
49
            'attributes': {'type': type},
49
            'attrib': {'type': type},
50 50
            'subtree': criteria
51 51
        }
52 52
    else:
b/ncclient/transport/errors.py
21 21
    pass
22 22

  
23 23
class SessionCloseError(TransportError):
24
    
24

  
25 25
    def __init__(self, in_buf, out_buf=None):
26 26
        msg = 'Unexpected session close.'
27 27
        if in_buf:
......
34 34
    pass
35 35

  
36 36
class SSHUnknownHostError(SSHError):
37
    
38
    def __init__(self, hostname, key):
39
        from binascii import hexlify
40
        SSHError(self, 'Unknown host key [%s] for [%s]'
41
                 % (hexlify(key.get_fingerprint()), hostname))
42
        self.hostname = hostname
37

  
38
    def __init__(self, host, fingerprint):
39
        SSHError.__init__(self, 'Unknown host key [%s] for [%s]'
40
                          % (fingerprint, host))
41
        self.host = host
42
        self.fingerprint = fingerprint
b/ncclient/transport/session.py
22 22
logger = logging.getLogger('ncclient.transport.session')
23 23

  
24 24
class Session(Thread):
25
    "This is a base class for use by protocol implementations"
26
    
25
    "Base class for use by transport protocol implementations."
26

  
27 27
    def __init__(self, capabilities):
28 28
        Thread.__init__(self)
29
        self.set_daemon(True)
30
        self._listeners = set() # 3.0's weakset ideal
29
        self.setDaemon(True)
30
        self._listeners = set() # 3.0's weakset would be ideal
31 31
        self._lock = Lock()
32
        self.set_name('session')
32
        self.setName('session')
33 33
        self._q = Queue()
34 34
        self._client_capabilities = capabilities
35 35
        self._server_capabilities = None # yet
......
37 37
        self._connected = False # to be set/cleared by subclass implementation
38 38
        logger.debug('%r created: client_capabilities=%r' %
39 39
                     (self, self._client_capabilities))
40
    
40

  
41 41
    def _dispatch_message(self, raw):
42 42
        try:
43 43
            root = content.parse_root(raw)
......
52 52
                l.callback(root, raw)
53 53
            except Exception as e:
54 54
                logger.warning('[error] %r' % e)
55
    
55

  
56 56
    def _dispatch_error(self, err):
57 57
        with self._lock:
58 58
            listeners = list(self._listeners)
......
62 62
                l.errback(err)
63 63
            except Exception as e:
64 64
                logger.warning('error %r' % e)
65
    
65

  
66 66
    def _post_connect(self):
67 67
        "Greeting stuff"
68 68
        init_event = Event()
......
86 86
        self.remove_listener(listener)
87 87
        if error[0]:
88 88
            raise error[0]
89
        logger.info('initialized: session-id=%s | server_capabilities=%s' % (self._id, self._server_capabilities))
90
    
89
        logger.info('initialized: session-id=%s | server_capabilities=%s' %
90
                    (self._id, self._server_capabilities))
91

  
91 92
    def add_listener(self, listener):
92
        """Register a listener that will be notified of incoming messages and errors.
93
        
94
        :type listener: :class:`SessionListener`
93
        """Register a listener that will be notified of incoming messages and
94
        errors.
95

  
96
        :arg listener: :class:`SessionListener`
95 97
        """
96 98
        logger.debug('installing listener %r' % listener)
97 99
        if not isinstance(listener, SessionListener):
98 100
            raise SessionError("Listener must be a SessionListener type")
99 101
        with self._lock:
100 102
            self._listeners.add(listener)
101
    
103

  
102 104
    def remove_listener(self, listener):
103
        "Unregister some listener; ignoring if the listener was never registered."
105
        """Unregister some listener; ignore if the listener was never
106
        registered."""
104 107
        logger.debug('discarding listener %r' % listener)
105 108
        with self._lock:
106 109
            self._listeners.discard(listener)
107
    
110

  
108 111
    def get_listener_instance(self, cls):
109
        """If a listener of the specified type is registered, returns it. This is useful when it is desirable to have only one instance of a particular type per session, i.e. a multiton.
110
        
111
        :type cls: :class:`type`
112
        :rtype: :class:`SessionListener` or :const:`None`
112
        """If a listener of the sspecified type is registered, returns the
113
        instance. This is useful when it is desirable to have only one instance
114
        of a particular type per session, i.e. a multiton.
115

  
116
        :arg cls: class of the listener
113 117
        """
114 118
        with self._lock:
115 119
            for listener in self._listeners:
116 120
                if isinstance(listener, cls):
117 121
                    return listener
118
    
122

  
119 123
    def connect(self, *args, **kwds): # subclass implements
120 124
        raise NotImplementedError
121 125

  
122 126
    def run(self): # subclass implements
123 127
        raise NotImplementedError
124
    
128

  
125 129
    def send(self, message):
126
        """
127
        :param message: XML document
128
        :type message: string
130
        """Send the supplied *message* to NETCONF server.
131

  
132
        :arg message: an XML document
133

  
134
        :type message: :obj:`string`
129 135
        """
130 136
        logger.debug('queueing %s' % message)
131 137
        self._q.put(message)
132
    
138

  
133 139
    ### Properties
134 140

  
135 141
    @property
136 142
    def connected(self):
137
        ":rtype: bool"
143
        "Connection status of the session."
138 144
        return self._connected
139 145

  
140 146
    @property
141 147
    def client_capabilities(self):
142
        ":rtype: :class:`Capabilities`"
148
        "Client's :class:`Capabilities`"
143 149
        return self._client_capabilities
144
    
150

  
145 151
    @property
146 152
    def server_capabilities(self):
147
        ":rtype: :class:`Capabilities` or :const:`None`"
153
        "Server's :class:`Capabilities`"
148 154
        return self._server_capabilities
149
    
155

  
150 156
    @property
151 157
    def id(self):
152
        ":rtype: :obj:`string` or :const:`None`"
158
        """A :obj:`string` representing the `session-id`. If the session has not
159
        been initialized it will be :const:`None`"""
153 160
        return self._id
154
    
161

  
155 162
    @property
156 163
    def can_pipeline(self):
157
        ":rtype: :obj:`bool`"
164
        "Whether this session supports pipelining"
158 165
        return True
159 166

  
160 167

  
161 168
class SessionListener(object):
162
    
163
    """'Listen' to incoming messages on a NETCONF :class:`Session`
164
    
169

  
170
    """Base class for :class:`Session` listeners, which are notified when a new
171
    NETCONF message is received or an error occurs.
172

  
165 173
    .. note::
166
        Avoid computationally intensive tasks in the callbacks.
174
        Avoid time-intensive tasks in a callback's context.
167 175
    """
168
    
176

  
169 177
    def callback(self, root, raw):
170
        """Called when a new XML document is received. The `root` argument allows the callback to determine whether it wants to further process the document.
171
        
172
        :param root: tuple of (tag, attrs) where tag is the qualified name of the root element and attrs is a dictionary of its attributes (also qualified names)
173
        :param raw: XML document
174
        :type raw: string
178
        """Called when a new XML document is received. The `root` argument
179
        allows the callback to determine whether it wants to further process the
180
        document.
181

  
182
        :arg root: is a tuple of `(tag, attributes)` where `tag` is the qualified name of the root element and `attributes` is a dictionary of its attributes (also qualified names)
183

  
184
        :arg raw: XML document
185
        :type raw: :obj:`string`
175 186
        """
176 187
        raise NotImplementedError
177
    
188

  
178 189
    def errback(self, ex):
179 190
        """Called when an error occurs.
180
        
181
        :type ex: :class:`Exception`
191

  
192
        :type ex: :exc:`Exception`
182 193
        """
183 194
        raise NotImplementedError
184 195

  
185 196

  
186 197
class HelloHandler(SessionListener):
187
    
198

  
188 199
    def __init__(self, init_cb, error_cb):
189 200
        self._init_cb = init_cb
190 201
        self._error_cb = error_cb
191
    
202

  
192 203
    def callback(self, root, raw):
193 204
        if content.unqualify(root[0]) == 'hello':
194 205
            try:
......
197 208
                self._error_cb(e)
198 209
            else:
199 210
                self._init_cb(id, capabilities)
200
    
211

  
201 212
    def errback(self, err):
202 213
        self._error_cb(err)
203
    
214

  
204 215
    @staticmethod
205 216
    def build(capabilities):
206 217
        "Given a list of capability URI's returns <hello> message XML string"
......
213 224
                }]
214 225
            }
215 226
        return content.dtree2xml(spec)
216
    
227

  
217 228
    @staticmethod
218 229
    def parse(raw):
219 230
        "Returns tuple of (session-id (str), capabilities (Capabilities)"
b/ncclient/transport/ssh.py
30 30
MSG_DELIM = ']]>]]>'
31 31
TICK = 0.1
32 32

  
33
def default_unknown_host_cb(host, key):
34
    """An `unknown host callback` returns :const:`True` if it finds the key
35
    acceptable, and :const:`False` if not.
36

  
37
    :arg host: the hostname/address which needs to be verified
38

  
39
    :arg key: a hex string representing the host key fingerprint
40

  
41
    :returns: this default callback always returns :const:`False`
42
    """
43
    return False
44

  
45

  
33 46
class SSHSession(Session):
34
    
35
    "A NETCONF SSH session, per :rfc:`4742`"
36
    
47

  
48
    "Implements a :rfc:`4742` NETCONF session over SSH."
49

  
37 50
    def __init__(self, capabilities):
38 51
        Session.__init__(self, capabilities)
39 52
        self._host_keys = paramiko.HostKeys()
......
44 57
        self._expecting_close = False
45 58
        self._buffer = StringIO() # for incoming data
46 59
        # parsing-related, see _parse()
47
        self._parsing_state = 0 
60
        self._parsing_state = 0
48 61
        self._parsing_pos = 0
49
    
62

  
50 63
    def _parse(self):
51 64
        '''Messages ae delimited by MSG_DELIM. The buffer could have grown by a
52 65
        maximum of BUF_SIZE bytes everytime this method is called. Retains state
53
        across method calls and if a byte has been read it will not be considered
54
        again.
55
        '''
66
        across method calls and if a byte has been read it will not be
67
        considered again. '''
56 68
        delim = MSG_DELIM
57 69
        n = len(delim) - 1
58 70
        expect = self._parsing_state
......
90 102
        self._buffer = buf
91 103
        self._parsing_state = expect
92 104
        self._parsing_pos = self._buffer.tell()
93
    
105

  
94 106
    def load_system_host_keys(self, filename=None):
95 107
        if filename is None:
96 108
            filename = os.path.expanduser('~/.ssh/known_hosts')
......
105 117
                    pass
106 118
            return
107 119
        self._system_host_keys.load(filename)
108
    
120

  
109 121
    def load_host_keys(self, filename):
110 122
        self._host_keys.load(filename)
111 123

  
112 124
    def add_host_key(self, key):
113 125
        self._host_keys.add(key)
114
    
126

  
115 127
    def save_host_keys(self, filename):
116 128
        f = open(filename, 'w')
117 129
        for host, keys in self._host_keys.iteritems():
118 130
            for keytype, key in keys.iteritems():
119 131
                f.write('%s %s %s\n' % (host, keytype, key.get_base64()))
120
        f.close()    
121
    
132
        f.close()
133

  
122 134
    def close(self):
123 135
        self._expecting_close = True
124 136
        if self._transport.is_active():
125 137
            self._transport.close()
126 138
        self._connected = False
127
    
139

  
128 140
    def connect(self, host, port=830, timeout=None,
129
                unknown_host_cb=None, username=None, password=None,
141
                unknown_host_cb=default_unknown_host_cb,
142
                username=None, password=None,
130 143
                key_filename=None, allow_agent=True, look_for_keys=True):
144
        """Connect via SSH and initialize the NETCONF session. First attempts
145
        the publickey authentication method and then password authentication.
146

  
147
        To disable publickey authentication, call with *allow_agent* and
148
        *look_for_keys* as :const:`False`
149

  
150
        :arg host: the hostname or IP address to connect to
151

  
152
        :arg port: by default 830, but some devices use the default SSH port of 22 so this may need to be specified
153

  
154
        :arg timeout: an optional timeout for the TCP handshake
155

  
156
        :arg unknown_host_cb: called when a host key is not known. See :func:`unknown_host_cb` for details on signature
157

  
158
        :arg username: the username to use for SSH authentication
159

  
160
        :arg password: the password used if using password authentication, or the passphrase to use in order to unlock keys that require it
161

  
162
        :arg key_filename: a filename where a the private key to be used can be found
163

  
164
        :arg allow_agent: enables querying SSH agent (if found) for keys
165

  
166
        :arg look_for_keys: enables looking in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`)
167
        """
168

  
131 169
        assert(username is not None)
132
        
170

  
133 171
        for (family, socktype, proto, canonname, sockaddr) in \
134 172
        socket.getaddrinfo(host, port):
135 173
            if socktype == socket.SOCK_STREAM:
......
143 181
        sock.connect(addr)
144 182
        t = self._transport = paramiko.Transport(sock)
145 183
        t.set_log_channel(logger.name)
146
        
184

  
147 185
        try:
148 186
            t.start_client()
149 187
        except paramiko.SSHException:
150 188
            raise SSHError('Negotiation failed')
151
        
189

  
152 190
        # host key verification
153 191
        server_key = t.get_remote_server_key()
154 192
        known_host = self._host_keys.check(host, server_key) or \
155 193
                        self._system_host_keys.check(host, server_key)
156
        
157
        if unknown_host_cb is None:
158
            unknown_host_cb = lambda *args: False
159
        if not known_host and not unknown_host_cb(host, server_key):
160
                raise SSHUnknownHostError(host, server_key)
161
        
194

  
195
        fp = hexlify(server_key.get_fingerprint())
196
        if not known_host and not unknown_host_cb(host, fp):
197
            raise SSHUnknownHostError(host, fp)
198

  
162 199
        if key_filename is None:
163 200
            key_filenames = []
164 201
        elif isinstance(key_filename, basestring):
165 202
            key_filenames = [ key_filename ]
166 203
        else:
167 204
            key_filenames = key_filename
168
        
205

  
169 206
        self._auth(username, password, key_filenames, allow_agent, look_for_keys)
170
        
207

  
171 208
        self._connected = True # there was no error authenticating
172
        
209

  
173 210
        c = self._channel = self._transport.open_session()
174 211
        c.set_name('netconf')
175 212
        c.invoke_subsystem('netconf')
176
        
213

  
177 214
        self._post_connect()
178
    
215

  
179 216
    # on the lines of paramiko.SSHClient._auth()
180 217
    def _auth(self, username, password, key_filenames, allow_agent,
181 218
              look_for_keys):
182 219
        saved_exception = None
183
        
220

  
184 221
        for key_filename in key_filenames:
185 222
            for cls in (paramiko.RSAKey, paramiko.DSSKey):
186 223
                try:
......
192 229
                except Exception as e:
193 230
                    saved_exception = e
194 231
                    logger.debug(e)
195
        
232

  
196 233
        if allow_agent:
197 234
            for key in paramiko.Agent().get_keys():
198 235
                try:
......
203 240
                except Exception as e:
204 241
                    saved_exception = e
205 242
                    logger.debug(e)
206
        
243

  
207 244
        keyfiles = []
208 245
        if look_for_keys:
209 246
            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
......
219 256
                keyfiles.append((paramiko.RSAKey, rsa_key))
220 257
            if os.path.isfile(dsa_key):
221 258
                keyfiles.append((paramiko.DSSKey, dsa_key))
222
        
259

  
223 260
        for cls, filename in keyfiles:
224 261
            try:
225 262
                key = cls.from_private_key_file(filename, password)
......
230 267
            except Exception as e:
231 268
                saved_exception = e
232 269
                logger.debug(e)
233
        
270

  
234 271
        if password is not None:
235 272
            try:
236 273
                self._transport.auth_password(username, password)
......
238 275
            except Exception as e:
239 276
                saved_exception = e
240 277
                logger.debug(e)
241
        
278

  
242 279
        if saved_exception is not None:
243 280
            # need pep-3134 to do this right
244 281
            raise SSHAuthenticationError(repr(saved_exception))
245
        
282

  
246 283
        raise SSHAuthenticationError('No authentication methods available')
247
    
284

  
248 285
    def run(self):
249 286
        chan = self._channel
250 287
        chan.setblocking(0)
......
252 289
        try:
253 290
            while True:
254 291
                # select on a paramiko ssh channel object does not ever return
255
                # it in the writable list, so it channel's don't exactly emulate 
292
                # it in the writable list, so it channel's don't exactly emulate
256 293
                # the socket api
257 294
                r, w, e = select([chan], [], [], TICK)
258 295
                # will wakeup evey TICK seconds to check if something
......
278 315
            self.close()
279 316
            if not (isinstance(e, SessionCloseError) and self._expecting_close):
280 317
                self._dispatch_error(e)
281
    
318

  
282 319
    @property
283 320
    def transport(self):
284
        "gug"
321
        """The underlying `paramiko.Transport
322
        <http://www.lag.net/paramiko/docs/paramiko.Transport-class.html>`_
323
        object. This makes it possible to call methods like set_keepalive on it.
324
        """
285 325
        return self._transport
286
    
326

  
287 327
    @property
288 328
    def can_pipeline(self):
289 329
        if 'Cisco' in self._transport.remote_version:
b/ncclient/util.py
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 14

  
15
from ncclient.glue import Listener
15
from ncclient.transport import SessionListener
16

  
17
class PrintListener(SessionListener):
16 18

  
17
class PrintListener(Listener):
18
    
19 19
    def callback(self, root, raw):
20 20
        print('\n# RECEIVED MESSAGE with root=[tag=%r, attrs=%r] #\n%r\n' %
21 21
              (root[0], root[1], raw))
22
    
22

  
23 23
    def errback(self, err):
24 24
        print('\n# RECEIVED ERROR #\n%r\n' % err)

Also available in: Unified diff