-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_blas.cpp
110 lines (91 loc) · 2.42 KB
/
test_blas.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <cstdlib>
#include <random>
#undef FINTEGER
#define FINTEGER long
extern "C" {
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
int sgeqrf_(
FINTEGER* m,
FINTEGER* n,
float* a,
FINTEGER* lda,
float* tau,
float* work,
FINTEGER* lwork,
FINTEGER* info);
}
float* new_random_vec(int size) {
float* x = new float[size];
std::mt19937 rng;
std::uniform_real_distribution<> distrib;
for (int i = 0; i < size; i++)
x[i] = distrib(rng);
return x;
}
int main() {
FINTEGER m = 10, n = 20, k = 30;
float *a = new_random_vec(m * k), *b = new_random_vec(n * k),
*c = new float[n * m];
float one = 1.0, zero = 0.0;
printf("BLAS test\n");
sgemm_("Not transposed",
"Not transposed",
&m,
&n,
&k,
&one,
a,
&m,
b,
&k,
&zero,
c,
&m);
printf("errors=\n");
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float accu = 0;
for (int l = 0; l < k; l++)
accu += a[i + l * m] * b[l + j * k];
printf("%6.3f ", accu - c[i + j * m]);
}
printf("\n");
}
long info = 0x64bL << 32;
long mi = 0x64bL << 32 | m;
float* tau = new float[m];
FINTEGER lwork = -1;
float work1;
printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
sgeqrf_(&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);
// sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
printf("info=%016lx\n", info);
if (info >> 32 == 0x64b) {
printf("Lapack uses 32-bit integers\n");
} else {
printf("Lapack uses 64-bit integers\n");
}
return 0;
}