How to use the deepforest.utilities.format_args function in deepforest

To help you get started, we’ve selected a few deepforest examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github weecology / DeepForest / tests / test_deepforest.py View on Github external
def test_retrain_release(annotations, release_model):
    release_model.config["epochs"] = 1
    release_model.config["save-snapshot"] = False
    release_model.config["steps"] = 1
    
    assert release_model.config["weights"] == release_model.weights
    
    #test that it gets passed to retinanet
    arg_list = utilities.format_args(annotations, release_model.config, images_per_epoch=1)
    strs = ["--weights" == x for x in arg_list]
    index = np.where(strs)[0][0] + 1
    arg_list[index] == release_model.weights
github weecology / DeepForest / tests / test_utilities.py View on Github external
def test_format_args(annotations, config):
    arg_list = utilities.format_args(annotations, config)
    assert isinstance(arg_list, list)
github weecology / DeepForest / tests / test_deepforest.py View on Github external
def test_random_transform(annotations):
    test_model = deepforest.deepforest()
    test_model.config["random_transform"] = True
    arg_list = utilities.format_args(annotations, test_model.config)
    assert "--random-transform" in arg_list
github weecology / DeepForest / tests / test_utilities.py View on Github external
def test_format_args_steps(annotations, config):
    arg_list = utilities.format_args(annotations, config, images_per_epoch=2)
    assert isinstance(arg_list, list)
    
    #A bit ugly, but since its a list, what is the argument after --steps to assert
    steps_position = np.where(["--steps" in x for x in arg_list])[0][0] + 1
    assert arg_list[steps_position] == '2'
github weecology / DeepForest / deepforest / deepforest.py View on Github external
def predict_generator(self, annotations, comet_experiment = None, iou_threshold=0.5, score_threshold=0.05, max_detections=200):
        """Predict bounding boxes for a model using a csv fit_generator
        
        Args:
            annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
            iou_threshold(float): IoU Threshold to count for a positive detection (defaults to 0.5)
            score_threshold (float): Eliminate bounding boxes under this threshold
            max_detections (int): Maximum number of bounding box predictions
            comet_experiment(object): A comet experiment class objects to track
        
        Return:
            boxes_output: a pandas dataframe of bounding boxes for each image in the annotations file
        """
        #Format args for CSV generator 
        arg_list = utilities.format_args(annotations, self.config)
        args = parse_args(arg_list)
        
        #create generator
        generator = CSVGenerator(
            args.annotations,
            args.classes,
            image_min_side=args.image_min_side,
            image_max_side=args.image_max_side,
            config=args.config,
            shuffle_groups=False,
        )
        
        if self.prediction_model:
            boxes_output = [ ]
            #For each image, gather predictions
            for i in range(generator.size()):
github weecology / DeepForest / deepforest / deepforest.py View on Github external
def evaluate_generator(self, annotations, comet_experiment = None, iou_threshold=0.5, score_threshold=0.05, max_detections=200):
        """ Evaluate prediction model using a csv fit_generator
        
        Args:
            annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
            iou_threshold(float): IoU Threshold to count for a positive detection (defaults to 0.5)
            score_threshold (float): Eliminate bounding boxes under this threshold
            max_detections (int): Maximum number of bounding box predictions
            comet_experiment(object): A comet experiment class objects to track
        
        Return:
            mAP: Mean average precision of the evaluated data
        """
        #Format args for CSV generator 
        arg_list = utilities.format_args(annotations, self.config)
        args = parse_args(arg_list)
        
        #create generator
        validation_generator = CSVGenerator(
            args.annotations,
            args.classes,
            image_min_side=args.image_min_side,
            image_max_side=args.image_max_side,
            config=args.config,
            shuffle_groups=False,
        )
        
        average_precisions = evaluate(
            validation_generator,
            self.prediction_model,
            iou_threshold=iou_threshold,
github weecology / DeepForest / deepforest / deepforest.py View on Github external
'''Train a deep learning tree detection model using keras-retinanet.
        This is the main entry point for training a new model based on either existing weights or scratch
        
        Args:
            annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
            comet_experiment: A comet ml object to log images. Optional.
            list_of_tfrecords: Ignored if input_type != "tfrecord", list of tf records to process
            input_type: "fit_generator" or "tfrecord"
            images_per_epoch: number of images to override default config of # images in annotations file / batch size. Useful for debug
        
        Returns:
            model (object): A trained keras model
            prediction model: with bbox nms
            trained model: without nms
        '''
        arg_list = utilities.format_args(annotations, self.config, images_per_epoch)
            
        print("Training retinanet with the following args {}".format(arg_list))
        
        #Train model
        self.model, self.prediction_model, self.training_model = retinanet_train(args=arg_list, input_type = input_type, list_of_tfrecords = list_of_tfrecords, comet_experiment = comet_experiment)