Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hardikdava committed Dec 6, 2024
1 parent 672baf4 commit f2dc224
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 21 deletions.
1 change: 1 addition & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from trolo.utils.logging.glob_logger import platform_safe_emojis, configure_logger, add_separator_method


@pytest.fixture
def test_logger():
"""Fixture to create a test logger with separator support."""
Expand Down
2 changes: 1 addition & 1 deletion trolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

from .loaders.registry import GLOBAL_CONFIG
from .inference import DetectionPredictor
from .utils.box_ops import to_sv
from .utils.box_ops import to_sv
1 change: 1 addition & 0 deletions trolo/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from trolo.utils.draw_utils import draw_predictions
from trolo.utils.logging import LOGGER


class BasePredictor(ABC):
def __init__(self, model_path: str, device: Optional[str] = None):
self.device = torch.device(infer_device(device))
Expand Down
5 changes: 4 additions & 1 deletion trolo/inference/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..utils.logging import LOGGER


class DetectionPredictor(BasePredictor):
def __init__(
self,
Expand Down Expand Up @@ -111,7 +112,9 @@ def preprocess(self, inputs: Union[str, List[str], Image.Image, List[Image.Image

return torch.stack(images).to(self.device)

def postprocess(self, outputs: torch.Tensor, letterbox_sizes: List[Tuple[int, int]], original_sizes: List[Tuple[int, int]]) -> List[Dict[str, Any]]:
def postprocess(
self, outputs: torch.Tensor, letterbox_sizes: List[Tuple[int, int]], original_sizes: List[Tuple[int, int]]
) -> List[Dict[str, Any]]:
"""Convert model outputs to boxes, scores, labels
Returns:
Expand Down
2 changes: 1 addition & 1 deletion trolo/loaders/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import Callable, List, Dict

from ..utils.logging import LOGGER
from ..utils.logging import LOGGER


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion trolo/loaders/yaml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .registry import create_from_config
from .yaml_utils import load_config, merge_config, merge_dict

from ..utils.logging import LOGGER
from ..utils.logging import LOGGER


class YAMLConfig(BaseConfig):
Expand Down
1 change: 1 addition & 0 deletions trolo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def init_distributed_mode(device="cpu"):
logger.warning(f"Distributed training initialization failed: {e}")
logger.info("Running in non-distributed mode.")


def train_model(
config: str,
resume: str = None,
Expand Down
8 changes: 6 additions & 2 deletions trolo/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,12 @@ def _load_separate_configs(self, model, dataset, **overrides) -> YAMLConfig:
raise TypeError(f"Unsupported dataset type: {type(dataset)}")

# Print configs before merge for debugging
LOGGER.info(f"Model config transforms: {model_config.get('train_dataloader', {}).get('dataset', {}).get('transforms')}")
LOGGER.info(f"Dataset config transforms: {dataset_config.get('train_dataloader', {}).get('dataset', {}).get('transforms')}")
LOGGER.info(
f"Model config transforms: {model_config.get('train_dataloader', {}).get('dataset', {}).get('transforms')}"
)
LOGGER.info(
f"Dataset config transforms: {dataset_config.get('train_dataloader', {}).get('dataset', {}).get('transforms')}"
)

# Merge configs
cfg = YAMLConfig.merge_configs(model_config, dataset_config, **overrides)
Expand Down
2 changes: 1 addition & 1 deletion trolo/trainers/clas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def evaluate(model, criterion, dataloader, device):
metric_logger.update(**reduced_values)

metric_logger.synchronize_between_processes()
logger.info(f"Averaged stats: {metric_logger}" )
logger.info(f"Averaged stats: {metric_logger}")

stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return stats
2 changes: 1 addition & 1 deletion trolo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def fit(
args = self.cfg

n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(f"Number of params: {n_parameters}" )
logger.info(f"Number of params: {n_parameters}")

output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion trolo/trainers/det_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train_one_epoch(
loss_value = sum(loss_dict_reduced.values())

if not math.isfinite(loss_value):
LOGGER.info(F"Loss is {loss_value}, stopping training")
LOGGER.info(f"Loss is {loss_value}, stopping training")
LOGGER.info(f"{loss_dict_reduced}")
sys.exit(1)

Expand Down
1 change: 1 addition & 0 deletions trolo/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# from torch.utils.data.dataloader import DataLoader
from ..data import DataLoader


def setup_distributed(
print_rank: int = 0,
print_method: str = "builtin",
Expand Down
2 changes: 0 additions & 2 deletions trolo/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,3 @@ def log_every(self, iterable, print_freq, header=None):
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))


28 changes: 18 additions & 10 deletions trolo/utils/logging/glob_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import sys
import logging


def platform_safe_emojis(emoji_str=""):
"""Return emoji-safe version of the string."""
return emoji.emojize(emoji_str, language='alias')
return emoji.emojize(emoji_str, language="alias")


class ColorLogger(logging.Formatter):
"""
Custom formatter with colors & emojis for all log levels.
"""

# No color for INFO
COLORS = {
"DEBUG": Fore.BLUE,
Expand All @@ -26,24 +29,27 @@ class ColorLogger(logging.Formatter):
"ERROR": "❌",
"CRITICAL": "🔥",
}

def format(self, record):
levelname = record.levelname
message = super().format(record)
color = self.COLORS.get(levelname, Style.RESET_ALL)
emoji_symbol = platform_safe_emojis(self.EMOJIS.get(levelname, ""))
return f"{color}{emoji_symbol} {message}{Style.RESET_ALL}"


def add_separator_method(logger):
"""
Add a separator method to the logger.
:param logger: The logger instance
:return: The logger with added separator method
"""
def separator(text, char='-', width=85):

def separator(text, char="-", width=85):
"""
Create a separator line with optional text.
:param text: Text to display in the separator
:param char: Character to use for separator
:param width: Total width of the separator
Expand All @@ -52,25 +58,26 @@ def separator(text, char='-', width=85):
total_width = width
text_with_spaces = f" {text} "
padding_length = (total_width - len(text_with_spaces)) // 2

# Create separator line
separator_line = char * padding_length + text_with_spaces + char * padding_length

# Ensure exact width by trimming or padding
separator_line = separator_line[:total_width]
separator_line = separator_line.ljust(total_width, char)

# Log the separator
logger.info(separator_line)

# Attach the separator method to the logger
logger.separator = separator
return logger


def configure_logger(name="default_logger", verbose=True, rank=0):
"""
Configure logger with color and emoji formatting and customizable verbosity.
:param name: Name of the logger
:param verbose: Whether to set logging level to INFO
:param rank: Rank of the process (to control logging)
Expand All @@ -91,5 +98,6 @@ def configure_logger(name="default_logger", verbose=True, rank=0):

return logger


# Lazy initialization
LOGGER = add_separator_method(configure_logger())
LOGGER = add_separator_method(configure_logger())

0 comments on commit f2dc224

Please sign in to comment.