Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
lianapatel committed Jun 11, 2024
1 parent db87eaf commit 61db7c9
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions misc/test_blas.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cstdio>
#include <cstdlib>
#include <random>

#undef FINTEGER
#define FINTEGER long

extern "C" {

/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */

int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);

/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */

int sgeqrf_(
FINTEGER* m,
FINTEGER* n,
float* a,
FINTEGER* lda,
float* tau,
float* work,
FINTEGER* lwork,
FINTEGER* info);
}

float* new_random_vec(int size) {
float* x = new float[size];
std::mt19937 rng;
std::uniform_real_distribution<> distrib;
for (int i = 0; i < size; i++)
x[i] = distrib(rng);
return x;
}

int main() {
FINTEGER m = 10, n = 20, k = 30;
float *a = new_random_vec(m * k), *b = new_random_vec(n * k),
*c = new float[n * m];
float one = 1.0, zero = 0.0;

printf("BLAS test\n");

sgemm_("Not transposed",
"Not transposed",
&m,
&n,
&k,
&one,
a,
&m,
b,
&k,
&zero,
c,
&m);

printf("errors=\n");

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float accu = 0;
for (int l = 0; l < k; l++)
accu += a[i + l * m] * b[l + j * k];
printf("%6.3f ", accu - c[i + j * m]);
}
printf("\n");
}

long info = 0x64bL << 32;
long mi = 0x64bL << 32 | m;
float* tau = new float[m];
FINTEGER lwork = -1;

float work1;

printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
sgeqrf_(&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);

// sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
printf("info=%016lx\n", info);

if (info >> 32 == 0x64b) {
printf("Lapack uses 32-bit integers\n");
} else {
printf("Lapack uses 64-bit integers\n");
}

return 0;
}

0 comments on commit 61db7c9

Please sign in to comment.