Skip to content

Commit

Permalink
update: modify file structure and tidy code
Browse files Browse the repository at this point in the history
  • Loading branch information
Howeng98 committed Nov 30, 2022
1 parent e518acd commit 5ad0edb
Show file tree
Hide file tree
Showing 13 changed files with 15 additions and 19 deletions.
Binary file added codes/__pycache__/datasets.cpython-38.pyc
Binary file not shown.
Binary file added codes/__pycache__/inspection.cpython-38.pyc
Binary file not shown.
Binary file added codes/__pycache__/mvtecad.cpython-38.pyc
Binary file not shown.
Binary file added codes/__pycache__/nearest_neighbor.cpython-38.pyc
Binary file not shown.
Binary file added codes/__pycache__/networks.cpython-38.pyc
Binary file not shown.
Binary file added codes/__pycache__/utils.cpython-38.pyc
Binary file not shown.
File renamed without changes.
File renamed without changes.
22 changes: 7 additions & 15 deletions codes/codes/mvtecad.py → codes/mvtecad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,20 @@
from sklearn.metrics import roc_auc_score, precision_recall_curve, roc_curve


DATASET_PATH = '/home/ginger/mvtec_anomaly_detection'
DATASET_PATH = '../MvTecAD/'


__all__ = ['objs', 'set_root_path',
__all__ = ['objs',
'get_x', 'get_x_standardized',
'detection_auroc', 'segmentation_auroc']

objs = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
'transistor', 'wood', 'zipper']


def resize(image, shape=(256, 256)):
return np.array(Image.fromarray(image).resize(shape[::-1]))


def bilinears(images, shape) -> np.ndarray:
import cv2
N = images.shape[0]
Expand All @@ -39,14 +37,9 @@ def gray2rgb(images):
return images


def set_root_path(new_path):
global DATASET_PATH
DATASET_PATH = new_path


def get_x(obj, mode='train'):
fpattern = os.path.join(DATASET_PATH, f'{obj}/{mode}/*/*.png')
fpaths = sorted(glob(fpattern))
fpattern = os.path.join(DATASET_PATH, f'{obj}/{mode}/*/*.png')
fpaths = sorted(glob(fpattern))

if mode == 'test':
fpaths1 = list(filter(lambda fpath: os.path.basename(os.path.dirname(fpath)) != 'good', fpaths))
Expand All @@ -62,13 +55,12 @@ def get_x(obj, mode='train'):
if images.shape[-1] != 3:
images = gray2rgb(images)
images = list(map(resize, images))
images = np.asarray(images)

images = np.asarray(images)
return images


def get_x_standardized(obj, mode='train'):
x = get_x(obj, mode=mode)
x = get_x(obj, mode=mode)
mean = get_mean(obj)
return (x.astype(np.float32) - mean) / 255

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
12 changes: 8 additions & 4 deletions main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@
device = 'cuda'
torch.backends.cudnn.benchmark = True


# logging AUROC results
newline = '\n'
if not os.path.isdir(f'./log_result'):
os.mkdir(f'./log_result')
if not os.path.isdir(f'./ckpts'):
os.mkdir(f'./ckpts')

LOG = f'./log_result/AUROC_{args.obj}.log'
logging.basicConfig(filename=LOG, filemode="w", level=logging.INFO)
logging.info(f' [class:{args.obj}, lambda:{args.lambda_value}, learning rate:{args.lr}, total training epochs:{args.epochs}, groups_64:{args.groups_64}, groups_32:{args.groups_32}, groups_16:{args.groups_16}]{newline}{newline}')


if not os.path.isdir(f'ckpts/{args.obj}'):
os.mkdir(f'ckpts/{args.obj}')
if not os.path.isdir(f'./ckpts/{args.obj}'):
os.mkdir(f'./ckpts/{args.obj}')


def train():
Expand Down Expand Up @@ -73,7 +77,7 @@ def train():
opt = torch.optim.Adam(params=params, lr=lr)

with task('Datasets'):
train_x = mvtecad.get_x_standardized(obj, mode='train')
train_x = mvtecad.get_x_standardized(obj, mode='train')
train_x = NHWC2NCHW(train_x)

rep = 100
Expand Down

0 comments on commit 5ad0edb

Please sign in to comment.