How to use the tensorpack.utils.stats.StatCounter function in tensorpack

To help you get started, we’ve selected a few tensorpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github thanosvlo / MARL-for-Anatomical-Landmark-Detection / common.py View on Github external
return
                    for i in range (0,self.agents):
                        self.queue_put_stoppable(self.q, sum_r[i])
                        self.queue_put_stoppable(self.q_dist, dist[i])


    q = queue.Queue()
    q_dist = queue.Queue()

    threads = [Worker(f, q, q_dist,agents=agents) for f in predictors]

    # start all workers
    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()
    dist_stat = StatCounter()

    # show progress bar w/ tqdm
    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        r = q.get()
        stat.feed(r)
        dist = q_dist.get()
        dist_stat.feed(dist)

    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        r = q.get()
github amiralansary / rl-medical / examples / LandmarkDetection / DQN / medical.py View on Github external
def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()
github qq456cvb / doudizhu-C / TensorPack / Hierarchical_Q / combination.py View on Github external
def recursive():
    import timeit
    env = Pyenv()
    st = StatCounter()
    for i in range(1):
        env.reset()
        env.prepare()
        # print(env.get_handcards())
        cards = env.get_handcards()[:15]
        cards = ['J', '10', '10', '7', '7', '6']

        # last_cards = ['3', '3']
        mask = get_mask_onehot60(cards, action_space, None).reshape(len(action_space), 15, 4).sum(-1).astype(np.uint8)
        valid = mask.sum(-1) > 0
        cards_target = Card.char2onehot60(cards).reshape(-1, 4).sum(-1).astype(np.uint8)
        t1 = timeit.default_timer()
        print(cards_target)
        print(mask[valid])
        combs = get_combinations_recursive(mask[valid, :], cards_target)
        print(combs)
github qq456cvb / doudizhu-C / simulator / expreplay.py View on Github external
setattr(self, k, v)
        self.agent_name = agent_name

        self.exploration = init_exploration
        self.num_actions = num_actions
        logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        # self._current_ob, self._action_space = self.get_state_and_action_spaces()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()
github amiralansary / rl-medical / examples / AutomaticViewPlanning / DQN / expreplay.py View on Github external
for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()
        self._player_distError = StatCounter()
github qq456cvb / doudizhu-C / TensorPack / MA_Hierarchical_Q / expreplay.py View on Github external
logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        self.player.reset()
        self.player.prepare()
        self._comb_mask = True
        self._fine_mask = None
        self._current_ob, self._action_space = self.get_state_and_action_spaces()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()
github qq456cvb / doudizhu-C / TensorPack / A3C_FC / evaluator_fc.py View on Github external
#     self.eval_episode = int(self.eval_episode * 0.94)

    def _trigger_epoch(self):
        t = time.time()
        farmer_win_rate = eval_with_funcs(
            self.pred_funcs, self.eval_episode, self.get_player_fn, verbose=False)
        t = time.time() - t
        if t > 10 * 60:  # eval takes too long
            self.eval_episode = int(self.eval_episode * 0.94)
        self.trainer.monitors.put_scalar('farmer win rate', farmer_win_rate)
        self.trainer.monitors.put_scalar('lord win rate', 1 - farmer_win_rate)


if __name__ == '__main__':
    env = Env()
    stat = StatCounter()
    init_cards = np.arange(15)
    # init_cards = np.append(init_cards[::4], init_cards[1::4])
    for _ in range(1000):
        env.reset()
        env.prepare_manual(init_cards)
        r = 0
        while r == 0:
            _, r, _ = env.step_auto()
        stat.feed(int(r < 0))
    print('lord win rate: {}'.format(stat.average))
github qq456cvb / doudizhu-C / TensorPack / Vanilla_Q / evaluator.py View on Github external
with self.default_sess():
                player = get_player_fn()
                while not self.stopped():
                    try:
                        val = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, val)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()

    def fetch():
        val = q.get()
        stat.feed(val)
        if verbose:
            if val > 0:
                logger.info("farmer wins")
            else:
                logger.info("lord wins")

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        fetch()
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
github qq456cvb / doudizhu-C / TensorPack / Hierarchical_Q / evaluator.py View on Github external
with self.default_sess():
                player = get_player_fn()
                while not self.stopped():
                    try:
                        val = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, val)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()

    def fetch():
        val = q.get()
        stat.feed(val)
        if verbose:
            if val > 0:
                logger.info("farmer wins")
            else:
                logger.info("lord wins")

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        fetch()
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
github qq456cvb / doudizhu-C / TensorPack / ValueSL / evaluator.py View on Github external
player = get_player_fn()
                while not self.stopped():
                    try:
                        stats = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    scores = [stat.average if stat.count > 0 else -1 for stat in stats]
                    self.queue_put_stoppable(self.q, scores)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stats = [StatCounter() for _ in range(7)]

    def fetch():
        scores = q.get()
        for i, score in enumerate(scores):
            if scores[i] >= 0:
                stats[i].feed(scores[i])
        accs = [stat.average if stat.count > 0 else 0 for stat in stats]
        if verbose:
            logger.info("passive decision accuracy: {}\n"
                        "passive bomb accuracy: {}\n"
                        "passive response accuracy: {}\n"
                        "active decision accuracy: {}\n"
                        "active response accuracy: {}\n"
                        "active sequence accuracy: {}\n"
                        "minor response accuracy: {}\n".format(accs[0], accs[1], accs[2], accs[3], accs[4], accs[5], accs[6]))