Skip to content

Commit

Permalink
Resolve linter warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jaagut committed Dec 29, 2021
1 parent 616e504 commit 7a3d02a
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pytorchyolo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
return x


class Mish(nn.Module):
""" The MISH activation function (https://github.com/digantamisra98/Mish) """

Expand All @@ -123,6 +124,7 @@ def __init__(self):
def forward(self, x):
return x * torch.tanh(F.softplus(x))


class YOLOLayer(nn.Module):
"""Detection layer"""

Expand Down Expand Up @@ -152,7 +154,7 @@ def forward(self, x, img_size):
self.grid = self._make_grid(nx, ny).to(x.device)

x[..., 0:2] = (x[..., 0:2].sigmoid() + self.grid) * stride # xy
x[..., 2:4] = torch.exp(x[..., 2:4]) * self.anchor_grid # wh
x[..., 2:4] = torch.exp(x[..., 2:4]) * self.anchor_grid # wh
x[..., 4:] = x[..., 4:].sigmoid()
x = x.view(bs, -1, self.no)

Expand Down Expand Up @@ -186,7 +188,7 @@ def forward(self, x):
combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
group_size = combined_outputs.shape[1] // int(module_def.get("groups", 1))
group_id = int(module_def.get("group_id", 0))
x = combined_outputs[:, group_size * group_id : group_size * (group_id + 1)] # Slice groupings used by yolo v4
x = combined_outputs[:, group_size * group_id:group_size * (group_id + 1)] # Slice groupings used by yolo v4
elif module_def["type"] == "shortcut":
layer_i = int(module_def["from"])
x = layer_outputs[-1] + layer_outputs[layer_i]
Expand Down
2 changes: 1 addition & 1 deletion pytorchyolo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.autograd import Variable

from pytorchyolo.models import load_model
from pytorchyolo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, print_environment_info
from pytorchyolo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, xywh2xyxy, print_environment_info
from pytorchyolo.utils.datasets import ListDataset
from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS
from pytorchyolo.utils.parse_config import parse_data_config
Expand Down
2 changes: 1 addition & 1 deletion pytorchyolo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set
from pytorchyolo.utils.datasets import ListDataset
from pytorchyolo.utils.augmentations import AUGMENTATION_TRANSFORMS
#from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS
# from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS
from pytorchyolo.utils.parse_config import parse_data_config
from pytorchyolo.utils.loss import compute_loss
from pytorchyolo.test import _evaluate, _create_validation_data_loader
Expand Down
2 changes: 1 addition & 1 deletion pytorchyolo/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def compute_loss(predictions, targets, model):

# Classification of the objectness the sequel
# Calculate the BCE loss between the on the fly generated target and the network prediction
lobj += BCEobj(layer_predictions[..., 4], tobj) # obj loss
lobj += BCEobj(layer_predictions[..., 4], tobj) # obj loss

lbox *= 0.05
lobj *= 1.0
Expand Down
4 changes: 3 additions & 1 deletion pytorchyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def provide_determinism(seed=42):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def worker_seed_set(worker_id):
# See for details of numpy:
# https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
Expand Down Expand Up @@ -390,6 +391,7 @@ def print_environment_info():

# Print commit hash if possible
try:
print(f"Current Commit Hash: {subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()}")
commit_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()
print(f"Current Commit Hash: {commit_hash}")
except (subprocess.CalledProcessError, FileNotFoundError):
print("No git or repo found")

0 comments on commit 7a3d02a

Please sign in to comment.