-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3870175
commit 0ffce54
Showing
28 changed files
with
1,641 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import argparse | ||
|
||
|
||
def train_options(): | ||
parser = argparse.ArgumentParser(description="Training script.") | ||
parser.add_argument( | ||
"-exp", | ||
"--experiment", | ||
default="0483mse", | ||
type=str, | ||
required=False, | ||
help="Experiment name" | ||
) | ||
parser.add_argument( | ||
"-d", | ||
"--dataset", | ||
default="/home/npr/dataset/", | ||
type=str, | ||
required=False, | ||
help="Training dataset" | ||
) | ||
parser.add_argument( | ||
"-e", | ||
"--epochs", | ||
default=60000, | ||
type=int, | ||
help="Number of epochs (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"-lr", | ||
"--learning-rate", | ||
default=1e-4, | ||
type=float, | ||
help="Learning rate (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--num-workers", | ||
type=int, | ||
default=8, | ||
help="Dataloaders threads (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--lambda", | ||
dest="lmbda", | ||
type=float, | ||
default=0.0483, | ||
help="Bit-rate distortion parameter (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--metrics", | ||
type=str, | ||
default="mse", | ||
help="Optimized for (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=8, | ||
help="Batch size (default: %(default)s)" | ||
) | ||
parser.add_argument( | ||
"--test-batch-size", | ||
type=int, | ||
default=1, | ||
help="Test batch size (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--aux-learning-rate", | ||
default=1e-3, | ||
help="Auxiliary loss learning rate (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--patch-size", | ||
type=int, | ||
nargs=2, | ||
default=(256, 256), | ||
help="Size of the patches to be cropped (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--gpu_id", | ||
type=int, | ||
default=0, | ||
help="GPU ID" | ||
) | ||
parser.add_argument( | ||
"--cuda", | ||
default=True, | ||
help="Use cuda" | ||
) | ||
parser.add_argument( | ||
"--save", | ||
default=True, | ||
help="Save model to disk" | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=float, | ||
default=192.1, | ||
help="Set random seed for reproducibility" | ||
) | ||
parser.add_argument( | ||
"--clip_max_norm", | ||
default=1.0, | ||
type=float, | ||
help="gradient clipping max norm (default: %(default)s", | ||
) | ||
parser.add_argument( | ||
"-c", | ||
"--checkpoint", | ||
default=None, | ||
type=str, | ||
help="pretrained model path" | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def test_options(): | ||
parser = argparse.ArgumentParser(description="Testing script.") | ||
parser.add_argument( | ||
"-exp", | ||
"--experiment", | ||
default="elic_test", | ||
type=str, | ||
required=False, | ||
help="Experiment name" | ||
) | ||
parser.add_argument( | ||
"--codestream_path", | ||
default="experiments/elic_0800/codestream/100", | ||
type=str, | ||
required=False, | ||
help="Path to the codestream" | ||
) | ||
parser.add_argument( | ||
"-d", | ||
"--dataset", | ||
default="/home/npr/dataset/", | ||
type=str, | ||
required=False, | ||
help="Training dataset" | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--num-workers", | ||
type=int, | ||
default=1, | ||
help="Dataloaders threads (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--metrics", | ||
type=str, | ||
default="mse", | ||
help="Optimized for (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--test-batch-size", | ||
type=int, | ||
default=1, | ||
help="Test batch size (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--gpu_id", | ||
type=int, | ||
default=0, | ||
help="GPU ID" | ||
) | ||
parser.add_argument( | ||
"--cuda", | ||
default=True, | ||
help="Use cuda" | ||
) | ||
parser.add_argument( | ||
"--save", | ||
default=True, | ||
help="Save model to disk" | ||
) | ||
parser.add_argument( | ||
"-c", | ||
"--checkpoint", | ||
default=None, | ||
type=str, | ||
help="pretrained model path" | ||
) | ||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from utils.utils import Config | ||
|
||
def model_config(): | ||
config = Config({ | ||
# MLIC and MLIC+ | ||
"N": 192, | ||
"M": 320, | ||
"slice_num": 10, | ||
"context_window": 5, | ||
"slice_ch": [8, 8, 8, 8, 16, 16, 32, 32, 96, 96], | ||
"quant": "ste", | ||
"elic_lambda_list": [0.05, 0.07, 0.09, 0.11], | ||
"mlicex_lambda_list": [0.04, 0.07, 0.075, 0.09, 0.11], | ||
"interpolated_type": "exponential", | ||
}) | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .rd_loss import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
from pytorch_msssim import ms_ssim | ||
|
||
|
||
class RateDistortionLoss(nn.Module): | ||
"""Custom rate distortion loss with a Lagrangian parameter.""" | ||
|
||
def __init__(self, lmbda=1e-2, metrics='mse'): | ||
super().__init__() | ||
self.mse = nn.MSELoss() | ||
self.lmbda = lmbda | ||
self.metrics = metrics | ||
|
||
def forward(self, output, target): | ||
N, _, H, W = target.size() | ||
out = {} | ||
num_pixels = N * H * W | ||
|
||
out["bpp_loss"] = sum( | ||
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) | ||
for likelihoods in output["likelihoods"].values() | ||
) | ||
if self.metrics == 'mse': | ||
out["mse_loss"] = self.mse(output["x_hat"], target) | ||
out["ms_ssim_loss"] = None | ||
out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] | ||
elif self.metrics == 'ms-ssim': | ||
out["mse_loss"] = None | ||
out["ms_ssim_loss"] = 1 - ms_ssim(output["x_hat"], target, data_range=1.0) | ||
out["loss"] = self.lmbda * out["ms_ssim_loss"] + out["bpp_loss"] | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .elic import ELIC |
Oops, something went wrong.