Skip to content

Commit

Permalink
updated mount path to match grading.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
shrbb committed Sep 21, 2024
1 parent 34c0aa0 commit e81fb31
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 43 deletions.
10 changes: 3 additions & 7 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@ services:
context: .
dockerfile: model-train/Dockerfile.train
volumes:
- mnist:/opt/mount/model
- data:/opt/mount/data
- mnist:/opt/mount

evaluate:
build:
context: .
dockerfile: model-eval/Dockerfile.eval
volumes:
- mnist:/opt/mount/model
- data:/opt/mount/data
- mnist:/opt/mount

infer:
build:
context: .
dockerfile: model-infer/Dockerfile.infer
volumes:
- mnist:/opt/mount/model
- data:/opt/mount/data
- mnist:/opt/mount

volumes:
mnist:
data:
8 changes: 4 additions & 4 deletions model-eval/Dockerfile.eval
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
FROM shravankgl/elmo4-tsai:session2

WORKDIR /opt/mount
WORKDIR /workdir

COPY model.py /opt/mount/
COPY model-eval/eval.py /opt/mount/
COPY model.py /workdir/
COPY model-eval/eval.py /workdir/

CMD ["python", "eval.py"]
CMD ["python", "eval.py", "--save-dir", "/opt/mount"]
11 changes: 6 additions & 5 deletions model-eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test(args, model, device, dataset, dataloader_kwargs):

test_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

test_epoch(model, device, test_loader)
return test_epoch(model, device, test_loader)

def test_epoch(model, device, data_loader):
# write code to test this epoch
Expand Down Expand Up @@ -66,15 +66,16 @@ def main():
)

# create MNIST test dataset and loader
test_dataset = datasets.MNIST('./data', train=False, download=False, transform=transform)
test_dataset = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=False, download=False, transform=transform)

device = torch.device("cpu")
model = Net().to(device)

# create model and load state dict
if os.path.isfile("./model/mnist_cnn.pt"):
# create model and load state dict
model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt")
if os.path.isfile(model_checkpoint_path):
print("Loading model_checkpoint")
model.load_state_dict(torch.load("./model/mnist_cnn.pt"))
model.load_state_dict(torch.load(model_checkpoint_path))

# test epoch function call
kwargs = {'batch_size': args.test_batch_size}
Expand Down
10 changes: 4 additions & 6 deletions model-infer/Dockerfile.infer
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
FROM shravankgl/elmo4-tsai:session2

WORKDIR /opt/mount
WORKDIR /workdir

COPY model.py /opt/mount/
COPY model-infer/infer.py /opt/mount/
COPY model.py /workdir/
COPY model-infer/infer.py /workdir/

RUN mkdir -p /opt/mount/model/results

CMD ["python", "infer.py"]
CMD ["python", "infer.py", "--save-dir", "/opt/mount"]
35 changes: 22 additions & 13 deletions model-infer/infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import json
import os
import time
import random
import torch
Expand Down Expand Up @@ -28,28 +30,35 @@ def infer(model, dataset, save_dir, num_samples=5):
img = Image.fromarray(image.squeeze().numpy() * 255).convert("L")

# Save the image with the predicted label as filename
img.save(results_dir / f"{pred}.png")
img.save(results_dir / f"{idx}_{pred}.png")

print(f"Saved image as {pred}.png in {results_dir}")


def main():
# Directory where results will be saved (inside the mounted mnist volume)
save_dir = "./model"
parser = argparse.ArgumentParser(description="MNIST Evaluation Script")
parser.add_argument(
"--save-dir", default="./", help="checkpoint will be saved in this directory"
)

args = parser.parse_args()

# Initialize the model and load checkpoint
model = Net() # Replace Net with your model's architecture
model.load_state_dict(torch.load("./model/mnist_cnn.pt")) # Load from the volume
model.eval() # Set model to evaluation mode
model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt")
if os.path.isfile(model_checkpoint_path):
print("Loading model_checkpoint")
model = Net()
model.load_state_dict(torch.load(model_checkpoint_path))

model.eval() # Set model to evaluation mode

# Create transformations for the MNIST dataset (normalize to [0,1] range)
transform = transforms.Compose([transforms.ToTensor()])
# Create transformations for the MNIST dataset (normalize to [0,1] range)
transform = transforms.Compose([transforms.ToTensor()])

# Load MNIST test dataset (download=True ensures it is downloaded if not present)
dataset = datasets.MNIST(root="./data", train=False, download=False, transform=transform)
# Load MNIST test dataset (download=True ensures it is downloaded if not present)
dataset = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=False, download=False, transform=transform)

# Run inference on the dataset and save results
infer(model, dataset, save_dir)
# Run inference on the dataset and save results
infer(model, dataset, args.save_dir)
print("Inference completed. Results saved in the 'results' folder inside the mnist volume.")


Expand Down
8 changes: 4 additions & 4 deletions model-train/Dockerfile.train
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
FROM shravankgl/elmo4-tsai:session2

WORKDIR /opt/mount
WORKDIR /workdir

COPY model.py /opt/mount/
COPY model-train/train.py /opt/mount/
COPY model.py /workdir/
COPY model-train/train.py /workdir/

CMD ["python", "train.py", "--epochs", "1", "--num-processes", "2"]
CMD ["python", "train.py", "--epochs", "1", "--num-processes", "2", "--save-dir", "/opt/mount"]
11 changes: 7 additions & 4 deletions model-train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ def main():
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset1 = datasets.MNIST(os.path.join(args.save_dir, 'data'), train=True, download=True, transform=transform)
kwargs = {'batch_size': args.batch_size,
'shuffle': True}

# mnist hogwild training process
if os.path.isfile("./model/mnist_cnn.pt"):
model_checkpoint_path = os.path.join(args.save_dir, "model/mnist_cnn.pt")
if os.path.isfile(model_checkpoint_path):
print("Loading model_checkpoint")
model.load_state_dict(torch.load("./model/mnist_cnn.pt"))
model.load_state_dict(torch.load(model_checkpoint_path))
else:
print("No model checkpoint found. Starting from scratch.")

Expand All @@ -136,7 +137,9 @@ def main():


# save model ckpt
torch.save(model.state_dict(), "./model/mnist_cnn.pt")
results_dir = Path(os.path.join(args.save_dir, "model"))
results_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), results_dir / "mnist_cnn.pt")

if __name__ == "__main__":
main()

0 comments on commit e81fb31

Please sign in to comment.