How to use the trains.backend_interface.model.Model function in trains

To help you get started, we’ve selected a few trains examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github allegroai / trains / trains / task.py View on Github external
def flush(self, wait_for_uploads=False):
        """
        flush any outstanding reports or console logs

        :param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
        """

        # make sure model upload is done
        if BackendModel.get_num_results() > 0 and wait_for_uploads:
            BackendModel.wait_for_results()

        # flush any outstanding logs
        if self._logger:
            # noinspection PyProtectedMember
            self._logger._flush_stdout_handler()
        if self._reporter:
            self.reporter.flush()
        LoggerRoot.flush()

        return True
github allegroai / trains / trains / backend_interface / task / task.py View on Github external
def _get_output_model(self, upload_required=True, force=False):
        return Model(
            session=self.session,
            model_id=None if force else self._get_task_property(
                'output.model', raise_on_error=False, log_on_error=False),
            cache_dir=self.cache_dir,
            upload_storage_uri=self.storage_uri or self.get_output_destination(
                raise_on_error=upload_required, log_on_error=upload_required),
            upload_storage_suffix=self._get_output_destination_suffix('models'),
            log=self.log)
github allegroai / trains / trains / model.py View on Github external
def _get_base_model(self):
        if self._base_model:
            return self._base_model

        if not self._base_model_id:
            # this shouldn't actually happen
            raise Exception('Missing model ID, cannot create an empty model')
        self._base_model = _Model(
            upload_storage_uri=None,
            cache_dir=get_cache_dir(),
            model_id=self._base_model_id,
        )
        return self._base_model
github allegroai / trains / trains / backend_interface / task / task.py View on Github external
def input_model(self):
        """ A model manager used to handle the input model object """
        model_id = self._get_task_property('execution.model', raise_on_error=False)
        if not model_id:
            return None
        if self._input_model is None:
            self._input_model = Model(
                session=self.session,
                model_id=model_id,
                cache_dir=self.cache_dir,
                log=self.log,
                upload_storage_uri=None)
        return self._input_model
github allegroai / trains / trains / task.py View on Github external
if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None and \
                                not self.__exit_hook.exception:
                            task_status = ('completed', )
                        else:
                            task_status = ('stopped', )

            # wait for repository detection (if we didn't crash)
            if not is_sub_process and wait_for_uploads and self._logger:
                # we should print summary here
                self._summary_artifacts()
                # make sure that if we crashed the thread we are not waiting forever
                self._wait_for_repo_detection(timeout=10.)

            # wait for uploads
            print_done_waiting = False
            if wait_for_uploads and (BackendModel.get_num_results() > 0 or
                                     (self._reporter and self.reporter.get_num_results() > 0)):
                self.log.info('Waiting to finish uploads')
                print_done_waiting = True
            # from here, do not send log in background thread
            if wait_for_uploads:
                self.flush(wait_for_uploads=True)
                # wait until the reporter flush everything
                if self._reporter:
                    self.reporter.stop()
                if print_done_waiting:
                    self.log.info('Finished uploading')
            elif self._logger:
                self._logger._flush_stdout_handler()

            if not is_sub_process:
                # from here, do not check worker status
github allegroai / trains / trains / task.py View on Github external
def flush(self, wait_for_uploads=False):
        """
        flush any outstanding reports or console logs

        :param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
        """

        # make sure model upload is done
        if BackendModel.get_num_results() > 0 and wait_for_uploads:
            BackendModel.wait_for_results()

        # flush any outstanding logs
        if self._logger:
            # noinspection PyProtectedMember
            self._logger._flush_stdout_handler()
        if self._reporter:
            self.reporter.flush()
        LoggerRoot.flush()

        return True
github allegroai / trains / trains / model.py View on Github external
logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url))

            model = get_single_result(
                entity='model',
                query=weights_url,
                results=result.response.models,
                log=logger,
                raise_on_error=False,
            )

            logger.info("Selected model id: {}".format(model.id))

            return InputModel(model_id=model.id)

        base_model = _Model(
            upload_storage_uri=None,
            cache_dir=get_cache_dir(),
        )

        from .task import Task
        task = Task.current_task()
        if task:
            comment = 'Imported by task id: {}'.format(task.id) + ('\n'+comment if comment else '')
            project_id = task.project
            task_id = task.id
        else:
            project_id = None
            task_id = None

        if not framework:
            framework, file_ext = Framework._get_file_ext(