Skip to content

Commit

Permalink
add notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jzenn committed Nov 2, 2023
1 parent 2e40870 commit a0ca5bc
Showing 1 changed file with 246 additions and 0 deletions.
246 changes: 246 additions & 0 deletions notebooks/svhn_remix.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "d9e84836-d1eb-4acc-ac9e-8e19ce3bb423",
"metadata": {},
"outputs": [],
"source": [
"import os"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bdbf297-1b13-4bd3-8588-d10801b7448b",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy.io as sio\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ded5262c-5d5b-41e0-94db-1c21c35c992f",
"metadata": {},
"outputs": [],
"source": [
"##"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d215bfa0-d644-4e04-ab5c-98699f823a48",
"metadata": {},
"outputs": [],
"source": [
"path_to_svhn_dataset = \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "befd0491-fe8d-4e00-99b2-eb1f1da23a15",
"metadata": {},
"outputs": [],
"source": [
"def open_svhn(mat_path: str, train=True):\n",
" loaded_mat = sio.loadmat(mat_path)\n",
"\n",
" images = loaded_mat[\"X\"]\n",
" print(\"X.shape\", images.shape)\n",
" print(\"X.dtype\", images.dtype)\n",
"\n",
" labels = loaded_mat[\"y\"].squeeze()\n",
" print(\"y.shape\", loaded_mat[\"y\"].shape)\n",
" print(\"y.dtype\", loaded_mat[\"y\"].dtype)\n",
"\n",
" np.place(labels, labels == 10, 0)\n",
" images = np.transpose(images, (3, 0, 1, 2))\n",
"\n",
" assert (\n",
" images.shape == ((73257 if train else 26032), 32, 32, 3)\n",
" and images.dtype == np.uint8\n",
" )\n",
" assert labels.shape == ((73257 if train else 26032),) and labels.dtype == np.uint8\n",
" assert np.min(images) == 0 and np.max(images) == 255\n",
" assert np.min(labels) == 0 and np.max(labels) == 9\n",
"\n",
" return images, labels, loaded_mat[\"X\"], loaded_mat[\"y\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fce351b7-6fd7-4046-8706-a3eced395455",
"metadata": {},
"outputs": [],
"source": [
"train_images, train_labels, train_X, train_y = open_svhn(\n",
" os.path.join(path_to_svhn_dataset, \"train_32x32.mat\"), train=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a79f95af-3d54-4da1-bcf6-68c65e6507b4",
"metadata": {},
"outputs": [],
"source": [
"test_images, test_labels, test_X, test_y = open_svhn(\n",
" os.path.join(path_to_svhn_dataset, \"test_32x32.mat\"), train=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bee0f789-336b-4020-9f85-900bd8fe3fdf",
"metadata": {},
"outputs": [],
"source": [
"##"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4df9d24d-4f4d-42a9-a7af-1829c3305d6e",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"\n",
"all_images = np.concatenate([train_images, test_images])\n",
"all_labels = np.concatenate([train_labels, test_labels])\n",
"\n",
"all_idx = np.arange(len(all_images))\n",
"new_train_idx = list()\n",
"new_test_idx = list()\n",
"\n",
"new_train_labels = list()\n",
"new_test_labels = list()\n",
"\n",
"new_train_images = list()\n",
"new_test_images = list()\n",
"\n",
"for class_idx in list(range(10)):\n",
" all_labels_idx = all_labels == class_idx\n",
" train_labels_idx = train_labels == class_idx\n",
" test_labels_idx = test_labels == class_idx\n",
"\n",
" n_of_test = int(test_labels_idx.sum())\n",
" n_of_train = int(train_labels_idx.sum())\n",
"\n",
" new_class_test_indices = np.random.choice(\n",
" n_of_test + n_of_train, n_of_test, replace=False\n",
" )\n",
" new_class_train_indices = np.array(\n",
" [i for i in range(n_of_test + n_of_train) if i not in new_class_test_indices]\n",
" )\n",
"\n",
" new_train_labels.extend(all_labels[all_labels_idx][new_class_train_indices])\n",
" new_train_images.extend(all_images[all_labels_idx][new_class_train_indices])\n",
" new_train_idx.extend(all_idx[all_labels_idx][new_class_train_indices])\n",
"\n",
" new_test_labels.extend(all_labels[all_labels_idx][new_class_test_indices])\n",
" new_test_images.extend(all_images[all_labels_idx][new_class_test_indices])\n",
" new_test_idx.extend(all_idx[all_labels_idx][new_class_test_indices])\n",
"\n",
"\n",
"train_idx_shuffle = np.random.permutation(len(new_train_labels))\n",
"test_idx_shuffle = np.random.permutation(len(new_test_labels))\n",
"\n",
"new_train_labels = np.array(new_train_labels)[train_idx_shuffle].squeeze()\n",
"new_test_labels = np.array(new_test_labels)[test_idx_shuffle].squeeze()\n",
"\n",
"new_train_idx = np.array(new_train_idx)[train_idx_shuffle].squeeze()\n",
"new_test_idx = np.array(new_test_idx)[test_idx_shuffle].squeeze()\n",
"\n",
"new_train_images = np.array(new_train_images)[train_idx_shuffle].squeeze()\n",
"new_test_images = np.array(new_test_images)[test_idx_shuffle].squeeze()\n",
"\n",
"assert len(new_train_labels) == len(train_labels)\n",
"assert len(new_test_labels) == len(test_labels)\n",
"\n",
"for i in range(10):\n",
" assert (new_train_labels == i).sum() == (train_labels == i).sum()\n",
"\n",
"for i in range(10):\n",
" assert (new_test_labels == i).sum() == (test_labels == i).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4a59c9f-0067-473e-bb7b-e8bb18a555a4",
"metadata": {},
"outputs": [],
"source": [
"##"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3a89777-c1c2-4782-be94-4c3ac2ad63fe",
"metadata": {},
"outputs": [],
"source": [
"train_mat = {\n",
" \"X\": np.transpose(new_train_images, (1, 2, 3, 0)),\n",
" \"y\": new_train_labels.reshape(-1, 1),\n",
"}\n",
"test_mat = {\n",
" \"X\": np.transpose(new_test_images, (1, 2, 3, 0)),\n",
" \"y\": new_test_labels.reshape(-1, 1),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "332911f3-5517-4e29-b3e4-76e540dd81de",
"metadata": {},
"outputs": [],
"source": [
"sio.savemat(\"../dataset/train_32x32_remix.mat\", train_mat)\n",
"sio.savemat(\"../dataset/test_32x32_remix.mat\", test_mat)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72d479f3-390a-43b3-a2e5-fce9d9f88fe9",
"metadata": {},
"outputs": [],
"source": [
"np.save(\"../index/train_index_remix.npy\", new_train_idx)\n",
"np.save(\"../index/test_index_remix.npy\", new_test_idx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad212a80-a213-4701-b424-4a1d4f0c6cd9",
"metadata": {},
"outputs": [],
"source": [
"##"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit a0ca5bc

Please sign in to comment.