How to use the tensorpack.callbacks.base.Callback function in tensorpack

To help you get started, we’ve selected a few tensorpack examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github amiralansary / rl-medical / examples / AutomaticViewPlanning / DQN / common.py View on Github external
###############################################################################

def eval_model_multithread(pred, nr_eval, get_player_fn):
    """
    Args:
        pred (OfflinePredictor): state -> Qvalue
    """
    NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
    with pred.sess.as_default():
        mean_score, max_score, mean_dist, max_dist = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
    logger.info("Average Score: {}; Max Score: {}; Average Distance: {}; Max Distance: {}".format(mean_score, max_score, mean_dist, max_dist))

###############################################################################

class Evaluator(Callback):

    def __init__(self, nr_eval, input_names, output_names,
                 get_player_fn, directory, files_list = None):
        self.directory = directory
        self.files_list = files_list
        self.eval_episode = nr_eval
        self.input_names = input_names
        self.output_names = output_names
        self.get_player_fn = get_player_fn

    def _setup_graph(self):
        NR_PROC = min(multiprocessing.cpu_count() // 2, 20)
        self.pred_funcs = [self.trainer.get_predictor(
            self.input_names, self.output_names)] * NR_PROC

    def _trigger(self):
github tensorpack / tensorpack / tensorpack / callbacks / misc.py View on Github external
# File: misc.py


import numpy as np
import os
import time
from collections import deque

from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback

__all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft']


class SendStat(Callback):
    """ An equivalent of :class:`SendMonitorData`, but as a normal callback. """
    def __init__(self, command, names):
        self.command = command
        if not isinstance(names, list):
            names = [names]
        self.names = names

    def _trigger(self):
        M = self.trainer.monitors
        v = {k: M.get_latest(k) for k in self.names}
        cmd = self.command.format(**v)
        ret = os.system(cmd)
        if ret != 0:
            logger.error("Command {} failed with ret={}!".format(cmd, ret))
github tensorpack / tensorpack / tensorpack / callbacks / summary.py View on Github external
# File: summary.py


import numpy as np
from collections import deque

from ..compat import tfv1 as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback

__all__ = ['MovingAverageSummary', 'MergeAllSummaries', 'SimpleMovingAverage']


class MovingAverageSummary(Callback):
    """
    Maintain the moving average of summarized tensors in every step,
    by ops added to the collection.
    Note that it only **maintains** the moving averages by updating
    the relevant variables in the graph,
    the actual summary should be done in other callbacks.

    This callback is one of the :func:`DEFAULT_CALLBACKS()`.
    """
    def __init__(self, collection=MOVING_SUMMARY_OPS_KEY, train_op=None):
        """
        Args:
            collection(str): the collection of EMA-maintaining ops.
                The default value would work with
                the tensors you added by :func:`tfutils.summary.add_moving_summary()`,
                but you can use other collections as well.
github tensorpack / tensorpack / tensorpack / callbacks / summary.py View on Github external
else:
            if isinstance(self._train_op, tf.Tensor):
                self._train_op = self._train_op.op
            if not isinstance(self._train_op, tf.Operation):
                self._train_op = self.graph.get_operation_by_name(self._train_op)
            self._train_op._add_control_inputs(ops)
            logger.info("[MovingAverageSummary] {} operations in collection '{}'"
                        " will be run together with operation '{}'.".format(
                            len(ops), self._collection, self._train_op.name))

    def _before_run(self, _):
        if self._train_op is None:
            return self._fetch


class MergeAllSummaries_RunAlone(Callback):
    def __init__(self, period, key):
        self._period = period
        self._key = key

    def _setup_graph(self):
        size = len(tf.get_collection(self._key))
        logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
        self.summary_op = tf.summary.merge_all(self._key)

    def _trigger_step(self):
        if self._period:
            if (self.local_step + 1) % self._period == 0:
                self._trigger()

    def _trigger(self):
        if self.summary_op:
github tensorpack / tensorpack / tensorpack / callbacks / summary.py View on Github external
size = len(tf.get_collection(self._key))
        logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
        self.summary_op = tf.summary.merge_all(self._key)

    def _trigger_step(self):
        if self._period:
            if (self.local_step + 1) % self._period == 0:
                self._trigger()

    def _trigger(self):
        if self.summary_op:
            summary = self.summary_op.eval()
            self.trainer.monitors.put_summary(summary)


class MergeAllSummaries_RunWithOp(Callback):
    def __init__(self, period, key):
        self._period = period
        self._key = key

    def _setup_graph(self):
        size = len(tf.get_collection(self._key))
        logger.info("Summarizing collection '{}' of size {}.".format(self._key, size))
        self.summary_op = tf.summary.merge_all(self._key)
        if self.summary_op is not None:
            self._fetches = tf.train.SessionRunArgs(self.summary_op)
        else:
            self._fetches = None

    def _need_run(self):
        if self.local_step == self.trainer.steps_per_epoch - 1:
            return True
github tensorpack / tensorpack / tensorpack / callbacks / saver.py View on Github external
collection_list=self.graph.get_all_collection_keys())

    def _trigger(self):
        try:
            self.saver.save(
                tf.get_default_session(),
                self.path,
                global_step=tf.train.get_global_step(),
                write_meta_graph=False)
            logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
        except (OSError, IOError, tf.errors.PermissionDeniedError,
                tf.errors.ResourceExhaustedError):   # disk error sometimes.. just ignore it
            logger.exception("Exception in ModelSaver!")


class MinSaver(Callback):
    """
    Separately save the model with minimum value of some statistics.
    """
    def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
        """
        Args:
            monitor_stat(str): the name of the statistics.
            reverse (bool): if True, will save the maximum.
            filename (str): the name for the saved model.
                Defaults to ``min-{monitor_stat}.tfmodel``.
            checkpoint_dir (str): the directory containing checkpoints.

        Example:
            Save the model with minimum validation error to
            "min-val-error.tfmodel":
github pkumusic / E-DRL / tensorpack / callbacks / graph.py View on Github external
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: graph.py
# Author: Yuxin Wu 

""" Graph related callbacks"""

from .base import Callback
from ..utils import logger

__all__ = ['RunOp']

class RunOp(Callback):
    """ Run an op periodically"""
    def __init__(self, setup_func, run_before=True, run_epoch=True):
        """
        :param setup_func: a function that returns the op in the graph
        :param run_before: run the op before training
        :param run_epoch: run the op on every epoch trigger
        """
        self.setup_func = setup_func
        self.run_before = run_before
        self.run_epoch = run_epoch

    def _setup_graph(self):
        self._op = self.setup_func()
        #self._op_name = self._op.name

    def _before_train(self):
github tensorpack / tensorpack / tensorpack / callbacks / hooks.py View on Github external
You shouldn't need to use this.
    """

    def __init__(self, cb):
        self._cb = cb

    @HIDE_DOC
    def before_run(self, ctx):
        return self._cb.before_run(ctx)

    @HIDE_DOC
    def after_run(self, ctx, vals):
        self._cb.after_run(ctx, vals)


class HookToCallback(Callback):
    """
    Make a ``tf.train.SessionRunHook`` into a callback.
    Note that when ``SessionRunHook.after_create_session`` is called, the ``coord`` argument will be None.
    """

    _chief_only = False

    def __init__(self, hook):
        """
        Args:
            hook (tf.train.SessionRunHook):
        """
        self._hook = hook

    def _setup_graph(self):
        with tf.name_scope(None):   # jump out of the name scope
github pkumusic / E-DRL / tensorpack / RL / expreplay.py View on Github external
import threading
from tqdm import tqdm
import six
from six.moves import queue

from ..dataflow import DataFlow
from ..utils import *
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback

__all__ = ['ExpReplay']

Experience = namedtuple('Experience',
        ['state', 'action', 'reward', 'isOver'])

class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning`.

    This implementation provides the interface as an DataFlow.
    This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
    """
    def __init__(self,
            predictor_io_names,
            player,
            batch_size=32,
            memory_size=1e6,
            init_memory_size=50000,
            exploration=1,
            end_exploration=0.1,
            exploration_epoch_anneal=0.002,
github tensorpack / tensorpack / tensorpack / callbacks / saver.py View on Github external
clear_extraneous_savers=True)

    def _trigger(self):
        try:
            self.saver.save(
                tf.get_default_session(),
                self.path,
                global_step=tf.train.get_global_step(),
                write_meta_graph=False)
            logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
        except (OSError, IOError, tf.errors.PermissionDeniedError,
                tf.errors.ResourceExhaustedError):   # disk error sometimes.. just ignore it
            logger.exception("Exception in ModelSaver!")


class MinSaver(Callback):
    """
    Separately save the model with minimum value of some statistics.
    """
    def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
        """
        Args:
            monitor_stat(str): the name of the statistics.
            reverse (bool): if True, will save the maximum.
            filename (str): the name for the saved model.
                Defaults to ``min-{monitor_stat}.tfmodel``.
            checkpoint_dir (str): the directory containing checkpoints.

        Example:
            Save the model with minimum validation error to
            "min-val-error.tfmodel":