-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_test_simple.cpp
280 lines (224 loc) · 8.44 KB
/
main_test_simple.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
// =========================================================================================
//
// =========================================================================================
// Original message from Peter Kontschieder and Samuel Rota Bulò:
// Structured Class-Label in Random Forests. This is a re-implementation of
// the work we presented at ICCV'11 in Barcelona, Spain.
//
// In case of using this code, please cite the following paper:
// P. Kontschieder, S. Rota Bulò, H. Bischof and M. Pelillo.
// Structured Class-Labels in Random Forests for Semantic Image Labelling. In (ICCV), 2011.
//
// Implementation by Peter Kontschieder and Samuel Rota Bulò
// October 2013
//
// =========================================================================================
#include <iostream>
#include <unistd.h>
#include <omp.h>
#include <sys/stat.h>
#include "Global.h"
#include "ConfigReader.h"
#include "ImageData.h"
#include "ImageDataFloat.h"
#include "SemanticSegmentationForests.h"
#include "StrucClassSSF.h"
#include "label.h"
using namespace std;
using namespace vision;
/***************************************************************************
USAGE
***************************************************************************/
void usage (char *com)
{
std::cerr<< "usage: " << com << " <configfile> <inputimage> <outputimage> <n.o.trees> <tree-model-prefix>\n"
;
exit(1);
}
/***************************************************************************
Writes profiling output (milli-seconds since last call)
***************************************************************************/
clock_t LastProfilingClock;
inline float profiling (const char *s, clock_t *whichClock=NULL)
{
if (whichClock==NULL)
whichClock=&LastProfilingClock;
clock_t newClock=clock();
float res = (float) (newClock-*whichClock) / (float) CLOCKS_PER_SEC;
if (s!=NULL)
std::cerr << "Time: " << s << ": " << res << std::endl;
*whichClock = newClock;
return res;
}
inline float profilingTime (const char *s, time_t *whichClock)
{
time_t newTime=time(NULL);
float res = (float) (newTime-*whichClock);
if (s!=NULL)
std::cerr << "Time(real): " << s << ": " << res << std::endl;
return res;
}
/***************************************************************************
Test a simple image
***************************************************************************/
void testStructClassForest(StrucClassSSF<float> *forest, ConfigReader *cr, TrainingSetSelection<float> *pTS)
{
int iImage;
cv::Point pt;
cv::Mat matConfusion;
char strOutput[200];
// Process all test images
// result goes into ====> result[].at<>(pt)
for (iImage = 0; iImage < pTS->getNbImages(); ++iImage)
{
// Create a sample object, which contains the imageId
Sample<float> s;
std::cout << "Testing image nr. " << iImage+1 << "\n";
s.imageId = iImage;
cv::Rect box(0, 0, pTS->getImgWidth(s.imageId), pTS->getImgHeight(s.imageId));
cv::Mat mapResult = cv::Mat::ones(box.size(), CV_8UC1) * cr->numLabels;
// ==============================================
// THE CLASSICAL CPU SOLUTION
// ==============================================
profiling("");
int lPXOff = cr->labelPatchWidth / 2;
int lPYOff = cr->labelPatchHeight / 2;
// Initialize the result matrices
vector<cv::Mat> result(cr->numLabels);
for(int j = 0; j < result.size(); ++j)
result[j] = Mat::zeros(box.size(), CV_32FC1);
// Iterate over input image pixels
for(s.y = 0; s.y < box.height; ++s.y)
for(s.x = 0; s.x < box.width; ++s.x)
{
// Obtain forest predictions
// Iterate over all trees
for(size_t t = 0; t < cr->numTrees; ++t)
{
// The prediction itself.
// The given Sample object s contains the imageId and the pixel coordinates.
// p is an iterator to a vector over labels (attribut hist of class Prediction)
// This labels correspond to a patch centered on position s
// (this is the structured version of a random forest!)
vector<uint32_t>::const_iterator p = forest[t].predictPtr(s);
for (pt.y=(int)s.y-lPYOff;pt.y<=(int)s.y+(int)lPYOff;++pt.y)
for (pt.x=(int)s.x-(int)lPXOff;pt.x<=(int)s.x+(int)lPXOff;++pt.x,++p)
{
if (*p<0 || *p >= (size_t)cr->numLabels)
{
std::cerr << "Invalid label in prediction: " << (int) *p << "\n";
exit(1);
}
if (box.contains(pt))
{
result[*p].at<float>(pt) += 1;
}
}
}
}
// Argmax of result ===> mapResult
size_t maxIdx;
for (pt.y = 0; pt.y < box.height; ++pt.y)
for (pt.x = 0; pt.x < box.width; ++pt.x)
{
maxIdx = 0;
for(int j = 1; j < cr->numLabels; ++j)
{
maxIdx = (result[j].at<float>(pt) > result[maxIdx].at<float>(pt)) ? j : maxIdx;
}
mapResult.at<uint8_t>(pt) = (uint8_t)maxIdx;
}
profiling("Prediction");
// Write segmentation map
sprintf(strOutput, "%s/segmap_1st_stage%04d.png", cr->outputFolder.c_str(), iImage);
if (cv::imwrite(strOutput, mapResult)==false)
{
cout<<"Failed to write to "<<strOutput<<endl;
return;
}
// Write RGB segmentation map
cv::Mat imgResultRGB;
convertLabelToRGB(mapResult, imgResultRGB);
sprintf(strOutput, "%s/segmap_1st_stage_RGB%04d.png", cr->outputFolder.c_str(), iImage);
if (cv::imwrite(strOutput, imgResultRGB)==false)
{
cout<<"Failed to write to "<<strOutput<<endl;
return;
}
}
}
/***************************************************************************
MAIN PROGRAM
***************************************************************************/
int main(int argc, char* argv[])
{
string strConfigFile;
ConfigReader cr;
ImageData *idata = new ImageDataFloat();
TrainingSetSelection<float> *pTrainingSet;
bool bTestAll = false;
int optNumTrees=-1;
char *optTreeFnamePrefix=NULL;
char buffer[2048];
srand(time(0));
setlocale(LC_NUMERIC, "C");
profiling(NULL);
#ifndef NDEBUG
std::cout << "******************************************************\n"
<< "DEBUG MODE!!!!!\n"
<< "******************************************************\n";
#endif
if (argc!=4)
usage(*argv);
else
{
strConfigFile = argv[1];
optNumTrees = atoi(argv[2]);
optTreeFnamePrefix = argv[3];
}
if (cr.readConfigFile(strConfigFile)==false)
{
cout<<"Failed to read config file "<<strConfigFile<<endl;
return -1;
}
// Load image data
idata->bGenerateFeatures = true;
if (idata->setConfiguration(cr)==false)
{
cout<<"Failed to initialize image data with configuration"<<endl;
return -1;
}
if (bTestAll==true)
{
std::cout << "Set contains all images. Not supported.\n";
exit(1);
}
else {
// CW Create a dummy training set selection with a single image number
pTrainingSet = new TrainingSetSelection<float>(9, idata);
((TrainingSetSelection<float> *)pTrainingSet)->vectSelectedImagesIndices.push_back(0);
}
cout<<pTrainingSet->getNbImages()<<" test images"<<endl;
// Load forest
StrucClassSSF<float> *forest = new StrucClassSSF<float>[optNumTrees];
profiling("Init + feature extraction");
cr.numTrees = optNumTrees;
cout << "Loading " << cr.numTrees << " trees: \n";
for(int iTree = 0; iTree < optNumTrees; ++iTree)
{
sprintf(buffer, "%s%d.txt", optTreeFnamePrefix, iTree+1);
std::cout << "Loading tree from file " << buffer << "\n";
forest[iTree].bUseRandomBoxes = true;
forest[iTree].load(buffer);
forest[iTree].setTrainingSet(pTrainingSet);
}
cout << "done!" << endl;
profiling("Tree loading");
testStructClassForest(forest, &cr, pTrainingSet);
// delete tree;
delete pTrainingSet;
delete idata;
delete [] forest;
std::cout << "Terminated successfully.\n";
return 0;
}