forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMin.cu
151 lines (122 loc) · 4.34 KB
/
Min.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/*
* Description:
* this function finds the min along the innermost dimension
* Nd input, (N-1)d output, (N-1)d argmin
*/
__global__ void min_output(float *input, float *output, float *indices,
long nrows, long ncols)
{
// output offset:
long o = threadIdx.x + blockDim.x * blockIdx.x;
if (o >= nrows) return;
// input offset:
long i = o * ncols;
// move pointers
input = input + i;
// compute min:
float min = input[0];
long argmin = 0;
long ii;
for (ii=1; ii<ncols; ii++) {
float val = input[ii];
if (val < min) {
min = val;
argmin = ii;
}
}
// store
output[o] = min;
indices[o] = argmin+1;
}
__global__ void min_gradInput(float *input, float *output, float *indices,
long nrows, long ncols)
{
// output offset:
long o = threadIdx.x + blockDim.x * blockIdx.x;
if (o >= nrows) return;
// input offset:
long i = o * ncols;
// bprop min gradient:
long idx = indices[o]-1;
input[i+idx] = output[o];
}
static int cunn_Min_updateOutput(lua_State *L)
{
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
int dimension = luaT_getfieldcheckint(L, 1, "dimension")-1;
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
THCudaTensor *output = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
luaL_argcheck(L, dimension >= 0 && dimension < input->nDimension, 2, "dimension out of range");
luaL_argcheck(L, dimension == input->nDimension-1, 2, "only supported dimension is innermost (CUDA kernel only)");
input = THCudaTensor_newContiguous(input);
THLongStorage *dim = THLongStorage_newWithSize(input->nDimension);
long i;
for(i = 0; i < input->nDimension; i++)
dim->data[i] = input->size[i];
dim->data[dimension] = 1;
THCudaTensor_resize(output, dim, NULL);
THCudaTensor_resize(indices, dim, NULL);
THLongStorage_free(dim);
float *input_data = THCudaTensor_data(input);
float *output_data = THCudaTensor_data(output);
float *indices_data = THCudaTensor_data(indices);
long nrows = THCudaTensor_nElement(output);
long ncols = input->size[dimension];
// cuda blocks & threads:
long nthreads = 256;
long nblocks = ceil((float)nrows / nthreads);
dim3 blocks(nblocks);
dim3 threads(nthreads);
// kernel:
min_output <<<blocks, threads>>> (input_data, output_data, indices_data, nrows, ncols);
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in Min.updateOutput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
// final cut:
THCudaTensor_free(input);
THCudaTensor_select(output, NULL, dimension, 0);
return 1;
}
static int cunn_Min_updateGradInput(lua_State *L)
{
THCudaTensor *input = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *indices = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "indices", "torch.CudaTensor");
int dimension = luaT_getfieldcheckint(L, 1, "dimension")-1;
THCudaTensor *gradInput = (THCudaTensor *)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THCudaTensor_resizeAs(gradInput, input);
THCudaTensor_zero(gradInput);
float *gradInput_data = THCudaTensor_data(gradInput);
float *gradOutput_data = THCudaTensor_data(gradOutput);
float *indices_data = THCudaTensor_data(indices);
long nrows = THCudaTensor_nElement(gradOutput);
long ncols = gradInput->size[dimension];
// cuda blocks & threads:
long nthreads = 256;
long nblocks = ceil((float)nrows / nthreads);
dim3 blocks(nblocks);
dim3 threads(nthreads);
// kernel:
min_gradInput <<<blocks, threads>>> (gradInput_data, gradOutput_data, indices_data, nrows, ncols);
// check for errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in Min.updateOutput: %s\n", cudaGetErrorString(err));
THError("aborting");
}
return 1;
}
static const struct luaL_Reg cunn_Min__ [] = {
{"Min_updateOutput", cunn_Min_updateOutput},
{"Min_updateGradInput", cunn_Min_updateGradInput},
{NULL, NULL}
};
static void cunn_Min_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_Min__, "nn");
lua_pop(L,1);
}