-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
498 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license | ||
|
||
from .predict import DetectionPredictor, predict | ||
from .train import DetectionTrainer, train | ||
from .val import DetectionValidator, val | ||
|
||
__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license | ||
|
||
import torch | ||
|
||
from ...yolo.engine.predictor import BasePredictor | ||
from ...yolo.engine.results import Results | ||
from ...yolo.utils import DEFAULT_CFG, ROOT, ops | ||
|
||
|
||
class DetectionPredictor(BasePredictor): | ||
|
||
def postprocess(self, preds, img, orig_imgs): | ||
"""Postprocesses predictions and returns a list of Results objects.""" | ||
preds = ops.non_max_suppression(preds, | ||
self.args.conf, | ||
self.args.iou, | ||
agnostic=self.args.agnostic_nms, | ||
max_det=self.args.max_det, | ||
classes=self.args.classes) | ||
|
||
results = [] | ||
for i, pred in enumerate(preds): | ||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs | ||
if not isinstance(orig_imgs, torch.Tensor): | ||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) | ||
path = self.batch[0] | ||
img_path = path[i] if isinstance(path, list) else path | ||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) | ||
return results | ||
|
||
|
||
def predict(cfg=DEFAULT_CFG, use_python=False): | ||
model = cfg.model or 'yolov8n.pt' | ||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ | ||
else 'https://ultralytics.com/images/bus.jpg' | ||
|
||
args = dict(model=model, source=source) | ||
if use_python: | ||
from ultralytics import YOLO | ||
YOLO(model)(**args) | ||
else: | ||
predictor = DetectionPredictor(overrides=args) | ||
predictor.predict_cli() | ||
|
||
|
||
if __name__ == '__main__': | ||
predict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# BGF-YOLO based on Ultralytics YOLOv8x 8.0.109 object detection model with same license, AGPL-3.0 license | ||
from copy import copy | ||
|
||
import numpy as np | ||
import sys | ||
sys.path.append("/root/BGF-YOLO") | ||
from ...nn.tasks import DetectionModel | ||
from ...yolo import bgf | ||
from ...yolo.data import build_dataloader, build_yolo_dataset | ||
from ...yolo.data.dataloaders.v5loader import create_dataloader | ||
from ...yolo.engine.trainer import BaseTrainer | ||
from ...yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr | ||
from ...yolo.utils.plotting import plot_images, plot_labels, plot_results | ||
from ...yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first | ||
|
||
|
||
# BaseTrainer python usage | ||
class DetectionTrainer(BaseTrainer): | ||
|
||
def build_dataset(self, img_path, mode='train', batch=None): | ||
"""Build YOLO Dataset | ||
Args: | ||
img_path (str): Path to the folder containing images. | ||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. | ||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. | ||
""" | ||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) | ||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs) | ||
|
||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): | ||
"""TODO: manage splits differently.""" | ||
# Calculate stride - check if model is initialized | ||
if self.args.v5loader: | ||
LOGGER.warning("WARNING ⚠️ 'v5loader' feature is deprecated and will be removed soon. You can train using " | ||
'the default YOLOv8 dataloader instead, no argument is needed.') | ||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) | ||
return create_dataloader(path=dataset_path, | ||
imgsz=self.args.imgsz, | ||
batch_size=batch_size, | ||
stride=gs, | ||
hyp=vars(self.args), | ||
augment=mode == 'train', | ||
cache=self.args.cache, | ||
pad=0 if mode == 'train' else 0.5, | ||
rect=self.args.rect or mode == 'val', | ||
rank=rank, | ||
workers=self.args.workers, | ||
close_mosaic=self.args.close_mosaic != 0, | ||
prefix=colorstr(f'{mode}: '), | ||
shuffle=mode == 'train', | ||
seed=self.args.seed)[0] | ||
assert mode in ['train', 'val'] | ||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP | ||
dataset = self.build_dataset(dataset_path, mode, batch_size) | ||
shuffle = mode == 'train' | ||
if getattr(dataset, 'rect', False) and shuffle: | ||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") | ||
shuffle = False | ||
workers = self.args.workers if mode == 'train' else self.args.workers * 2 | ||
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader | ||
|
||
def preprocess_batch(self, batch): | ||
"""Preprocesses a batch of images by scaling and converting to float.""" | ||
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 | ||
return batch | ||
|
||
def set_model_attributes(self): | ||
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" | ||
# self.args.box *= 3 / nl # scale to layers | ||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers | ||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers | ||
self.model.nc = self.data['nc'] # attach number of classes to model | ||
self.model.names = self.data['names'] # attach class names to model | ||
self.model.args = self.args # attach hyperparameters to model | ||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc | ||
|
||
def get_model(self, cfg=None, weights=None, verbose=True): | ||
"""Return a YOLO detection model.""" | ||
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||
if weights: | ||
model.load(weights) | ||
return model | ||
|
||
def get_validator(self): | ||
"""Returns a DetectionValidator for YOLO model validation.""" | ||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' | ||
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||
|
||
def label_loss_items(self, loss_items=None, prefix='train'): | ||
""" | ||
Returns a loss dict with labelled training loss items tensor | ||
""" | ||
# Not needed for classification but necessary for segmentation & detection | ||
keys = [f'{prefix}/{x}' for x in self.loss_names] | ||
if loss_items is not None: | ||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats | ||
return dict(zip(keys, loss_items)) | ||
else: | ||
return keys | ||
|
||
def progress_string(self): | ||
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" | ||
return ('\n' + '%11s' * | ||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') | ||
|
||
def plot_training_samples(self, batch, ni): | ||
"""Plots training samples with their annotations.""" | ||
plot_images(images=batch['img'], | ||
batch_idx=batch['batch_idx'], | ||
cls=batch['cls'].squeeze(-1), | ||
bboxes=batch['bboxes'], | ||
paths=batch['im_file'], | ||
fname=self.save_dir / f'train_batch{ni}.jpg', | ||
on_plot=self.on_plot) | ||
|
||
def plot_metrics(self): | ||
"""Plots metrics from a CSV file.""" | ||
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png | ||
|
||
def plot_training_labels(self): | ||
"""Create a labeled training plot of the YOLO model.""" | ||
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) | ||
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) | ||
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot) | ||
|
||
|
||
def train(cfg=DEFAULT_CFG, use_python=False): | ||
"""Train and optimize YOLO model given training data and device.""" | ||
model = cfg.model or 'yolov8x.pt' | ||
data = cfg.data or 'coco.yaml' # or yolo.ClassificationDataset("mnist") | ||
device = cfg.device if cfg.device is not None else '' | ||
|
||
args = dict(model=model, data=data, device=device) | ||
if use_python: | ||
from ultralytics import YOLO | ||
YOLO(model).train(**args) | ||
else: | ||
trainer = DetectionTrainer(overrides=args) | ||
trainer.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train() |
Oops, something went wrong.