diff --git a/docker-compose.yml b/docker-compose.yml index a871d00..0e999e5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,25 +4,21 @@ services: context: . dockerfile: model-train/Dockerfile.train volumes: - - mnist:/opt/mount/model - - data:/opt/mount/data + - mnist:/opt/mount evaluate: build: context: . dockerfile: model-eval/Dockerfile.eval volumes: - - mnist:/opt/mount/model - - data:/opt/mount/data + - mnist:/opt/mount infer: build: context: . dockerfile: model-infer/Dockerfile.infer volumes: - - mnist:/opt/mount/model - - data:/opt/mount/data + - mnist:/opt/mount volumes: mnist: - data: diff --git a/model-eval/Dockerfile.eval b/model-eval/Dockerfile.eval index de66472..cbb9301 100644 --- a/model-eval/Dockerfile.eval +++ b/model-eval/Dockerfile.eval @@ -1,8 +1,8 @@ FROM shravankgl/elmo4-tsai:session2 -WORKDIR /opt/mount +WORKDIR /workdir -COPY model.py /opt/mount/ -COPY model-eval/eval.py /opt/mount/ +COPY model.py /workdir/ +COPY model-eval/eval.py /workdir/ -CMD ["python", "eval.py"] \ No newline at end of file +CMD ["python", "eval.py", "--save-dir", "/opt/mount"] \ No newline at end of file diff --git a/model-eval/eval.py b/model-eval/eval.py index 64f1446..78a22b9 100644 --- a/model-eval/eval.py +++ b/model-eval/eval.py @@ -15,7 +15,7 @@ def test(args, model, device, dataset, dataloader_kwargs): test_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs) - test_epoch(model, device, test_loader) + return test_epoch(model, device, test_loader) def test_epoch(model, device, data_loader): # write code to test this epoch @@ -66,15 +66,16 @@ def main(): ) # create MNIST test dataset and loader - test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transform) + test_dataset = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=False, download=False, transform=transform) device = torch.device("cpu") model = Net().to(device) - # create model and load state dict - if os.path.isfile("./model/mnist_cnn.pt"): + # create model and load state dict + model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt") + if os.path.isfile(model_checkpoint_path): print("Loading model_checkpoint") - model.load_state_dict(torch.load("./model/mnist_cnn.pt")) + model.load_state_dict(torch.load(model_checkpoint_path)) # test epoch function call kwargs = {'batch_size': args.test_batch_size} diff --git a/model-infer/Dockerfile.infer b/model-infer/Dockerfile.infer index 41ecc0b..e829ea9 100644 --- a/model-infer/Dockerfile.infer +++ b/model-infer/Dockerfile.infer @@ -1,10 +1,8 @@ FROM shravankgl/elmo4-tsai:session2 -WORKDIR /opt/mount +WORKDIR /workdir -COPY model.py /opt/mount/ -COPY model-infer/infer.py /opt/mount/ +COPY model.py /workdir/ +COPY model-infer/infer.py /workdir/ -RUN mkdir -p /opt/mount/model/results - -CMD ["python", "infer.py"] \ No newline at end of file +CMD ["python", "infer.py", "--save-dir", "/opt/mount"] \ No newline at end of file diff --git a/model-infer/infer.py b/model-infer/infer.py index 485ccb7..b08e1c0 100644 --- a/model-infer/infer.py +++ b/model-infer/infer.py @@ -1,4 +1,6 @@ +import argparse import json +import os import time import random import torch @@ -28,28 +30,35 @@ def infer(model, dataset, save_dir, num_samples=5): img = Image.fromarray(image.squeeze().numpy() * 255).convert("L") # Save the image with the predicted label as filename - img.save(results_dir / f"{pred}.png") + img.save(results_dir / f"{idx}_{pred}.png") print(f"Saved image as {pred}.png in {results_dir}") def main(): - # Directory where results will be saved (inside the mounted mnist volume) - save_dir = "./model" + parser = argparse.ArgumentParser(description="MNIST Evaluation Script") + parser.add_argument( + "--save-dir", default="./", help="checkpoint will be saved in this directory" + ) + + args = parser.parse_args() - # Initialize the model and load checkpoint - model = Net() # Replace Net with your model's architecture - model.load_state_dict(torch.load("./model/mnist_cnn.pt")) # Load from the volume - model.eval() # Set model to evaluation mode + model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt") + if os.path.isfile(model_checkpoint_path): + print("Loading model_checkpoint") + model = Net() + model.load_state_dict(torch.load(model_checkpoint_path)) + + model.eval() # Set model to evaluation mode - # Create transformations for the MNIST dataset (normalize to [0,1] range) - transform = transforms.Compose([transforms.ToTensor()]) + # Create transformations for the MNIST dataset (normalize to [0,1] range) + transform = transforms.Compose([transforms.ToTensor()]) - # Load MNIST test dataset (download=True ensures it is downloaded if not present) - dataset = datasets.MNIST(root="./data", train=False, download=False, transform=transform) + # Load MNIST test dataset (download=True ensures it is downloaded if not present) + dataset = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=False, download=False, transform=transform) - # Run inference on the dataset and save results - infer(model, dataset, save_dir) + # Run inference on the dataset and save results + infer(model, dataset, args.save_dir) print("Inference completed. Results saved in the 'results' folder inside the mnist volume.") diff --git a/model-train/Dockerfile.train b/model-train/Dockerfile.train index a5c0d52..f3b15bb 100644 --- a/model-train/Dockerfile.train +++ b/model-train/Dockerfile.train @@ -1,8 +1,8 @@ FROM shravankgl/elmo4-tsai:session2 -WORKDIR /opt/mount +WORKDIR /workdir -COPY model.py /opt/mount/ -COPY model-train/train.py /opt/mount/ +COPY model.py /workdir/ +COPY model-train/train.py /workdir/ -CMD ["python", "train.py", "--epochs", "1", "--num-processes", "2"] \ No newline at end of file +CMD ["python", "train.py", "--epochs", "1", "--num-processes", "2", "--save-dir", "/opt/mount"] \ No newline at end of file diff --git a/model-train/train.py b/model-train/train.py index bf92c79..645d654 100644 --- a/model-train/train.py +++ b/model-train/train.py @@ -113,14 +113,15 @@ def main(): transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) - dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) + dataset1 = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=True, download=True, transform=transform) kwargs = {'batch_size': args.batch_size, 'shuffle': True} # mnist hogwild training process - if os.path.isfile("./model/mnist_cnn.pt"): + model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt") + if os.path.isfile(model_checkpoint_path): print("Loading model_checkpoint") - model.load_state_dict(torch.load("./model/mnist_cnn.pt")) + model.load_state_dict(torch.load(model_checkpoint_path)) else: print("No model checkpoint found. Starting from scratch.") @@ -136,7 +137,9 @@ def main(): # save model ckpt - torch.save(model.state_dict(), "./model/mnist_cnn.pt") + results_dir = Path(os.path.join(args.save_dir, "model")) + results_dir.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), results_dir / "mnist_cnn.pt") if __name__ == "__main__": main()