-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmain.m
43 lines (38 loc) · 1.99 KB
/
main.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% This script is to perform the micro-expression recognition using STSTNet with LOSOCV protocol.
% Reference:
% Liong, S. T., Gan, Y. S., See, J., Khor, H. Q., & Huang, Y. C. (2019, May). Shallow triple stream three-dimensional cnn (ststnet) for micro-expression recognition. In 2019 14th IEEE International Conference on Automatic Face & Gesture Recognition (FG 2019) (pp. 1-5). IEEE.
%
% The files include:
% 1) main.m : Script which trains and tests the STSTNet
% 2) STSTNet.mat : The STSTNet architecture design
% 3) video442subName.txt : List of subject's name
% 4) input : Input data (28x28x3) arranged in LOSOCV manner
%
% Matlab version was written by Sze Teng Liong and was tested on Matlab 2018b
% If you have any problem, please feel free to contact Sze Teng Liong ([email protected])
%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Read the 68 subject names
fid = fopen('video442subName.txt');
garbage = textscan(fid,'%s','delimiter','\n');
subName =garbage{1};
% Load STSTNet
load ('STSTNet.mat')
% Network configuration
opts = trainingOptions('adam', 'InitialLearnRate', 0.00005, 'MaxEpochs', 500, 'MiniBatchSize', 256,'Plots','training-progress');
% LOSOCV train and test
for nSub = 1:length(subName)
% Read train images and labels
cd (['input\' , subName{nSub,:}]);
trainingImages = imageDatastore('u_train', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
% Train model
myNet = trainNetwork(trainingImages, STSTNet,opts);
cd ('..\..')
% Read test images and labels
cd (['input\' , subName{nSub,:}])
testImages = imageDatastore('u_test', 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
desiredLabels = testImages.Labels;
% Test images using trained model
predictedLabels = classify(myNet, testImages);
cd ('..\..')
end