Skip to content

Commit

Permalink
release code
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangWeibeta committed Apr 26, 2023
1 parent 3870175 commit 0ffce54
Show file tree
Hide file tree
Showing 28 changed files with 1,641 additions and 0 deletions.
187 changes: 187 additions & 0 deletions config/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import argparse


def train_options():
parser = argparse.ArgumentParser(description="Training script.")
parser.add_argument(
"-exp",
"--experiment",
default="0483mse",
type=str,
required=False,
help="Experiment name"
)
parser.add_argument(
"-d",
"--dataset",
default="/home/npr/dataset/",
type=str,
required=False,
help="Training dataset"
)
parser.add_argument(
"-e",
"--epochs",
default=60000,
type=int,
help="Number of epochs (default: %(default)s)",
)
parser.add_argument(
"-lr",
"--learning-rate",
default=1e-4,
type=float,
help="Learning rate (default: %(default)s)",
)
parser.add_argument(
"-n",
"--num-workers",
type=int,
default=8,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
"--lambda",
dest="lmbda",
type=float,
default=0.0483,
help="Bit-rate distortion parameter (default: %(default)s)",
)
parser.add_argument(
"--metrics",
type=str,
default="mse",
help="Optimized for (default: %(default)s)",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size (default: %(default)s)"
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1,
help="Test batch size (default: %(default)s)",
)
parser.add_argument(
"--aux-learning-rate",
default=1e-3,
help="Auxiliary loss learning rate (default: %(default)s)",
)
parser.add_argument(
"--patch-size",
type=int,
nargs=2,
default=(256, 256),
help="Size of the patches to be cropped (default: %(default)s)",
)
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="GPU ID"
)
parser.add_argument(
"--cuda",
default=True,
help="Use cuda"
)
parser.add_argument(
"--save",
default=True,
help="Save model to disk"
)
parser.add_argument(
"--seed",
type=float,
default=192.1,
help="Set random seed for reproducibility"
)
parser.add_argument(
"--clip_max_norm",
default=1.0,
type=float,
help="gradient clipping max norm (default: %(default)s",
)
parser.add_argument(
"-c",
"--checkpoint",
default=None,
type=str,
help="pretrained model path"
)
args = parser.parse_args()
return args


def test_options():
parser = argparse.ArgumentParser(description="Testing script.")
parser.add_argument(
"-exp",
"--experiment",
default="elic_test",
type=str,
required=False,
help="Experiment name"
)
parser.add_argument(
"--codestream_path",
default="experiments/elic_0800/codestream/100",
type=str,
required=False,
help="Path to the codestream"
)
parser.add_argument(
"-d",
"--dataset",
default="/home/npr/dataset/",
type=str,
required=False,
help="Training dataset"
)
parser.add_argument(
"-n",
"--num-workers",
type=int,
default=1,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
"--metrics",
type=str,
default="mse",
help="Optimized for (default: %(default)s)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1,
help="Test batch size (default: %(default)s)",
)
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="GPU ID"
)
parser.add_argument(
"--cuda",
default=True,
help="Use cuda"
)
parser.add_argument(
"--save",
default=True,
help="Save model to disk"
)
parser.add_argument(
"-c",
"--checkpoint",
default=None,
type=str,
help="pretrained model path"
)
args = parser.parse_args()
return args
17 changes: 17 additions & 0 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from utils.utils import Config

def model_config():
config = Config({
# MLIC and MLIC+
"N": 192,
"M": 320,
"slice_num": 10,
"context_window": 5,
"slice_ch": [8, 8, 8, 8, 16, 16, 32, 32, 96, 96],
"quant": "ste",
"elic_lambda_list": [0.05, 0.07, 0.09, 0.11],
"mlicex_lambda_list": [0.04, 0.07, 0.075, 0.09, 0.11],
"interpolated_type": "exponential",
})

return config
1 change: 1 addition & 0 deletions loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rd_loss import *
34 changes: 34 additions & 0 deletions loss/rd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import math
import torch
import torch.nn as nn
from pytorch_msssim import ms_ssim


class RateDistortionLoss(nn.Module):
"""Custom rate distortion loss with a Lagrangian parameter."""

def __init__(self, lmbda=1e-2, metrics='mse'):
super().__init__()
self.mse = nn.MSELoss()
self.lmbda = lmbda
self.metrics = metrics

def forward(self, output, target):
N, _, H, W = target.size()
out = {}
num_pixels = N * H * W

out["bpp_loss"] = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in output["likelihoods"].values()
)
if self.metrics == 'mse':
out["mse_loss"] = self.mse(output["x_hat"], target)
out["ms_ssim_loss"] = None
out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
elif self.metrics == 'ms-ssim':
out["mse_loss"] = None
out["ms_ssim_loss"] = 1 - ms_ssim(output["x_hat"], target, data_range=1.0)
out["loss"] = self.lmbda * out["ms_ssim_loss"] + out["bpp_loss"]

return out
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .elic import ELIC
Loading

0 comments on commit 0ffce54

Please sign in to comment.