Revision ee4bb099

b/ncclient/capabilities.py
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
class Capabilities:
16
    
17
    def __init__(self, capabilities=None):
18
        self._dict = {}
19
        if isinstance(capabilities, dict):
20
            self._dict = capabilities
21
        elif isinstance(capabilities, list):
22
            for uri in capabilities:
23
                self._dict[uri] = Capabilities.guess_shorthand(uri)
24
    
25
    def __contains__(self, key):
26
        return ( key in self._dict ) or ( key in self._dict.values() )
27
    
28
    def __repr__(self):
29
        return self.to_xml()
30
    
31
    def add(self, uri, shorthand=None):
32
        if shorthand is None:
33
            shorthand = Capabilities.guess_shorthand(uri)
34
        self._dict[uri] = shorthand
35
    
36
    set = add
37
    
38
    def remove(self, key):
39
        if key in self._dict:
40
            del self._dict[key]
41
        else:
42
            for uri in self._dict:
43
                if self._dict[uri] == key:
44
                    del self._dict[uri]
45
                    break
46
    
47
    def to_xml(self):
48
        elems = ['<capability>%s</capability>' % uri for uri in self._dict]
49
        return ('<capabilities>%s</capabilities>' % ''.join(elems))
50
    
51
    @staticmethod
52
    def guess_shorthand(uri):
53
        if uri.startswith('urn:ietf:params:netconf:capability:'):
54
            return (':' + uri.split(':')[5])
55

  
56
    
57
CAPABILITIES = Capabilities([
58
    'urn:ietf:params:netconf:base:1.0',
59
    'urn:ietf:params:netconf:capability:writable-running:1.0',
60
    'urn:ietf:params:netconf:capability:candidate:1.0',
61
    'urn:ietf:params:netconf:capability:confirmed-commit:1.0',
62
    'urn:ietf:params:netconf:capability:rollback-on-error:1.0',
63
    'urn:ietf:params:netconf:capability:startup:1.0',
64
    'urn:ietf:params:netconf:capability:url:1.0',
65
    'urn:ietf:params:netconf:capability:validate:1.0',
66
    'urn:ietf:params:netconf:capability:xpath:1.0',
67
    'urn:ietf:params:netconf:capability:notification:1.0',
68
    'urn:ietf:params:netconf:capability:interleave:1.0'
69
    ])
70

  
71
if __name__ == "__main__":
72
    assert(':validate' in CAPABILITIES) # test __contains__
73
    print CAPABILITIES # test __repr__
/dev/null
1
# Copyright 2009 Shikhar Bhushan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#    http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

  
15
class Capabilities:
16
    
17
    def __init__(self, capabilities=None):
18
        self._dict = {}
19
        if isinstance(capabilities, dict):
20
            self._dict = capabilities
21
        elif isinstance(capabilities, list):
22
            for uri in capabilities:
23
                self._dict[uri] = Capabilities.guess_shorthand(uri)
24
    
25
    def __contains__(self, key):
26
        return ( key in self._dict ) or ( key in self._dict.values() )
27
    
28
    def __repr__(self):
29
        return self.to_xml()
30
    
31
    def add(self, uri, shorthand=None):
32
        if shorthand is None:
33
            shorthand = Capabilities.guess_shorthand(uri)
34
        self._dict[uri] = shorthand
35
    
36
    set = add
37
    
38
    def remove(self, key):
39
        if key in self._dict:
40
            del self._dict[key]
41
        else:
42
            for uri in self._dict:
43
                if self._dict[uri] == key:
44
                    del self._dict[uri]
45
                    break
46
    
47
    def to_xml(self):
48
        elems = ['<capability>%s</capability>' % uri for uri in self._dict]
49
        return ('<capabilities>%s</capabilities>' % ''.join(elems))
50
    
51
    @staticmethod
52
    def guess_shorthand(uri):
53
        if uri.startswith('urn:ietf:params:netconf:capability:'):
54
            return (':' + uri.split(':')[5])
55

  
56
    
57
CAPABILITIES = Capabilities([
58
    'urn:ietf:params:netconf:base:1.0',
59
    'urn:ietf:params:netconf:capability:writable-running:1.0',
60
    'urn:ietf:params:netconf:capability:candidate:1.0',
61
    'urn:ietf:params:netconf:capability:confirmed-commit:1.0',
62
    'urn:ietf:params:netconf:capability:rollback-on-error:1.0',
63
    'urn:ietf:params:netconf:capability:startup:1.0',
64
    'urn:ietf:params:netconf:capability:url:1.0',
65
    'urn:ietf:params:netconf:capability:validate:1.0',
66
    'urn:ietf:params:netconf:capability:xpath:1.0',
67
    'urn:ietf:params:netconf:capability:notification:1.0',
68
    'urn:ietf:params:netconf:capability:interleave:1.0'
69
    ])
70

  
71
if __name__ == "__main__":
72
    assert(':validate' in CAPABILITIES) # test __contains__
73
    print CAPABILITIES # test __repr__
b/ncclient/content.py
12 12
# See the License for the specific language governing permissions and
13 13
# limitations under the License.
14 14

  
15
import logging
15 16
from xml.etree import cElementTree as ElementTree
16 17

  
17
NAMESPACE = 'urn:ietf:params:xml:ns:netconf:base:1.0'
18
logger = logging.getLogger('ncclient.content')
18 19

  
19
def qualify(tag, ns=NAMESPACE):
20
BASE_NS = 'urn:ietf:params:xml:ns:netconf:base:1.0'
21
NOTIFICATION_NS = 'urn:ietf:params:xml:ns:netconf:notification:1.0'
22

  
23
def qualify(tag, ns=BASE_NS):
20 24
    return '{%s}%s' % (ns, tag)
21 25

  
22 26
_ = qualify
23 27

  
24 28
def make_hello(capabilities):
25
    return '<hello xmlns="%s">%s</hello>' % (NAMESPACE, capabilities)
29
    return '<hello xmlns="%s">%s</hello>' % (BASE_NS, capabilities)
26 30

  
27 31
def make_rpc(id, op):
28
    return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, NAMESPACE, op)
32
    return '<rpc message-id="%s" xmlns="%s">%s</rpc>' % (id, BASE_NS, op)
29 33

  
30 34
def parse_hello(raw):
31
    from capability import Capabilities
35
    from capabilities import Capabilities
32 36
    id, capabilities = 0, Capabilities()
33 37
    root = ElementTree.fromstring(raw)
34 38
    if root.tag == _('hello'):
......
40 44
                    capabilities.add(cap.text)
41 45
    return id, capabilities
42 46

  
43
def parse_message_type(raw):
44
    
45
    target = RootElementParser()
46
    parser = ElementTree.XMLTreeBuilder(target=target)
47
    parser.feed(raw)
48
    return target.id
47
def parse_message_root(raw):
48
    from cStringIO import StringIO
49
    fp = StringIO(raw)
50
    for event, element in ElementTree.iterparse(fp, events=('start',)):
51
        if element.tag == _('rpc'):
52
            return element.attrib['message-id']
53
        elif element.tag == _('notification', NOTIFICATION_NS):
54
            return 'notification'
55
        else:
56
            return None
b/ncclient/listeners.py
1
                                                                                                                                    # Copyright 2009 Shikhar Bhushan
1
# Copyright 2009 Shikhar Bhushan
2 2
#
3 3
# Licensed under the Apache License, Version 2.0 (the "License");
4 4
# you may not use this file except in compliance with the License.
......
13 13
# limitations under the License.
14 14

  
15 15
import logging
16
import weakref
16
from weakref import WeakValueDictionary
17

  
18
import content
17 19

  
18 20
logger = logging.getLogger('ncclient.listeners')
19 21

  
20
import content
22
session_listeners = {}
23
def session_listener_factory(session):
24
    try:
25
        return session_listeners[session]
26
    except KeyError:
27
        session_listeners[session] = SessionListener()
28
        return session_listeners[session]
21 29

  
22 30
class SessionListener(object):
23 31
    
24
    'A multiton - one listener per session'
25
    
26
    instances = weakref.WeakValueDictionary()
27
    
28
    def __new__(cls, sid):
29
        if sid in instances:# not been gc'd
30
            return cls.instances[sid]
31
        else:
32
            inst = object.__new__(cls)
33
            cls.instances[sid] = inst
34
            return inst
32
    def __init__(self):
33
        self._id2rpc = WeakValueDictionary()
34
        self._expecting_close = False
35
        self._subscription = None
35 36
    
36 37
    def __str__(self):
37 38
        return 'SessionListener'
38 39
    
39
    def set_subscription(self, id):     
40
    def set_subscription(self, id):   
40 41
        self._subscription = id
41 42
    
43
    def expect_close(self):
44
        self._expecting_close = True
45
    
42 46
    def register(self, id, op):
43 47
        self._id2rpc[id] = op
44 48
    
45
    def unregister(self, id):
46
        del self._id2prc[id]
47
    
48 49
    ### Events
49 50
    
50 51
    def reply(self, raw):
51
        id = content.parse_message(raw)
52
        if id:
53
            self._id2rpc[id]._deliver(raw)
54
        else:
55
            self._id2rpc[self._sub_id]._notify(raw)
52
        try:
53
            id = content.parse_message_root(raw)
54
            if id is None:
55
                pass
56
            elif id == 'notification':
57
                self._id2rpc[self._sub_id]._notify(raw)
58
            else:
59
                self._id2rpc[id]._response_cb(raw)
60
        except Exception as e:
61
            logger.warning(e)
56 62
    
57
    def close(self, buf):
58
        pass # TODO
63
    def error(self, err):
64
        from ssh import SessionCloseError
65
        if err is SessionCloseError:
66
            logger.debug('received session close, expecting_close=%s' %
67
                         self._expecting_close)
68
            if not self._expecting_close:
69
                raise err
59 70

  
60 71
class DebugListener:
61 72
    
......
63 74
        return 'DebugListener'
64 75
    
65 76
    def reply(self, raw):
66
        logger.debug('reply:\n%s' % raw)
77
        logger.debug('DebugListener:reply:\n%s' % raw)
67 78
    
68 79
    def error(self, err):
69
        logger.debug(err)
80
        logger.debug('DebugListener:error:\n%s' % err)
b/ncclient/rpc.py
13 13
# limitations under the License.
14 14

  
15 15
from threading import Event, Lock
16

  
17
from listener import SessionListener
18

  
19 16
from uuid import uuid1
20 17

  
18
import content
19
from listeners import session_listener_factory
20

  
21 21
class RPC:
22 22
    
23
    def __init__(self, session, async=False):
23
    def __init__(self, session, async=False, parse=True):
24 24
        self._session = session
25 25
        self._async = async
26
        self._id = uuid1().urn
27
        self._listener = session_listener_factory(self._session)
28
        listener.register(self._id, self)
29
        session.add_listener(self._listener)
26 30
        self._reply = None
27 31
        self._reply_event = Event()
28
        self._id = uuid1().urn
29 32

  
30
    def _listener(self):
31
        if not RPC.listeners.has_key(self._session.id):
32
            RPC.listeners[self._session.id] = SessionListener()
33
        return RPC.listeners[self._session.id]
34

  
35
    def request(self, async=False):
36
        self._async = async
37
        listener = SessionListener(self._session.id)
38
        session.add_listener(listener)
39
        listener.register(self._id, self)
40
        self._session.send(self.to_xml())
41
    
42
    def response_cb(self, reply):
43
        self._reply = reply # does callback parse??
33
    def _response_cb(self, reply):
34
        self._reply = reply
44 35
        self._event.set()
45 36
    
37
    def _do_request(self, op):
38
        self._session.send(content.make_rpc(self._id, op))
39
        if not self._async:
40
            self._reply_event.wait()
41
        return self._reply
42
    
43
    def request(self):
44
        raise NotImplementedError
45
    
46
    def wait_for_reply(self, timeout=None):
47
        self._reply_event.wait(timeout)
48
    
46 49
    @property
47 50
    def has_reply(self):
48 51
        return self._reply_event.isSet()
49 52
    
50
    def wait_on_reply(self, timeout=None):
51
        self._reply_event.wait(timeout)
52
    
53 53
    @property
54 54
    def is_async(self):
55 55
        return self._async
56 56
    
57 57
    @property
58
    def reply(self):
59
        return self._reply
60
    
61
    @property
58 62
    def id(self):
59
        return self._id
63
        return self._id
64
    
65
    @property
66
    def session(self):
67
        return self._session
68

  
69

  
70
class RPCReply:
71
    pass
72

  
73
class RPCError:
74
    pass
b/ncclient/session.py
13 13
# limitations under the License.
14 14

  
15 15
import logging
16

  
17
import content
18

  
19 16
from threading import Thread, Event
20 17
from Queue import Queue
21 18

  
22
from capability import CAPABILITIES
19
import content
20
from capabilities import CAPABILITIES
23 21
from error import ClientError
24 22
from subject import Subject
25 23

  
26 24
logger = logging.getLogger('ncclient.session')
27 25

  
28
class SessionError(ClientError):
29
    
30
    pass
26
class SessionError(ClientError): pass
31 27

  
32 28
class Session(Thread, Subject):
33 29
    
......
37 33
        self._client_capabilities = CAPABILITIES
38 34
        self._server_capabilities = None # yet
39 35
        self._id = None # session-id
40
        self._connected = False # subclasses should set this
41 36
        self._error = None
42 37
        self._init_event = Event()
43 38
        self._q = Queue()
39
        self._connected = False # to be set/cleared by subclass
40
    
41
    def _post_connect(self):
42
        # start the subclass' main loop
43
        self.start()
44
        # queue client's hello message for sending
45
        self.send(content.make_hello(self._client_capabilities))
46
        # we expect server's hello message, wait for _init_event to be set by HelloListener
47
        self._init_event.wait()
48
        # there may have been an error
49
        if self._error:
50
            self._close()
51
            raise self._error
44 52
    
45 53
    def send(self, message):
46
        message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' % message).encode('utf-8')
54
        message = (u'<?xml version="1.0" encoding="UTF-8"?>%s' %
55
                   message).encode('utf-8')
47 56
        logger.debug('queueing message: \n%s' % message)
48 57
        self._q.put(message)
49 58
    
......
60 69
        return self._client_capabilities
61 70
    
62 71
    @property
63
    def serve_capabilities(self):
72
    def server_capabilities(self):
64 73
        return self._server_capabilities
65 74
    
66 75
    @property
......
71 80
    def id(self):
72 81
        return self._id
73 82
    
74
    def _post_connect(self):
75
        # start the subclass' main loop
76
        self.start()
77
        # queue client's hello message for sending
78
        self.send(content.make_hello(self._client_capabilities))
79
        # we expect server's hello message, wait for _init_event to be set by HelloListener
80
        self._init_event.wait()
81
        # there may have been an error
82
        if self._error:
83
            self._close()
84
            raise self._error
85
    
86 83
    class HelloListener:
87 84
        
88 85
        def __str__(self):
b/ncclient/ssh.py
13 13
# limitations under the License.
14 14

  
15 15
import logging
16
import paramiko
17

  
18
from os import SEEK_CUR
19 16
from cStringIO import StringIO
17
from os import SEEK_CUR
18

  
19
import paramiko
20 20

  
21 21
from session import Session, SessionError
22 22

  
......
42 42
                 missing_host_key_policy=paramiko.RejectPolicy):
43 43
        Session.__init__(self)
44 44
        self._client = paramiko.SSHClient()
45
        self._channel = None
45 46
        if load_known_hosts:
46 47
            self._client.load_system_host_keys()
47 48
        self._client.set_missing_host_key_policy(missing_host_key_policy)
......
49 50
        self._parsing_state = 0
50 51
        self._parsing_pos = 0
51 52
    
53
    def _close(self):
54
        self._channel.close()
55
        self._connected = False
56
    
57
    def _fresh_data(self):
58
        delim = SSHSession.MSG_DELIM
59
        n = len(delim) - 1
60
        state = self._parsing_state
61
        buf = self._in_buf
62
        buf.seek(self._parsing_pos)
63
        while True:
64
            x = buf.read(1)
65
            if not x: # done reading
66
                break
67
            elif x == delim[state]:
68
                state += 1
69
            else:
70
                continue
71
            # loop till last delim char expected, break if other char encountered
72
            for i in range(state, n):
73
                x = buf.read(1)
74
                if not x: # done reading
75
                    break
76
                if x==delim[i]: # what we expected
77
                    state += 1 # expect the next delim char
78
                else:
79
                    state = 0 # reset
80
                    break
81
            else: # if we didn't break out of above loop, full delim parsed
82
                till = buf.tell() - n
83
                buf.seek(0)
84
                msg = buf.read(till)
85
                self.dispatch('reply', msg)
86
                buf.seek(n+1, SEEK_CUR)
87
                rest = buf.read()
88
                buf = StringIO()
89
                buf.write(rest)
90
                buf.seek(0)
91
                state = 0
92
        self._in_buf = buf
93
        self._parsing_state = state
94
        self._parsing_pos = self._in_buf.tell()
95

  
52 96
    def load_host_keys(self, filename):
53 97
        self._client.load_host_keys(filename)
54 98
    
......
96 140
        except Exception as e:
97 141
            logger.debug('*** broke out of main loop ***')
98 142
            self.dispatch('error', e)
99
    
100
    def _close(self):
101
        self._channel.close()
102
        self._connected = False
103
    
104
    def _fresh_data(self):
105
        delim = SSHSession.MSG_DELIM
106
        n = len(delim) - 1
107
        state = self._parsing_state
108
        buf = self._in_buf
109
        buf.seek(self._parsing_pos)
110
        while True:
111
            x = buf.read(1)
112
            if not x: # done reading
113
                break
114
            elif x == delim[state]:
115
                state += 1
116
            else:
117
                continue
118
            # loop till last delim char expected, break if other char encountered
119
            for i in range(state, n):
120
                x = buf.read(1)
121
                if not x: # done reading
122
                    break
123
                if x==delim[i]: # what we expected
124
                    state += 1 # expect the next delim char
125
                else:
126
                    state = 0 # reset
127
                    break
128
            else: # if we didn't break out of above loop, full delim parsed
129
                till = buf.tell() - n
130
                buf.seek(0)
131
                msg = buf.read(till)
132
                self.dispatch('reply', msg)
133
                buf.seek(n+1, SEEK_CUR)
134
                rest = buf.read()
135
                buf = StringIO()
136
                buf.write(rest)
137
                buf.seek(0)
138
                state = 0
139
        self._in_buf = buf
140
        self._parsing_state = state
141
        self._parsing_pos = self._in_buf.tell()
142 143

  
143 144
class MissingHostKeyPolicy(paramiko.MissingHostKeyPolicy):
144 145
    

Also available in: Unified diff