diff options
Diffstat (limited to 'mpm/python/usrp_mpm')
| -rw-r--r-- | mpm/python/usrp_mpm/rpc_server.py | 69 | 
1 files changed, 47 insertions, 22 deletions
| diff --git a/mpm/python/usrp_mpm/rpc_server.py b/mpm/python/usrp_mpm/rpc_server.py index ce82393ab..7b8d1edba 100644 --- a/mpm/python/usrp_mpm/rpc_server.py +++ b/mpm/python/usrp_mpm/rpc_server.py @@ -18,6 +18,9 @@  Implemented RPC Servers  """  from __future__ import print_function +from random import choice +from string import ascii_letters, digits +from multiprocessing import Process  from gevent.server import StreamServer  from gevent.pool import Pool  from gevent import signal @@ -26,11 +29,19 @@ from gevent import Greenlet  from gevent import monkey  monkey.patch_all()  from mprpc import RPCServer -from random import choice -from string import ascii_letters, digits -from multiprocessing import Process  from .mpmlog import get_main_logger +TOKEN_LEN = 16 # Length of the token string + +def no_claim(func): +    " Decorator for functions that require no token check " +    func._notok = True +    return func + +def no_rpc(func): +    " Decorator for functions that should not be exposed via RPC " +    func._norpc = True +    return func  class MPMServer(RPCServer):      """ @@ -56,36 +67,48 @@ class MPMServer(RPCServer):      def _update_component_commands(self, component, namespace, storage):          """ -        Detect available methods for an object and add them to the RPC server -        """ -        for method in (m for m in dir(component) -                       if not m.startswith('_') and callable(getattr(component, m))): -            if method.startswith('safe_'): -                command_name = namespace + method.lstrip('safe_') -                self._add_safe_command(getattr(component, method), command_name) +        Detect available methods for an object and add them to the RPC server. + +        We skip all private methods, and all methods that use the @no_rpc +        decorator. +        """ +        for method_name in ( +                m for m in dir(component) +                if not m.startswith('_') \ +                    and callable(getattr(component, m)) \ +                    and not getattr(getattr(component, m), '_norpc', False) +            ): +            new_rpc_method = getattr(component, method_name) +            command_name = namespace + method_name +            if getattr(new_rpc_method, '_notok', False): +                self._add_safe_command(new_rpc_method, command_name)              else: -                command_name = namespace + method -                self._add_command(getattr(component, method), command_name) +                self._add_claimed_command(new_rpc_method, command_name)              getattr(self, storage).append(command_name) -    def _add_command(self, function, command): +    def _add_claimed_command(self, function, command):          """          Adds a method with the name command to the RPC server -        This command will require an acquired claim on the device +        This command will require an acquired claim on the device, and a valid +        token needs to be passed in for it to not fail. + +        If the method does not require a token, use _add_safe_command().          """          self.log.trace("adding command %s pointing to %s", command, function) -        def new_function(token, *args): -            if token[:256] != self._state.claim_token.value: +        def new_claimed_function(token, *args): +            " Define a function that requires a claim token check " +            if token[:TOKEN_LEN] != self._state.claim_token.value:                  return False              return function(*args) -        new_function.__doc__ = function.__doc__ -        setattr(self, command, new_function) +        new_claimed_function.__doc__ = function.__doc__ +        setattr(self, command, new_claimed_function)      def _add_safe_command(self, function, command):          """ -        Add a safe method which does not require a claim on the -        device +        Add a safe method which does not require a claim on the device. +        If the method should only be called by claimers, use +        _add_claimed_command().          """          self.log.trace("adding safe command %s pointing to %s", command, function)          setattr(self, command, function) @@ -117,7 +140,9 @@ class MPMServer(RPCServer):              return ""          self.log.debug("claiming from: %s", self.client_host)          self.periph_manager.claimed = True -        self._state.claim_token.value = ''.join(choice(ascii_letters + digits) for _ in range(256)) +        self._state.claim_token.value = ''.join( +            choice(ascii_letters + digits) for _ in range(TOKEN_LEN) +        )          self._state.claim_status.value = True          self._state.lock.release()          self.sender_id = sender_id @@ -133,7 +158,7 @@ class MPMServer(RPCServer):          """          self._state.lock.acquire()          if self._state.claim_status.value: -            if self._state.claim_token.value == token[:256]: +            if self._state.claim_token.value == token[:TOKEN_LEN]:                  self._state.lock.release()                  self.log.debug("reclaimed from: %s", self.client_host)                  self._reset_timer() | 
