diff options
| -rw-r--r-- | mpm/python/usrp_mpm/rpc_server.py | 39 | 
1 files changed, 31 insertions, 8 deletions
| diff --git a/mpm/python/usrp_mpm/rpc_server.py b/mpm/python/usrp_mpm/rpc_server.py index 0084184f1..4fbaa248f 100644 --- a/mpm/python/usrp_mpm/rpc_server.py +++ b/mpm/python/usrp_mpm/rpc_server.py @@ -34,6 +34,7 @@ from builtins import range  from mprpc import RPCServer  from .mpmlog import get_main_logger +TIMEOUT_INTERVAL = 3.0 # Seconds before claim expires  TOKEN_LEN = 16 # Length of the token string  def no_claim(func): @@ -60,6 +61,7 @@ class MPMServer(RPCServer):          self.log = get_main_logger().getChild('RPCServer')          self._state = state          self._timer = Greenlet() +        self.session_id = None          self.periph_manager = mgr          # add public mboard methods without namespace          self._update_component_commands(mgr, '', '_mb_methods') @@ -68,6 +70,17 @@ class MPMServer(RPCServer):              self._update_component_commands(dboard, 'db_' + str(db_slot) + '_', '_db_methods')          super(MPMServer, self).__init__(*args, **kwargs) +    def _check_token_valid(self, token): +        """ +        Returns True iff: +        - The device is currently claimed +        - The claim token matches the one passed in +        """ +        return self._state.claim_status.value and \ +                len(token) == TOKEN_LEN and \ +                self._state.claim_token.value == bytes(token, 'ascii') + +      def _update_component_commands(self, component, namespace, storage):          """          Detect available methods for an object and add them to the RPC server. @@ -101,9 +114,9 @@ class MPMServer(RPCServer):          self.log.trace("adding command %s pointing to %s", command, function)          def new_claimed_function(token, *args):              " Define a function that requires a claim token check " -            if bytes(token[:TOKEN_LEN], 'ascii') != self._state.claim_token.value: +            if not self._check_token_valid(token):                  self.log.warning( -                    "Stopped attempt to access function `{}' with invalid " \ +                    "Thwarted attempt to access function `{}' with invalid " \                      "token `{}'.".format(command, token)                  )                  raise RuntimeError("Invalid token!") @@ -168,7 +181,7 @@ class MPMServer(RPCServer):          """          self._state.lock.acquire()          if self._state.claim_status.value: -            if self._state.claim_token.value == bytes(token[:TOKEN_LEN], 'ascii'): +            if self._check_token_valid(token):                  self._state.lock.release()                  self.log.debug("reclaimed from: %s", self.client_host)                  self._reset_timer() @@ -179,14 +192,19 @@ class MPMServer(RPCServer):                  self.client_host, token[:TOKEN_LEN]              )              return False -        self.log.debug("trying to reclaim unclaimed device from: %s", self.client_host) +        self.log.debug( +            "trying to reclaim unclaimed device from: %s", +            self.client_host +        )          return False      def _unclaim(self):          """          unconditional unclaim - for internal use          """ -        self.log.debug("releasing claim") +        self.log.debug("Releasing claim on session `{}' by `{}'".format( +            self.session_id, self.client_host +        ))          self._state.claim_status.value = False          self._state.claim_token.value = b''          self.session_id = None @@ -199,15 +217,17 @@ class MPMServer(RPCServer):          reset unclaim timer          """          self._timer.kill() -        self._timer = spawn_later(2.0, self._unclaim) +        self._timer = spawn_later(TIMEOUT_INTERVAL, self._unclaim)      def unclaim(self, token):          """ -        unclaim `token` - unclaims the MPM device if it is claimed with this token +        unclaim `token` - unclaims the MPM device if it is claimed with this +        token          """ -        if self._state.claim_status.value and self._state.claim_token.value == token: +        if self._check_token_valid(token):              self._unclaim()              return True +        self.log.warning("Attempt to unclaim session with invalid token!")          return False      def get_device_info(self): @@ -227,6 +247,9 @@ class MPMServer(RPCServer):          Forwards the call to periph_manager._allocate_sid with the client ip addresss          as argument. Should be used to setup interfaces          """ +        if not self._check_token_valid(token): +            self.log.warning("Attempt to allocate SID without valid token!") +            return None          return self.periph_manager._allocate_sid(self.client_host, *args) | 
