Skip to content

moskomule/cca.pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This library is deprecated in favor of anatome. I will no longer support this repository.

CCA.pytorch

PyTorch implementation of

Now GPU is set as the default device for SVD calculation.

Requirements

  • Python>=3.6
  • PyTorch>=0.4.1
  • torchvision>=0.2.1

To run example.py, you also need

Usage

from cca import CCAHook
device = "cuda" # or "cpu"
hook1 = CCAHook(model, "layer3.0.conv1", svd_device=device)
hook2 = CCAHook(model, "layer3.0.conv2", svd_device=device)
model.eval()
with torch.no_grad():
    model(torch.randn(1200, 3, 224, 224))
hook1.distance(hook2, size=8) # resize to 8x8

Example

python example.py trains ResNet-20 on CIFAR-10 for 100 epochs then measures CCA distance between a trained model and its checkpoints.

Note

While the original SVCCA uses DFT for resizing, we use global average pooling for simplicity.

About

CCAs for looking into DNNs

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages