Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
mkang315 authored Dec 4, 2023
1 parent a19b901 commit c03a839
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 0 deletions.
7 changes: 7 additions & 0 deletions yolo/bgf/detect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license

from .predict import DetectionPredictor, predict
from .train import DetectionTrainer, train
from .val import DetectionValidator, val

__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
47 changes: 47 additions & 0 deletions yolo/bgf/detect/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license

import torch

from ...yolo.engine.predictor import BasePredictor
from ...yolo.engine.results import Results
from ...yolo.utils import DEFAULT_CFG, ROOT, ops


class DetectionPredictor(BasePredictor):

def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)

results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
return results


def predict(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or 'yolov8n.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'

args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()


if __name__ == '__main__':
predict()
144 changes: 144 additions & 0 deletions yolo/bgf/detect/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license
from copy import copy

import numpy as np
import sys
sys.path.append("/root/BGF-YOLO")
from ...nn.tasks import DetectionModel
from ...yolo import bgf
from ...yolo.data import build_dataloader, build_yolo_dataset
from ...yolo.data.dataloaders.v5loader import create_dataloader
from ...yolo.engine.trainer import BaseTrainer
from ...yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ...yolo.utils.plotting import plot_images, plot_labels, plot_results
from ...yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first


# BaseTrainer python usage
class DetectionTrainer(BaseTrainer):

def build_dataset(self, img_path, mode='train', batch=None):
"""Build YOLO Dataset
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)

def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""TODO: manage splits differently."""
# Calculate stride - check if model is initialized
if self.args.v5loader:
LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using "
'the default YOLOv8 dataloader instead, no argument is needed.')
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return create_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
batch_size=batch_size,
stride=gs,
hyp=vars(self.args),
augment=mode == 'train',
cache=self.args.cache,
pad=0 if mode == 'train' else 0.5,
rect=self.args.rect or mode == 'val',
rank=rank,
workers=self.args.workers,
close_mosaic=self.args.close_mosaic != 0,
prefix=colorstr(f'{mode}: '),
shuffle=mode == 'train',
seed=self.args.seed)[0]
assert mode in ['train', 'val']
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == 'train'
if getattr(dataset, 'rect', False) and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == 'train' else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader

def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch

def set_model_attributes(self):
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data['names'] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc

def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model

def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys

def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')

def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=batch['batch_idx'],
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)

def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png

def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)


def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8x.pt'
data = cfg.data or 'coco.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''

args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = DetectionTrainer(overrides=args)
trainer.train()


if __name__ == '__main__':
train()
Loading

0 comments on commit c03a839

Please sign in to comment.