diff options
| -rwxr-xr-x | mpm/tools/mpm_shell.py | 139 | 
1 files changed, 86 insertions, 53 deletions
| diff --git a/mpm/tools/mpm_shell.py b/mpm/tools/mpm_shell.py index e480a4857..ed2998809 100755 --- a/mpm/tools/mpm_shell.py +++ b/mpm/tools/mpm_shell.py @@ -12,10 +12,8 @@ from __future__ import print_function  import cmd  import time  import argparse -import threading +import multiprocessing  from importlib import import_module -from mprpc import RPCClient -from mprpc.exceptions import RPCError  try:      from usrp_mpm.mpmtypes import MPM_RPC_PORT @@ -69,54 +67,97 @@ class MPMClaimer(object):      """      Holds a claim.      """ -    def __init__(self, host, port, disc_callback): +    def __init__(self, host, port):          self.token = None -        self._exit_loop = False -        self._disc_callback = disc_callback -        self._claim_loop = threading.Thread( +        self.hijacked = False +        self._cmd_q = multiprocessing.Queue() +        self._token_q = multiprocessing.Queue() +        self._claim_loop = multiprocessing.Process(              target=self.claim_loop,              name="Claimer Loop", -            args=(host, port, self._disc_callback) +            args=(host, port, self._cmd_q, self._token_q)          )          self._claim_loop.start() -    def claim_loop(self, host, port, disc_callback): +    def claim_loop(self, host, port, cmd_q, token_q):          """          Run a claim loop          """ +        from mprpc import RPCClient +        from mprpc.exceptions import RPCError +        cmd = None +        token = None +        exit_loop = False          client = RPCClient(host, port, pack_params={'use_bin_type': True}) -        self.token = client.call('claim', 'MPM Shell')          try: -            while not self._exit_loop: -                client.call('reclaim', self.token) +            while not exit_loop: +                if token and not cmd: +                    client.call('reclaim', token) +                elif cmd == 'claim': +                    if not token: +                        token = client.call('claim', 'MPM Shell') +                    else: +                        print("Already have claim") +                    token_q.put(token) +                elif cmd == 'unclaim': +                    if token: +                        client.call('unclaim', token) +                    token = None +                    token_q.put(None) +                elif cmd == 'exit': +                    if token: +                        client.call('unclaim', token) +                    token = None +                    token_q.put(None) +                    exit_loop = True                  time.sleep(1) -            client.call('unclaim', self.token) +                cmd = None +                if not cmd_q.empty(): +                    cmd = cmd_q.get(False)          except RPCError as ex:              print("Unexpected RPC error in claimer loop!")              print(str(ex)) -        disc_callback() -        self.token = None -    def unclaim(self): +    def exit(self):          """          Unclaim device and exit claim loop.          """ -        self._exit_loop = True +        self.unclaim() +        self._cmd_q.put('exit')          self._claim_loop.join() -class MPMHijacker(object): -    """ -    Looks like a claimer object, but doesn't actually claim. -    """ -    def __init__(self, token): -        self.token = token -      def unclaim(self):          """ -        Unclaim device and exit claim loop. +        Unclaim device. +        """ +        if not self.hijacked: +            self._cmd_q.put('unclaim') +        else: +            self.hijacked = False +        self.token = None + +    def claim(self): +        """ +        Claim device. +        """ +        self._cmd_q.put('claim') +        self.token = self._token_q.get(True, 5.0) + +    def get_token(self): +        """ +        Get current token (if any)          """ -        pass +        if not self._token_q.empty(): +            self.token = self._token_q.get(False) +        return self.token +    def hijack(self, token): +        if self.token: +            print("Already have token") +            return +        else: +            self.token = token +        self.hijacked = True  class MPMShell(cmd.Cmd):      """ @@ -127,10 +168,10 @@ class MPMShell(cmd.Cmd):          self.prompt = "> "          self.client = None          self.remote_methods = [] -        self._claimer = None          self._host = host          self._port = port          self._device_info = None +        self._claimer = MPMClaimer(self._host, self._port)          if host is not None:              self.connect(host, port)              if claim: @@ -156,15 +197,16 @@ class MPMShell(cmd.Cmd):          """          Template function to create new RPC shell commands          """ +        from mprpc.exceptions import RPCError          if requires_token and \ -                (self._claimer is None or self._claimer.token is None): +                (self._claimer is None or self._claimer.get_token() is None):              print("Cannot execute `{}' -- no claim available!")              return          try:              if args or requires_token:                  expanded_args = self.expand_args(args)                  if requires_token: -                    expanded_args.insert(0, self._claimer.token) +                    expanded_args.insert(0, self._claimer.get_token())                  response = self.client.call(command, *expanded_args)              else:                  response = self.client.call(command) @@ -214,6 +256,8 @@ class MPMShell(cmd.Cmd):          """          Launch a connection.          """ +        from mprpc import RPCClient +        from mprpc.exceptions import RPCError          print("Attempting to connect to {host}:{port}...".format(              host=host, port=port          )) @@ -239,9 +283,10 @@ class MPMShell(cmd.Cmd):          """          Clean up after a connection was closed.          """ +        from mprpc.exceptions import RPCError          self._device_info = None          if self._claimer is not None: -            self._claimer.unclaim() +            self._claimer.exit()          if self.client:              try:                  self.client.close() @@ -257,37 +302,24 @@ class MPMShell(cmd.Cmd):      def claim(self):          " Initialize claim " -        assert self.client is not None -        if self._claimer is not None: -            print("Claimer already active.") -            return True          print("Claiming device...") -        self._claimer = MPMClaimer(self._host, self._port, self.unclaim_hook) +        self._claimer.claim()          return True      def hijack(self, token):          " Hijack running session " -        assert self.client is not None -        if self._claimer is not None: +        if self._claimer.hijacked:              print("Claimer already active. Can't hijack.")              return False          print("Hijacking device...") -        self._claimer = MPMHijacker(token) +        self._claimer.hijack(token)          return True      def unclaim(self):          """          unclaim          """ -        if self._claimer is not None: -            self._claimer.unclaim() -            self._claimer = None - -    def unclaim_hook(self): -        """ -        Hook -        """ -        pass +        self._claimer.unclaim()      def update_prompt(self):          """ @@ -296,12 +328,13 @@ class MPMShell(cmd.Cmd):          if self._device_info is None:              self.prompt = '> '          else: -            if self._claimer is None: +            token = self._claimer.get_token() +            if token is None:                  claim_status = '' -            elif isinstance(self._claimer, MPMClaimer): -                claim_status = ' [C]' -            elif isinstance(self._claimer, MPMHijacker): +            elif self._claimer.hijacked:                  claim_status = ' [H]' +            else: +                claim_status = ' [C]'              self.prompt = '{dev_id}{claim_status}> '.format(                  dev_id=self._device_info.get(                      'name', self._device_info.get('serial', '?') @@ -313,8 +346,8 @@ class MPMShell(cmd.Cmd):          """          Takes a string and returns a list          """ -        if self._claimer is not None and self._claimer.token is not None: -            args = args.replace('$T', str(self._claimer.token)) +        if self._claimer is not None and self._claimer.get_token() is not None: +            args = args.replace('$T', str(self._claimer.get_token()))          eval_preamble = '='          args = args.strip()          if args.startswith(eval_preamble): | 
