From 323382962a0d2098668841f473e489f4bd8c9955 Mon Sep 17 00:00:00 2001 From: chris-1187 Date: Mon, 20 Jan 2025 19:02:58 +0100 Subject: [PATCH 01/13] Counter-based Philox RNG - Co-authored-by: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Signed-off-by: chris-1187 --- .../runtime/matrix/data/LibMatrixDatagen.java | 60 ++++-- .../matrix/data/RandomMatrixGenerator.java | 11 +- .../util/CounterBasedPRNGenerator.java | 28 +++ .../sysds/runtime/util/IPRNGenerator.java | 25 +++ .../sysds/runtime/util/PRNGenerator.java | 2 +- .../runtime/util/PhiloxCBPRNGenerator.java | 174 ++++++++++++++++++ 6 files changed, 274 insertions(+), 26 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/util/CounterBasedPRNGenerator.java create mode 100644 src/main/java/org/apache/sysds/runtime/util/IPRNGenerator.java create mode 100644 src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java index c934ca02adf..d85517cb718 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java @@ -36,12 +36,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; -import org.apache.sysds.runtime.util.NormalPRNGenerator; -import org.apache.sysds.runtime.util.PRNGenerator; -import org.apache.sysds.runtime.util.PoissonPRNGenerator; -import org.apache.sysds.runtime.util.UniformPRNGenerator; -import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.runtime.util.*; public class LibMatrixDatagen { @@ -465,13 +460,16 @@ private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int int ncb = (int) Math.ceil((double)cols/blen); int counter = 0; + long[] ctr = {0, 0, 0, 0}; // Counter based RNG counter + // Setup Pseudo Random Number Generator for cell values based on 'pdf'. - PRNGenerator valuePRNG = rgen._valuePRNG; + IPRNGenerator valuePRNG = rgen._valuePRNG; if (valuePRNG == null) { switch (rgen._pdf) { case UNIFORM: valuePRNG = new UniformPRNGenerator(); break; case NORMAL: valuePRNG = new NormalPRNGenerator(); break; case POISSON: valuePRNG = new PoissonPRNGenerator(); break; + case CB_UNIFORM: valuePRNG = new PhiloxCBPRNGenerator(); break; default: throw new DMLRuntimeException("Unsupported distribution function for Rand: " + rgen._pdf); } @@ -505,7 +503,13 @@ private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int // Also note that we cannot use the same seed here, because for ultra-sparse generation // the number of calls to the valuePRNG and nnzPRNG are the same, thus creating correlated // outcomes (bias toward the end of the value range) - nnzPRNG.setSeed((long)(valuePRNG.nextDouble()*Long.MAX_VALUE)); + + if (valuePRNG instanceof CounterBasedPRNGenerator) { + nnzPRNG.setSeed((long)(((CounterBasedPRNGenerator) valuePRNG).getDoubles(ctr, 1)[0]*Long.MAX_VALUE )); + } else { + nnzPRNG.setSeed((long)(((PRNGenerator) valuePRNG).nextDouble()*Long.MAX_VALUE)); + } + boolean localSparse = sparsity < 1 && MatrixBlock.evalSparseFormatInMemory( blockrows, blockcols, (long)(sparsity*blockrows*blockcols)); if ( localSparse) { @@ -515,17 +519,22 @@ private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int out.sparse = true; //otherwise ignored c = out.sparseBlock; } - genSparse(c, clen, blockrows, blockcols, rowoffset, coloffset, - sparsity, estnnzRow, min, range, valuePRNG, nnzPRNG); + if (valuePRNG instanceof PRNGenerator) { + genSparse(c, clen, blockrows, blockcols, rowoffset, coloffset, + sparsity, estnnzRow, min, range, (PRNGenerator) valuePRNG, nnzPRNG); + } } else { if (sparsity == 1.0) { genFullyDense(out.getDenseBlock(), blockrows, blockcols, - rowoffset, coloffset, min, range, valuePRNG); + rowoffset, coloffset, min, range, valuePRNG, ctr); } else { - genDense(out, clen, blockrows, blockcols, rowoffset, coloffset, - sparsity, estnnzRow, min, range, valuePRNG, nnzPRNG); + if (valuePRNG instanceof PRNGenerator) { + genDense(out, clen, blockrows, blockcols, rowoffset, coloffset, + sparsity, estnnzRow, min, range, (PRNGenerator) valuePRNG, nnzPRNG); + } + } } // sparse or dense } // cbj @@ -590,13 +599,26 @@ private static void genDense(MatrixBlock out, int clen, int blockrows, int block } private static void genFullyDense(DenseBlock c, int blockrows, int blockcols, int rowoffset, int coloffset, - double min, double range, PRNGenerator valuePRNG) + double min, double range, IPRNGenerator valuePRNG, long[] ctr) { - for(int i = rowoffset; i < rowoffset+blockrows; i++) { - double[] cvals = c.values(i); - int cix = c.pos(i, coloffset); - for(int j = 0; j < blockcols; j++) - cvals[cix+j] = min + (range * valuePRNG.nextDouble()); + if (valuePRNG instanceof PRNGenerator) { + for(int i = rowoffset; i < rowoffset+blockrows; i++) { + double[] cvals = c.values(i); + int cix = c.pos(i, coloffset); + for(int j = 0; j < blockcols; j++) + cvals[cix+j] = min + (range * ((PRNGenerator)valuePRNG).nextDouble()); + } + } else { + double[] randomDoubles = ((CounterBasedPRNGenerator)valuePRNG).getDoubles(ctr, blockrows * blockcols); + int index = 0; + for (int i = rowoffset; i < rowoffset + blockrows; i++) { + double[] cvals = c.values(i); + int cix = c.pos(i, coloffset); + for (int j = 0; j < blockcols; j++) { + cvals[cix + j] = min + (range * randomDoubles[index]); + index++; + } + } } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java index 38f92be4cbd..939b01644d4 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java @@ -20,10 +20,7 @@ package org.apache.sysds.runtime.matrix.data; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.util.NormalPRNGenerator; -import org.apache.sysds.runtime.util.PRNGenerator; -import org.apache.sysds.runtime.util.PoissonPRNGenerator; -import org.apache.sysds.runtime.util.UniformPRNGenerator; +import org.apache.sysds.runtime.util.*; public class RandomMatrixGenerator { @@ -31,14 +28,14 @@ public class RandomMatrixGenerator { * Types of Probability density functions */ public enum PDF { - NORMAL, UNIFORM, POISSON + NORMAL, UNIFORM, POISSON, CB_UNIFORM } PDF _pdf; int _rows, _cols, _blocksize; double _sparsity, _mean; double _min, _max; - PRNGenerator _valuePRNG; + IPRNGenerator _valuePRNG; public RandomMatrixGenerator() { _pdf = PDF.UNIFORM; @@ -166,6 +163,8 @@ protected void setupValuePRNG() { throw new DMLRuntimeException("Invalid parameter (" + _mean + ") for Poisson distribution."); _valuePRNG = new PoissonPRNGenerator(_mean); break; + case CB_UNIFORM: + _valuePRNG = new PhiloxCBPRNGenerator(); default: throw new DMLRuntimeException("Unsupported probability density function"); } diff --git a/src/main/java/org/apache/sysds/runtime/util/CounterBasedPRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/CounterBasedPRNGenerator.java new file mode 100644 index 00000000000..d163012ec2c --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/util/CounterBasedPRNGenerator.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysds.runtime.util; + +public abstract class CounterBasedPRNGenerator implements IPRNGenerator { + + public abstract void setSeed(long sd); + + public abstract double[] getDoubles(long[] ctr, int size); +} diff --git a/src/main/java/org/apache/sysds/runtime/util/IPRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/IPRNGenerator.java new file mode 100644 index 00000000000..e348874cee1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/util/IPRNGenerator.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysds.runtime.util; + +public interface IPRNGenerator { + public void setSeed(long seed); +} diff --git a/src/main/java/org/apache/sysds/runtime/util/PRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/PRNGenerator.java index ec1fc512efd..ce59978cbae 100644 --- a/src/main/java/org/apache/sysds/runtime/util/PRNGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/util/PRNGenerator.java @@ -20,7 +20,7 @@ package org.apache.sysds.runtime.util; -public abstract class PRNGenerator { +public abstract class PRNGenerator implements IPRNGenerator { public abstract void setSeed(long sd); diff --git a/src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java new file mode 100644 index 00000000000..1f2e367be76 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysds.runtime.util; + +public class PhiloxCBPRNGenerator extends CounterBasedPRNGenerator { + + // Constants for Philox + public static final long PHILOX_M4x64_0_hi = 0xD2E7470EE14C6C93L >>> 32; + public static final long PHILOX_M4x64_0_lo = 0xD2E7470EE14C6C93L & 0xFFFFFFFFL; + public static final long PHILOX_M4x64_1_hi = 0xCA5A826395121157L >>> 32; + public static final long PHILOX_M4x64_1_lo = 0xCA5A826395121157L & 0xFFFFFFFFL; + public static final long PHILOX_W64_0 = 0x9E3779B97F4A7C15L; + public static final long PHILOX_W64_1 = 0xBB67AE8584CAA73BL; + private static final double ULONG_TO_11 = (1.0 / Long.MAX_VALUE); + + // Default number of rounds + public static final int PHILOX4x64_DEFAULT_ROUNDS = 10; + long[] seed; + + public void setSeed(long sd) { + this.seed = new long[2]; + this.seed[0] = sd; + this.seed[1] = sd; + } + + public double[] getDoubles(long[] ctr, int size) { + // Ensure the key is correct size + if (this.seed.length != 2) { + throw new IllegalArgumentException("Key must be 128 bits"); + } + // Ensure the counter is correct size + if (ctr.length != 4) { + throw new IllegalArgumentException("Counter must be 256 bits"); + } + + int iterations = size / 4; + long[] result = new long[size]; + long[] currentKey = new long[]{this.seed[0], this.seed[1]}; // Create a copy of the key + + // Reusable arrays for counters + long[] currentCtr = ctr.clone(); + + for (int i = 0; i < iterations; i++) { + for (int j = 0; j < PHILOX4x64_DEFAULT_ROUNDS; j++) { + // Multiply as 128-bit + long bHigh = currentCtr[0] >>> 32; + long bLow = currentCtr[0] & 0xFFFFFFFFL; + + long hi0 = PHILOX_M4x64_0_hi * bHigh; + long mid1 = PHILOX_M4x64_0_hi * bLow; + long mid2 = PHILOX_M4x64_0_lo * bHigh; + long lo0 = PHILOX_M4x64_0_lo * bLow; + + // Combine results + long carry = (lo0 >>> 32) + (mid1 & 0xFFFFFFFFL) + (mid2 & 0xFFFFFFFFL); + hi0 += (mid1 >>> 32) + (mid2 >>> 32) + (carry >>> 32); + lo0 = (lo0 & 0xFFFFFFFFL) | (carry << 32); + + // Multiply as 128-bit + bHigh = currentCtr[2] >>> 32; + bLow = currentCtr[2] & 0xFFFFFFFFL; + + long hi1 = PHILOX_M4x64_1_hi * bHigh; + mid1 = PHILOX_M4x64_1_hi * bLow; + mid2 = PHILOX_M4x64_1_lo * bHigh; + long lo1 = PHILOX_M4x64_1_lo * bLow; + + // Combine results + carry = (lo1 >>> 32) + (mid1 & 0xFFFFFFFFL) + (mid2 & 0xFFFFFFFFL); + hi1 += (mid1 >>> 32) + (mid2 >>> 32) + (carry >>> 32); + lo1 = (lo1 & 0xFFFFFFFFL) | (carry << 32); + + currentCtr[0] = hi1 ^ currentCtr[1] ^ currentKey[0]; + currentCtr[2] = hi0 ^ currentCtr[3] ^ currentKey[1]; + currentCtr[1] = lo1; + currentCtr[3] = lo0; + + currentKey[0] += PHILOX_W64_0; + currentKey[1] += PHILOX_W64_1; + } + + // Unpack the results + result[i * 4] = currentCtr[0]; + result[i * 4 + 1] = currentCtr[1]; + result[i * 4 + 2] = currentCtr[2]; + result[i * 4 + 3] = currentCtr[3]; + + // Increment the counter + if (++ctr[0] == 0 && ++ctr[1] == 0 && ++ctr[2] == 0) { + ++ctr[3]; + } + currentCtr[0] = ctr[0]; + currentCtr[1] = ctr[1]; + currentCtr[2] = ctr[2]; + currentCtr[3] = ctr[3]; + currentKey[0] = this.seed[0]; + currentKey[1] = this.seed[1]; + } + + // Handle remaining elements + if (size % 4 != 0) { + for (int j = 0; j < PHILOX4x64_DEFAULT_ROUNDS; j++) { + // Multiply as 128-bit + long bHigh = currentCtr[0] >>> 32; + long bLow = currentCtr[0] & 0xFFFFFFFFL; + + long hi0 = PHILOX_M4x64_0_hi * bHigh; + long mid1 = PHILOX_M4x64_0_hi * bLow; + long mid2 = PHILOX_M4x64_0_lo * bHigh; + long lo0 = PHILOX_M4x64_0_lo * bLow; + + // Combine results + long carry = (lo0 >>> 32) + (mid1 & 0xFFFFFFFFL) + (mid2 & 0xFFFFFFFFL); + hi0 += (mid1 >>> 32) + (mid2 >>> 32) + (carry >>> 32); + lo0 = (lo0 & 0xFFFFFFFFL) | (carry << 32); + + // Multiply as 128-bit + bHigh = currentCtr[2] >>> 32; + bLow = currentCtr[2] & 0xFFFFFFFFL; + + long hi1 = PHILOX_M4x64_1_hi * bHigh; + mid1 = PHILOX_M4x64_1_hi * bLow; + mid2 = PHILOX_M4x64_1_lo * bHigh; + long lo1 = PHILOX_M4x64_1_lo * bLow; + + // Combine results + carry = (lo1 >>> 32) + (mid1 & 0xFFFFFFFFL) + (mid2 & 0xFFFFFFFFL); + hi1 += (mid1 >>> 32) + (mid2 >>> 32) + (carry >>> 32); + lo1 = (lo1 & 0xFFFFFFFFL) | (carry << 32); + + currentCtr[0] = hi1 ^ currentCtr[1] ^ currentKey[0]; + currentCtr[2] = hi0 ^ currentCtr[3] ^ currentKey[1]; + currentCtr[1] = lo1; + currentCtr[3] = lo0; + + currentKey[0] += PHILOX_W64_0; + currentKey[1] += PHILOX_W64_1; + } + + // Store the remaining results + switch (size % 4) { + case 3: + result[iterations * 4 + 2] = currentCtr[2]; + case 2: + result[iterations * 4 + 1] = currentCtr[1]; + case 1: + result[iterations * 4] = currentCtr[0]; + } + } + double[] double_result = new double[result.length]; + for (int i = 0; i < result.length; i++) { + double_result[i] = result[i]; + } + return double_result; + } +} From fc01ff5867708f73a8862bbc07c04a66e50ef405 Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Wed, 22 Jan 2025 10:58:04 +0100 Subject: [PATCH 02/13] Added unittests and added counter based normal distributed rng using boxmuller --- .../runtime/matrix/data/LibMatrixDatagen.java | 5 +- .../matrix/data/RandomMatrixGenerator.java | 8 +- .../util/PhiloxNormalCBPRNGenerator.java | 65 +++++++++++++ ....java => PhiloxUniformCBPRNGenerator.java} | 29 +++--- .../matrix/LibMatrixDatagenTest.java | 94 +++++++++++++++++++ 5 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/util/PhiloxNormalCBPRNGenerator.java rename src/main/java/org/apache/sysds/runtime/util/{PhiloxCBPRNGenerator.java => PhiloxUniformCBPRNGenerator.java} (85%) create mode 100644 src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java index d85517cb718..4e84c5014be 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java @@ -135,9 +135,11 @@ public static RandomMatrixGenerator createRandomMatrixGenerator(String pdfStr, i RandomMatrixGenerator rgen = null; switch (pdf) { case UNIFORM: + case CB_UNIFORM: rgen = new RandomMatrixGenerator(pdf, r, c, blen, sp, min, max); break; case NORMAL: + case CB_NORMAL: rgen = new RandomMatrixGenerator(pdf, r, c, blen, sp); break; case POISSON: @@ -469,7 +471,8 @@ private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int case UNIFORM: valuePRNG = new UniformPRNGenerator(); break; case NORMAL: valuePRNG = new NormalPRNGenerator(); break; case POISSON: valuePRNG = new PoissonPRNGenerator(); break; - case CB_UNIFORM: valuePRNG = new PhiloxCBPRNGenerator(); break; + case CB_UNIFORM: valuePRNG = new PhiloxUniformCBPRNGenerator(); break; + case CB_NORMAL: valuePRNG = new PhiloxNormalCBPRNGenerator(); break; default: throw new DMLRuntimeException("Unsupported distribution function for Rand: " + rgen._pdf); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java index 939b01644d4..e4b8eaa5539 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java @@ -28,7 +28,7 @@ public class RandomMatrixGenerator { * Types of Probability density functions */ public enum PDF { - NORMAL, UNIFORM, POISSON, CB_UNIFORM + NORMAL, UNIFORM, POISSON, CB_UNIFORM, CB_NORMAL } PDF _pdf; @@ -164,7 +164,11 @@ protected void setupValuePRNG() { _valuePRNG = new PoissonPRNGenerator(_mean); break; case CB_UNIFORM: - _valuePRNG = new PhiloxCBPRNGenerator(); + _valuePRNG = new PhiloxUniformCBPRNGenerator(); + break; + case CB_NORMAL: + _valuePRNG = new PhiloxNormalCBPRNGenerator(); + break; default: throw new DMLRuntimeException("Unsupported probability density function"); } diff --git a/src/main/java/org/apache/sysds/runtime/util/PhiloxNormalCBPRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/PhiloxNormalCBPRNGenerator.java new file mode 100644 index 00000000000..f1d9594858d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/util/PhiloxNormalCBPRNGenerator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysds.runtime.util; + +public class PhiloxNormalCBPRNGenerator extends CounterBasedPRNGenerator { + private long[] seed; + private PhiloxUniformCBPRNGenerator uniformGen; + + public void setSeed(long sd) { + this.seed = new long[2]; + this.seed[0] = sd; + this.seed[1] = sd; + uniformGen = new PhiloxUniformCBPRNGenerator(); + uniformGen.setSeed(this.seed[0]); + } + + /** + * Generate a sequence of random doubles using the Philox4x64 counter-based PRNG. + * + * @param ctr The start counter to use for the PRNG + * @param size The number of doubles to generate + * @return An array of random doubles distributed normally with mean 0 and variance 1 + */ + public double[] getDoubles(long[] ctr, int size) { + // Ensure the key is correct size + if (this.seed.length != 2) { + throw new IllegalArgumentException("Key must be 128 bits"); + } + // Ensure the counter is correct size + if (ctr.length != 4) { + throw new IllegalArgumentException("Counter must be 256 bits"); + } + + double[] uniform = uniformGen.getDoubles(ctr, size + size % 2); + double[] normal = new double[size]; + for (int i = 0; i < size; i+=2) { + double v1 = Math.sqrt(-2*Math.log(uniform[i])); + double v2 = 2*Math.PI*uniform[i + 1]; + normal[i] = v1 * Math.cos(v2); + if (i + 1 < size) { + normal[i + 1] = v1 * Math.sin(v2); + } + } + + return normal; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java b/src/main/java/org/apache/sysds/runtime/util/PhiloxUniformCBPRNGenerator.java similarity index 85% rename from src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java rename to src/main/java/org/apache/sysds/runtime/util/PhiloxUniformCBPRNGenerator.java index 1f2e367be76..9815eca8682 100644 --- a/src/main/java/org/apache/sysds/runtime/util/PhiloxCBPRNGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/util/PhiloxUniformCBPRNGenerator.java @@ -20,20 +20,20 @@ package org.apache.sysds.runtime.util; -public class PhiloxCBPRNGenerator extends CounterBasedPRNGenerator { +public class PhiloxUniformCBPRNGenerator extends CounterBasedPRNGenerator { // Constants for Philox - public static final long PHILOX_M4x64_0_hi = 0xD2E7470EE14C6C93L >>> 32; - public static final long PHILOX_M4x64_0_lo = 0xD2E7470EE14C6C93L & 0xFFFFFFFFL; - public static final long PHILOX_M4x64_1_hi = 0xCA5A826395121157L >>> 32; - public static final long PHILOX_M4x64_1_lo = 0xCA5A826395121157L & 0xFFFFFFFFL; - public static final long PHILOX_W64_0 = 0x9E3779B97F4A7C15L; - public static final long PHILOX_W64_1 = 0xBB67AE8584CAA73BL; - private static final double ULONG_TO_11 = (1.0 / Long.MAX_VALUE); + private static final long PHILOX_M4x64_0_hi = 0xD2E7470EE14C6C93L >>> 32; + private static final long PHILOX_M4x64_0_lo = 0xD2E7470EE14C6C93L & 0xFFFFFFFFL; + private static final long PHILOX_M4x64_1_hi = 0xCA5A826395121157L >>> 32; + private static final long PHILOX_M4x64_1_lo = 0xCA5A826395121157L & 0xFFFFFFFFL; + private static final long PHILOX_W64_0 = 0x9E3779B97F4A7C15L; + private static final long PHILOX_W64_1 = 0xBB67AE8584CAA73BL; + private static final double LONG_TO_01 = 0.5 / Long.MAX_VALUE; // Default number of rounds - public static final int PHILOX4x64_DEFAULT_ROUNDS = 10; - long[] seed; + private static final int PHILOX4x64_DEFAULT_ROUNDS = 10; + private long[] seed; public void setSeed(long sd) { this.seed = new long[2]; @@ -41,6 +41,13 @@ public void setSeed(long sd) { this.seed[1] = sd; } + /** + * Generate a sequence of random doubles using the Philox4x64 counter-based PRNG. + * + * @param ctr The start counter to use for the PRNG + * @param size The number of doubles to generate + * @return An array of random doubles distributed uniformly between 0 and 1 + */ public double[] getDoubles(long[] ctr, int size) { // Ensure the key is correct size if (this.seed.length != 2) { @@ -167,7 +174,7 @@ public double[] getDoubles(long[] ctr, int size) { } double[] double_result = new double[result.length]; for (int i = 0; i < result.length; i++) { - double_result[i] = result[i]; + double_result[i] = result[i] * LONG_TO_01 + .5; } return double_result; } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java new file mode 100644 index 00000000000..076382b42cc --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java @@ -0,0 +1,94 @@ +package org.apache.sysds.test.component.matrix; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.commons.math3.random.Well1024a; +import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; +import org.junit.Ignore; +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.*; + +public class LibMatrixDatagenTest { + protected static final Log LOG = LogFactory.getLog(LibMatrixDatagenTest.class.getName()); + + @Test + public void testGenerateUniformMatrixPhilox() { + MatrixBlock mb = new MatrixBlock(); + RandomMatrixGenerator rgen = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_UNIFORM, 10, 10, 10, 1, 0., 1.); + LibMatrixDatagen.generateRandomMatrix(mb, rgen, null, 0L); + for(int i = 0; i < 10; i++) { + for(int j = 0; j < 10; j++) { + assertTrue("Value: " + mb.get(i, j) + "needs to be less than 1", mb.get(i, j) < 1); + assertTrue("Value: " + mb.get(i, j) + "needs to be greater than 0", mb.get(i, j) > 0); + } + } + } + + @Test + public void testGenerateNormalMatrixPhilox() { + MatrixBlock mb = new MatrixBlock(); + RandomMatrixGenerator rgen = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_NORMAL, 1000, 1000, 1000 * 1000, 1); + LibMatrixDatagen.generateRandomMatrix(mb, rgen, null, 123123123123L); + double mean = mb.mean(); + double[] bv = mb.getDenseBlockValues(); + double variance = Arrays.stream(bv).map(x -> Math.pow(x - mean, 2)).sum() / bv.length; + assertEquals("Mean should be 0", 0, mean, 0.01); + assertEquals("Variance should be 1", 1, variance, 0.001); + } + + @Test + @Ignore + public void testGenerateUniformMatrixFasterUsingCounterBased() { + MatrixBlock mbPhilox = new MatrixBlock(); + RandomMatrixGenerator rgenPhilox = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_UNIFORM, 1000, 1000, 100, 1, 0., 1.); + long philoxStartTime = System.currentTimeMillis(); + LibMatrixDatagen.generateRandomMatrix(mbPhilox, rgenPhilox, null, 0L); + long philoxEndTime = System.currentTimeMillis(); + + RandomMatrixGenerator rgenDefault = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.UNIFORM, 1000, 1000, 100, 1, 0., 1.); + MatrixBlock mbDefault = new MatrixBlock(); + long defaultStartTime = System.currentTimeMillis(); + LibMatrixDatagen.generateRandomMatrix(mbDefault, rgenDefault, new Well1024a(), 0L); + long defaultEndTime = System.currentTimeMillis(); + + System.out.println("Time taken by Philox (in ms): " + (philoxEndTime - philoxStartTime)); + System.out.println("Time taken by Default (in ms): " + (defaultEndTime - defaultStartTime)); + System.out.println("Philox is faster than default: " + ((philoxEndTime - philoxStartTime) < (defaultEndTime - defaultStartTime))); + } + + @Test + @Ignore + public void testGenerateNormalMatrixFasterUsingCounterBased() { + MatrixBlock mbPhilox = new MatrixBlock(); + RandomMatrixGenerator rgenPhilox = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_NORMAL, 1000, 1000, 100, 1, 0., 1.); + long philoxStartTime = System.currentTimeMillis(); + LibMatrixDatagen.generateRandomMatrix(mbPhilox, rgenPhilox, null, 0L); + long philoxEndTime = System.currentTimeMillis(); + + RandomMatrixGenerator rgenDefault = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.NORMAL, 1000, 1000, 100, 1, 0., 1.); + MatrixBlock mbDefault = new MatrixBlock(); + long defaultStartTime = System.currentTimeMillis(); + LibMatrixDatagen.generateRandomMatrix(mbDefault, rgenDefault, new Well1024a(), 0L); + long defaultEndTime = System.currentTimeMillis(); + + System.out.println("Time taken by Philox (in ms): " + (philoxEndTime - philoxStartTime)); + System.out.println("Time taken by Default (in ms): " + (defaultEndTime - defaultStartTime)); + System.out.println("Philox is faster than default: " + ((philoxEndTime - philoxStartTime) < (defaultEndTime - defaultStartTime))); + } + + @Test + public void testGenerateUniformMatrixPhiloxShouldHaveGoodStatistics() { + MatrixBlock mb = new MatrixBlock(); + RandomMatrixGenerator rgen = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_UNIFORM, 1000, 1000, 100, 1, 0., 1.); + LibMatrixDatagen.generateRandomMatrix(mb, rgen, null, 0L); + + double mean = mb.mean(); + assertEquals("Mean should be 0.5", 0.5, mean, 0.001); + + } +} From 3a6df20e4bf3652fe4d4758938802bb8d55dd19d Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:50:04 +0100 Subject: [PATCH 03/13] fix imports --- .../sysds/runtime/matrix/data/LibMatrixDatagen.java | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java index 4e84c5014be..fa7a2c7557c 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java @@ -36,7 +36,17 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; -import org.apache.sysds.runtime.util.*; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.CounterBasedPRNGenerator; +import org.apache.sysds.runtime.util.IPRNGenerator; +import org.apache.sysds.runtime.util.NormalPRNGenerator; +import org.apache.sysds.runtime.util.PRNGenerator; +import org.apache.sysds.runtime.util.PhiloxNormalCBPRNGenerator; +import org.apache.sysds.runtime.util.PhiloxUniformCBPRNGenerator; +import org.apache.sysds.runtime.util.PoissonPRNGenerator; +import org.apache.sysds.runtime.util.UniformPRNGenerator; +import org.apache.sysds.runtime.util.UtilFunctions; + public class LibMatrixDatagen { From 75f224fced01cc1ece1187c272c77b9d5266a342 Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:50:42 +0100 Subject: [PATCH 04/13] Improve unit test and remove unnecessary tests --- .../matrix/LibMatrixDatagenTest.java | 46 ++----------------- 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java index 076382b42cc..03a5c907391 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java @@ -2,11 +2,9 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.commons.math3.random.Well1024a; import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; -import org.junit.Ignore; import org.junit.Test; import java.util.Arrays; @@ -41,46 +39,6 @@ public void testGenerateNormalMatrixPhilox() { assertEquals("Variance should be 1", 1, variance, 0.001); } - @Test - @Ignore - public void testGenerateUniformMatrixFasterUsingCounterBased() { - MatrixBlock mbPhilox = new MatrixBlock(); - RandomMatrixGenerator rgenPhilox = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_UNIFORM, 1000, 1000, 100, 1, 0., 1.); - long philoxStartTime = System.currentTimeMillis(); - LibMatrixDatagen.generateRandomMatrix(mbPhilox, rgenPhilox, null, 0L); - long philoxEndTime = System.currentTimeMillis(); - - RandomMatrixGenerator rgenDefault = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.UNIFORM, 1000, 1000, 100, 1, 0., 1.); - MatrixBlock mbDefault = new MatrixBlock(); - long defaultStartTime = System.currentTimeMillis(); - LibMatrixDatagen.generateRandomMatrix(mbDefault, rgenDefault, new Well1024a(), 0L); - long defaultEndTime = System.currentTimeMillis(); - - System.out.println("Time taken by Philox (in ms): " + (philoxEndTime - philoxStartTime)); - System.out.println("Time taken by Default (in ms): " + (defaultEndTime - defaultStartTime)); - System.out.println("Philox is faster than default: " + ((philoxEndTime - philoxStartTime) < (defaultEndTime - defaultStartTime))); - } - - @Test - @Ignore - public void testGenerateNormalMatrixFasterUsingCounterBased() { - MatrixBlock mbPhilox = new MatrixBlock(); - RandomMatrixGenerator rgenPhilox = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.CB_NORMAL, 1000, 1000, 100, 1, 0., 1.); - long philoxStartTime = System.currentTimeMillis(); - LibMatrixDatagen.generateRandomMatrix(mbPhilox, rgenPhilox, null, 0L); - long philoxEndTime = System.currentTimeMillis(); - - RandomMatrixGenerator rgenDefault = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.NORMAL, 1000, 1000, 100, 1, 0., 1.); - MatrixBlock mbDefault = new MatrixBlock(); - long defaultStartTime = System.currentTimeMillis(); - LibMatrixDatagen.generateRandomMatrix(mbDefault, rgenDefault, new Well1024a(), 0L); - long defaultEndTime = System.currentTimeMillis(); - - System.out.println("Time taken by Philox (in ms): " + (philoxEndTime - philoxStartTime)); - System.out.println("Time taken by Default (in ms): " + (defaultEndTime - defaultStartTime)); - System.out.println("Philox is faster than default: " + ((philoxEndTime - philoxStartTime) < (defaultEndTime - defaultStartTime))); - } - @Test public void testGenerateUniformMatrixPhiloxShouldHaveGoodStatistics() { MatrixBlock mb = new MatrixBlock(); @@ -90,5 +48,9 @@ public void testGenerateUniformMatrixPhiloxShouldHaveGoodStatistics() { double mean = mb.mean(); assertEquals("Mean should be 0.5", 0.5, mean, 0.001); + double[] bv = mb.getDenseBlockValues(); + assertEquals(1000 * 1000, bv.length); + double variance = Arrays.stream(bv).map(x -> Math.pow(x - mean, 2)).sum() / bv.length; + assertEquals("Variance should be 1", 0.0833, variance, 0.001); } } From 30ac4e27b9e5d4096e3b625f5b3583c759104dd9 Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:53:06 +0100 Subject: [PATCH 05/13] Add missing license --- .../matrix/LibMatrixDatagenTest.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java index 03a5c907391..9749076d912 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.test.component.matrix; import org.apache.commons.logging.Log; From e54c9f375173e3b1361cfdd03c67e4fc5ac3397b Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 14:16:35 +0100 Subject: [PATCH 06/13] refactor genFullyDense --- .../runtime/matrix/data/LibMatrixDatagen.java | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java index fa7a2c7557c..58fb8e4fa8d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java @@ -21,11 +21,14 @@ package org.apache.sysds.runtime.matrix.data; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.stream.LongStream; +import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -614,23 +617,19 @@ private static void genDense(MatrixBlock out, int clen, int blockrows, int block private static void genFullyDense(DenseBlock c, int blockrows, int blockcols, int rowoffset, int coloffset, double min, double range, IPRNGenerator valuePRNG, long[] ctr) { + Iterator rngStream; if (valuePRNG instanceof PRNGenerator) { - for(int i = rowoffset; i < rowoffset+blockrows; i++) { - double[] cvals = c.values(i); - int cix = c.pos(i, coloffset); - for(int j = 0; j < blockcols; j++) - cvals[cix+j] = min + (range * ((PRNGenerator)valuePRNG).nextDouble()); - } + rngStream = Stream.generate(() -> min + (range * ((PRNGenerator) valuePRNG).nextDouble())).iterator(); + } else if (valuePRNG instanceof CounterBasedPRNGenerator) { + rngStream = Arrays.stream(((CounterBasedPRNGenerator)valuePRNG).getDoubles(ctr, blockrows * blockcols)).map(i -> min + (range * i)).iterator(); } else { - double[] randomDoubles = ((CounterBasedPRNGenerator)valuePRNG).getDoubles(ctr, blockrows * blockcols); - int index = 0; - for (int i = rowoffset; i < rowoffset + blockrows; i++) { - double[] cvals = c.values(i); - int cix = c.pos(i, coloffset); - for (int j = 0; j < blockcols; j++) { - cvals[cix + j] = min + (range * randomDoubles[index]); - index++; - } + throw new DMLRuntimeException("Unsupported PRNGenerator for fully dense generation"); + } + for (int i = rowoffset; i < rowoffset + blockrows; i++) { + double[] cvals = c.values(i); + int cix = c.pos(i, coloffset); + for (int j = 0; j < blockcols; j++) { + cvals[cix + j] = rngStream.next(); } } } From 8dc6975580977f0f917a9d5e856d2ef73a223bd3 Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 14:21:05 +0100 Subject: [PATCH 07/13] Removed wildcard import --- .../sysds/test/component/matrix/LibMatrixDatagenTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java index 9749076d912..7363dfc5b88 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java @@ -28,7 +28,8 @@ import java.util.Arrays; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class LibMatrixDatagenTest { protected static final Log LOG = LogFactory.getLog(LibMatrixDatagenTest.class.getName()); From 931f3178ca589ea4397934c465ce4d5c80e6e2a8 Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Sun, 26 Jan 2025 14:23:05 +0100 Subject: [PATCH 08/13] Removed wildcard import --- .../sysds/runtime/matrix/data/RandomMatrixGenerator.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java index e4b8eaa5539..408f2475fb5 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/RandomMatrixGenerator.java @@ -20,7 +20,12 @@ package org.apache.sysds.runtime.matrix.data; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.util.*; +import org.apache.sysds.runtime.util.IPRNGenerator; +import org.apache.sysds.runtime.util.NormalPRNGenerator; +import org.apache.sysds.runtime.util.PhiloxNormalCBPRNGenerator; +import org.apache.sysds.runtime.util.PhiloxUniformCBPRNGenerator; +import org.apache.sysds.runtime.util.PoissonPRNGenerator; +import org.apache.sysds.runtime.util.UniformPRNGenerator; public class RandomMatrixGenerator { From 79c58c4eb840409ae14a4af656b45c8a77cf29f5 Mon Sep 17 00:00:00 2001 From: chris-1187 Date: Sun, 26 Jan 2025 15:10:54 +0100 Subject: [PATCH 09/13] RandCounterBased DML Signed-off-by: chris-1187 --- .../functions/data/RandCounterBasedNormal.dml | 24 +++++++++++++++++++ .../data/RandCounterBasedUniform.dml | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 src/test/scripts/functions/data/RandCounterBasedNormal.dml create mode 100644 src/test/scripts/functions/data/RandCounterBasedUniform.dml diff --git a/src/test/scripts/functions/data/RandCounterBasedNormal.dml b/src/test/scripts/functions/data/RandCounterBasedNormal.dml new file mode 100644 index 00000000000..f8263d966a0 --- /dev/null +++ b/src/test/scripts/functions/data/RandCounterBasedNormal.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = Rand(rows=$1, cols=$2, sparsity=$3, seed=$4, pdf="CB_NORMAL"); +write(A, $5); \ No newline at end of file diff --git a/src/test/scripts/functions/data/RandCounterBasedUniform.dml b/src/test/scripts/functions/data/RandCounterBasedUniform.dml new file mode 100644 index 00000000000..65dc1f07363 --- /dev/null +++ b/src/test/scripts/functions/data/RandCounterBasedUniform.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = Rand(rows=$1, cols=$2, sparsity=$3, seed=$4, pdf="CB_UNIFORM"); +write(A, $5); \ No newline at end of file From a8d63ff1a2806681d88bc1f8446241614536da4d Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Wed, 29 Jan 2025 21:57:57 +0100 Subject: [PATCH 10/13] Added test to make sure values generated when using streams are the same. --- .../test/component/matrix/LibMatrixDatagenTest.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java index 7363dfc5b88..30e0c75fbd2 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/LibMatrixDatagenTest.java @@ -30,6 +30,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertArrayEquals; public class LibMatrixDatagenTest { protected static final Log LOG = LogFactory.getLog(LibMatrixDatagenTest.class.getName()); @@ -73,4 +74,15 @@ public void testGenerateUniformMatrixPhiloxShouldHaveGoodStatistics() { double variance = Arrays.stream(bv).map(x -> Math.pow(x - mean, 2)).sum() / bv.length; assertEquals("Variance should be 1", 0.0833, variance, 0.001); } + + @Test + public void testGenerateUniformMatrixShouldReturnSameValuesUsingStreams() { + MatrixBlock mb = new MatrixBlock(); + RandomMatrixGenerator rgen = new RandomMatrixGenerator(RandomMatrixGenerator.PDF.UNIFORM, 1000, 1000, 100, 1, 0., 1.); + LibMatrixDatagen.generateRandomMatrix(mb, rgen, null, 0L); + + double[] bv = Arrays.copyOf(mb.getDenseBlockValues(), 100); + double[] previous = new double[] {0.24053641567148587, 0.6374174253501083, 0.5504370051176339, 0.5975452777972018, 0.3332183994766498, 0.3851891847407185, 0.984841540199809, 0.8791825178724801, 0.9412491794821144, 0.27495396603548483, 0.12889715087377673, 0.14660165764651822, 0.023238122483889456, 0.5467397571984656, 0.9644868606768501, 0.10449068625097169, 0.6251463634655593, 0.4107961954910617, 0.7763122912749325, 0.990722785714783, 0.4872328470301428, 0.7462414053223305, 0.7331520701949938, 0.8172970714093244, 0.8388903500470183, 0.5266994346048661, 0.8993350116114935, 0.13393984058689223, 0.0830623982249149, 0.9785743401478403, 0.7223571191888487, 0.7150310138504744, 0.14322038530059678, 0.4629578184224229, 0.004485602182885184, 0.07149831487989411, 0.34842022979166454, 0.3387696535357536, 0.859356551354648, 0.9715469888517128, 0.8657458802140383, 0.6125811047098682, 0.17898798452881726, 0.21757041220968598, 0.8544871670422907, 0.009673497300974332, 0.6922930069529333, 0.7713129661706796, 0.7126874281456893, 0.2112353749298962, 0.7830924897671794, 0.945333238959629, 0.014236355103667941, 0.3942035527773311, 0.8537907753080728, 0.7860424508145526, 0.993471955005814, 0.883104405981479, 0.17029153024770394, 0.9620689182075386, 0.7242950335788688, 0.6773541612498745, 0.8043954172246357, 0.44142677367579175, 0.46208799028599445, 0.8528274665994607, 0.501834850205735, 0.9919429804102169, 0.9692699099404161, 0.35310607217911816, 0.047265869196129406, 0.0716236234178006, 0.02910751272163581, 0.48367019010510015, 0.9719501209537452, 0.9891171507514055, 0.7674421030154899, 0.5013973510122299, 0.2555253108964435, 0.30915818724818767, 0.8482805002723425, 0.052084538173983286, 0.010175454536229256, 0.35385296970871194, 0.08673785516572752, 0.8503115152643057, 0.0036769023557003955, 0.3078931676344727, 0.5316085562487977, 0.9188142018385732, 0.27721002606871137, 0.8742622102831944, 0.6098815135127635, 0.9086392096967358, 0.04449062015679506, 0.6467239010388895, 0.4968037636226561, 0.5067015959528527, 0.5206888198929495, 0.36636074451399603}; + assertArrayEquals(previous, bv, 0.0001); + } } From 543f3b2616bb1dd83eaa6aca8f76d4472cc0ffda Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Wed, 29 Jan 2025 23:12:21 +0100 Subject: [PATCH 11/13] Add instructions on how to use the cuda version to the staging folder --- .../PhiloxJNvrtcExample.java | 90 ++ .../PhiloxRuntimeCompilationExample.java | 225 +++++ .../staging/cuda-counter-based-prng/kernel.cu | 260 ++++++ .../cuda-counter-based-prng/philox_kernel.ptx | 772 ++++++++++++++++++ .../staging/cuda-counter-based-prng/pom.xml | 38 + .../staging/cuda-counter-based-prng/readme.md | 410 ++++++++++ 6 files changed, 1795 insertions(+) create mode 100644 scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java create mode 100644 scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java create mode 100644 scripts/staging/cuda-counter-based-prng/kernel.cu create mode 100644 scripts/staging/cuda-counter-based-prng/philox_kernel.ptx create mode 100644 scripts/staging/cuda-counter-based-prng/pom.xml create mode 100644 scripts/staging/cuda-counter-based-prng/readme.md diff --git a/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java b/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java new file mode 100644 index 00000000000..7d1cafee4fe --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java @@ -0,0 +1,90 @@ +import jcuda.*; +import jcuda.driver.*; +import jcuda.nvrtc.*; +import jcuda.runtime.JCuda; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; + +import static jcuda.driver.JCudaDriver.cuCtxCreate; + +public class PhiloxJNvrtcExample { + + public static void main(String[] args) { + // Enable exceptions and omit error checks + JCuda.setExceptionsEnabled(true); + JCudaDriver.setExceptionsEnabled(true); + JNvrtc.setExceptionsEnabled(true); + + String ptx = ""; + try { + ptx = new String(Files.readAllBytes(Paths.get("philox_kernel.ptx"))); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + + // Print the PTX for debugging + //System.out.println("Generated PTX:"); + // System.out.println(ptx); + + // Initialize the driver API and create a context + JCudaDriver.cuInit(0); + CUdevice device = new CUdevice(); + JCudaDriver.cuDeviceGet(device, 0); + CUcontext context = new CUcontext(); + cuCtxCreate(context, 0, device); + + CUmodule module = new CUmodule(); + JCudaDriver.cuModuleLoadData(module, ptx); + + // Get a function pointer to the kernel + CUfunction function = new CUfunction(); + JCudaDriver.cuModuleGetFunction(function, module, "philox_4_64"); + + // Prepare data + int n = 1000; // Number of random numbers to generate + long[] hostOut = new long[n]; + CUdeviceptr deviceOut = new CUdeviceptr(); + JCudaDriver.cuMemAlloc(deviceOut, n * Sizeof.LONG); + + // Direkte Werte für seed und startingCounter + long seed = 0L; // Fester Seed-Wert + long startingCounter = 0L; // Startwert für Counter + + Pointer kernelParameters = Pointer.to( + Pointer.to(deviceOut), // ulong* output + Pointer.to(new long[]{seed}), // uint64_t seed + Pointer.to(new long[]{startingCounter}), // uint64_t startingCounter + Pointer.to(new long[]{n}) // size_t numElements + ); + + // Launch the kernel + int blockSizeX = 128; + int gridSizeX = (int) Math.ceil((double)n / blockSizeX); + JCudaDriver.cuLaunchKernel( + function, + gridSizeX, 1, 1, // Grid dimension + blockSizeX, 1, 1, // Block dimension + 0, null, // Shared memory size and stream + kernelParameters, null // Kernel- und extra parameters + ); + JCudaDriver.cuCtxSynchronize(); + + // Copy result back + JCudaDriver.cuMemcpyDtoH(Pointer.to(hostOut), deviceOut, n * Sizeof.LONG); + + // Print results + System.out.println("Generated random numbers with seed=" + + String.format("0x%016X", seed) + + " and startingCounter=" + startingCounter); + for (int i = 0; i < Math.min(10, n); i++) { + System.out.printf("hostOut[%d] = 0x%016X\n", i, hostOut[i]); + } + + // Cleanup + JCudaDriver.cuMemFree(deviceOut); + JCudaDriver.cuCtxDestroy(context); + } +} diff --git a/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java b/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java new file mode 100644 index 00000000000..93a1840ba3e --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java @@ -0,0 +1,225 @@ +import jcuda.*; +import jcuda.driver.*; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static java.nio.file.Files.readAllBytes; +import static jcuda.driver.JCudaDriver.*; + +public class PhiloxRuntimeCompilationExample implements AutoCloseable { + private static String philox4x64KernelSource = "#include \n" + + "#include \n" + + "extern \"C\" __global__ void philox_4_64(ulong* output, uint64_t startingCounter, uint64_t seed, size_t numElements) {\n" + + + " uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n" + + " if (idx * 4 < numElements) {\n" + + " r123::Philox4x64 rng;\n" + + " r123::Philox4x64::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}};\n" + + " r123::Philox4x64::key_type key = {{seed}};\n" + + " r123::Philox4x64::ctr_type result = rng(ctr, key);\n" + + " for (int i = 0; i < 4; ++i) {\n" + + " size_t outputIdx = idx * 4 + i;\n" + + " if (outputIdx < numElements) {\n" + + " output[outputIdx] = result[i];\n" + + " }\n" + + " }\n" + + " }\n" + + "}\n"; + + private final CUcontext context; + private final CUmodule module; + private final CUfunction function; + private final int blockSize; + + public PhiloxRuntimeCompilationExample() { + JCudaDriver.setExceptionsEnabled(true); + // Initialize CUDA + cuInit(0); + CUdevice device = new CUdevice(); + cuDeviceGet(device, 0); + context = new CUcontext(); + int result = cuCtxCreate(context, 0, device); + if (result != CUresult.CUDA_SUCCESS) { + throw new RuntimeException( + "Kontext-Erstellung fehlgeschlagen: " + result + ", " + CUresult.stringFor(result)); + } + + // Compile to PTX + String ptx = compileToTPX(philox4x64KernelSource); + + // Load the PTX + module = new CUmodule(); + cuModuleLoadData(module, ptx); + function = new CUfunction(); + cuModuleGetFunction(function, module, "philox_4_64"); + + // Set block size based on device capabilities + blockSize = 64; // Can be adjusted based on device properties + } + + private String compileToTPX(String source) { + try { + // Temporäre Dateien erstellen + File sourceFile = File.createTempFile("philox_kernel", ".cu"); + File outputFile = File.createTempFile("philox_kernel", ".ptx"); + + // CUDA-Quellcode in temporäre Datei schreiben + try (FileWriter writer = new FileWriter(sourceFile)) { + writer.write(philox4x64KernelSource); + } + + // nvcc Kommando zusammenbauen + List command = new ArrayList<>(); + command.add("/usr/local/cuda/bin/nvcc"); + command.add("-ccbin"); + command.add("gcc-8"); + command.add("--ptx"); // PTX-Output generieren + command.add("-o"); + command.add(outputFile.getAbsolutePath()); + command.add("-I"); + command.add("./lib/random123/include"); + command.add(sourceFile.getAbsolutePath()); + + // Prozess erstellen und ausführen + ProcessBuilder pb = new ProcessBuilder(command); + pb.redirectErrorStream(true); + Process process = pb.start(); + + // Output des Kompilers lesen + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream()))) { + String line; + StringBuilder output = new StringBuilder(); + while ((line = reader.readLine()) != null) { + output.append(line).append("\n"); + } + System.out.println("Compiler Output: " + output.toString()); + } + + // Auf Prozessende warten + int exitCode = process.waitFor(); + if (exitCode != 0) { + throw new RuntimeException("nvcc Kompilierung fehlgeschlagen mit Exit-Code: " + exitCode); + } + + // PTX-Datei einlesen + String ptxCode = new String(readAllBytes(outputFile.toPath())); + + // Aufräumen + sourceFile.delete(); + outputFile.delete(); + + return ptxCode; + + } catch (Exception e) { + throw new RuntimeException("Fehler bei der CUDA-Kompilierung: " + e.getMessage(), e); + } + } + + /** + * Generates random numbers using the Philox4x64 algorithm + * + * @param startingCounter Initial counter value + * @param seed Random seed + * @param numElements Number of random numbers to generate + * @return Array of random numbers + */ + public CUdeviceptr Philox4x64(long startingCounter, long seed, int numElements) { + // Allocate host memory for results + // long[] hostOutput = new long[numElements]; + + // Allocate device memory + CUdeviceptr deviceOutput = new CUdeviceptr(); + cuMemAlloc(deviceOutput, (long) numElements * Sizeof.LONG); + + try { + // Set up kernel parameters mit Debugging + System.out.printf("numElements: %d, seed: %d, startingCounter: %d%n", + numElements, seed, startingCounter); + + Pointer kernelParams = Pointer.to( + Pointer.to(deviceOutput), + Pointer.to(new long[] { startingCounter }), + Pointer.to(new long[] { seed }), + Pointer.to(new long[] { numElements })); + + // Calculate grid size + int gridSize = (numElements + (blockSize * 4) - 1) / (blockSize * 4); + + // Launch kernel mit Fehlerprüfung + int kernelResult = cuLaunchKernel(function, + gridSize, 1, 1, // Grid dimension + blockSize, 1, 1, // Block dimension + 0, null, // Shared memory size and stream + kernelParams, null // Kernel parameters and extra parameters + ); + if (kernelResult != CUresult.CUDA_SUCCESS) { + throw new RuntimeException( + "Kernel-Launch fehlgeschlagen: " + kernelResult + ", " + CUresult.stringFor(kernelResult)); + } + + // Copy results back to host + // cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, (long) numElements * + // Sizeof.LONG); + } finally { + // Free device memory + // cuMemFree(deviceOutput); + } + + // return hostOutput; + return deviceOutput; + } + + /** + * Cleans up CUDA resources + */ + public void close() { + cuModuleUnload(module); + cuCtxDestroy(context); + } + + // Example usage + public static void main(String[] args) { + try (PhiloxRuntimeCompilationExample generator = new PhiloxRuntimeCompilationExample()) { + // Generate 1 million random numbers + int numElements = 1_000_000; + long seed = 0L; + long startingCounter = 0L; + + CUdeviceptr randomNumbers = generator.Philox4x64(startingCounter, seed, numElements); + + long[] elements = new long[10]; + cuMemcpyDtoH(Pointer.to(elements), randomNumbers, 10L * Sizeof.LONG); + cuMemFree(randomNumbers); + + // Print first few numbers + System.out.println("First 10 random numbers:"); + for (int i = 0; i < 10; i++) { + System.out.printf("%d: %x%n", i, elements[i]); + } + + int size = 10_000_000; + long start = System.currentTimeMillis(); + CUdeviceptr ptr = generator.Philox4x64(0L, 0L, size); + long end = System.currentTimeMillis(); + System.out.println("philox4x64 speed test: " + (end - start) * 1000 + " microseconds"); + cuMemFree(ptr); + Random r = new Random(); + long javaStart = System.currentTimeMillis(); + for (int i = 0; i < size; i++) { + r.nextLong(); + } + long javaEnd = System.currentTimeMillis(); + System.out.println("java speed test: " + (javaEnd - javaStart) * 1000 + " microseconds"); + System.out.println("philox4x64 is " + (double) (javaEnd - javaStart) / (double) (end - start) + + " times faster than java"); + + } + } +} diff --git a/scripts/staging/cuda-counter-based-prng/kernel.cu b/scripts/staging/cuda-counter-based-prng/kernel.cu new file mode 100644 index 00000000000..456cca39c03 --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/kernel.cu @@ -0,0 +1,260 @@ +#include +#include +#include +#include +#include +#include +#include + +// CUDA kernel to generate random doubles between 0 and 1 using all 4 integers from Philox +extern "C" __global__ void philox_4_64_uniform(double* output, uint64_t originalKey, r123::Philox4x64::ctr_type startingCounter, size_t numElements) { + // Calculate the thread's unique index + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + double UINT_TO_ZERO_ONE = 1.0 / LONG_MAX; + + // Ensure the thread index is within bounds + if (idx * 4 < numElements) { + // Initialize the Philox generator with a unique counter and key + r123::Philox4x64 rng; + r123::Philox4x64::ctr_type ctr; + uint64_t sum0 = startingCounter[0] + idx; + uint64_t sum1 = startingCounter[1] + (sum0 < startingCounter[0] ? 1 : 0); // Carry-Bit + + ctr[0] = sum0; + ctr[1] = sum1; + ctr[2] = startingCounter[2]; + ctr[3] = startingCounter[3]; + r123::Philox4x64::key_type key = {{originalKey}}; // Key (seed) + + // Generate 4 random integers + r123::Philox4x64::ctr_type result = rng(ctr, key); + + // Convert each 64-bit integer to a double in [-1, 1] + for (int i = 0; i < 4; ++i) { + double randomDouble = static_cast((long)result[i]) * UINT_TO_ZERO_ONE; + size_t outputIdx = idx * 4 + i; + + // Ensure we don't exceed the output array bounds + if (outputIdx < numElements) { + output[outputIdx] = randomDouble; + } + } + } +} + +// CUDA kernel to generate random doubles between 0 and 1 using all 4 integers from Philox +extern "C" __global__ void philox_4_64_standard(double* output, uint64_t originalKey, r123::Philox4x64::ctr_type startingCounter, size_t numElements) { + // Calculate the thread's unique index + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + uint64_t idx2 = idx + numElements; + double UINT_TO_ZERO_ONE = 1.0 / LONG_MAX; + + // Ensure the thread index is within bounds + if (idx * 4 < numElements) { + // Initialize the Philox generator with a unique counter and key + r123::Philox4x64 rng; + r123::Philox4x64::ctr_type ctr1; + uint64_t sum0 = startingCounter[0] + idx; + uint64_t sum1 = startingCounter[1] + (sum0 < startingCounter[0] ? 1 : 0); // Carry-Bit + + ctr1[0] = sum0; + ctr1[1] = sum1; + ctr1[2] = startingCounter[2]; + ctr1[3] = startingCounter[3]; + r123::Philox4x64::ctr_type ctr2; + sum0 = startingCounter[0] + idx2; + sum1 = startingCounter[1] + (sum0 < startingCounter[0] ? 1 : 0); // Carry-Bit + + ctr2[0] = sum0; + ctr2[1] = sum1; + ctr2[2] = startingCounter[2]; + ctr2[3] = startingCounter[3]; + + r123::Philox4x64::key_type key1 = {{originalKey}}; + r123::Philox4x64::key_type key2 = {{originalKey}}; + + // Generate 4 random integers + r123::Philox4x64::ctr_type result1 = rng(ctr1, key1); + r123::Philox4x64::ctr_type result2 = rng(ctr2, key2); + + // Convert each 64-bit integer to a double in [-1, 1] + for (int i = 0; i < 4; ++i) { + double randomDouble1 = static_cast((long)result1[i]) * UINT_TO_ZERO_ONE; + double randomDouble2 = static_cast((long)result2[i]) * UINT_TO_ZERO_ONE; + + size_t outputIdx = idx * 4 + i; + + // Ensure we don't exceed the output array bounds + if (outputIdx < numElements) { + output[outputIdx] = (randomDouble1 + randomDouble2) / 2; + } + } + } +} + + +// CUDA kernel to generate random integers from Philox +extern "C" __global__ void philox_4_32(uint* output, uint32_t seed, uint32_t startingCounter, size_t numElements) { + // Calculate the thread's unique index + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Ensure the thread index is within bounds + if (idx * 4 < numElements) { + // Initialize the Philox generator with a unique counter and key + r123::Philox4x32 rng; + r123::Philox4x32::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}}; // Counter (startingCounter + thread index) + r123::Philox4x32::key_type key = {{seed}}; // Key (seed) + + // Generate 4 random integers + r123::Philox4x32::ctr_type result = rng(ctr, key); + + for (int i = 0; i < 4; ++i) { + size_t outputIdx = idx * 4 + i; + + // Ensure we don't exceed the output array bounds + if (outputIdx < numElements) { + output[outputIdx] = result[i]; + } + } + } +} + + +// CUDA kernel to generate random longs from Philox +extern "C" __global__ void philox_4_64(ulong* output, uint64_t seed, uint64_t startingCounter, size_t numElements) { + // Calculate the thread's unique index + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Ensure the thread index is within bounds + if (idx * 4 < numElements) { + // Initialize the Philox generator with a unique counter and key + r123::Philox4x64 rng; + r123::Philox4x64::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}}; // Counter (startingCounter + thread index) + r123::Philox4x64::key_type key = {{seed}}; // Key (seed) + + // Generate 4 random integers + r123::Philox4x64::ctr_type result = rng(ctr, key); + + for (int i = 0; i < 4; ++i) { + size_t outputIdx = idx * 4 + i; + + // Ensure we don't exceed the output array bounds + if (outputIdx < numElements) { + output[outputIdx] = result[i]; + } + } + } +} + + +int main(int argc, char** argv) { + // Check command-line arguments + if (argc != 4) { + std::cerr << "Usage: " << argv[0] << " \n"; + return 1; + } + + // Parse command-line arguments + uint64_t seed = std::stoull(argv[1]); // Seed (key) + uint64_t startingCounter = std::stoull(argv[2]); // Starting counter + size_t numElements = std::stoul(argv[3]); // Number of random numbers to generate + + // Allocate host and device memory + double* h_output = new double[numElements]; + double* d_output; + cudaMalloc(&d_output, numElements * sizeof(double)); + + // Launch the CUDA kernel + const int blockSize = 512; + const int gridSize = (numElements + blockSize * 4 - 1) / (blockSize * 4); // Adjust grid size for 4 outputs per thread + r123::Philox4x64::ctr_type uniformCounter = {{startingCounter, 0, 0, 0}}; + + auto start = std::chrono::high_resolution_clock::now(); + philox_4_64_standard<<>>(d_output, seed, uniformCounter, numElements); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + std::cout << "Time: " << duration.count() << " microseconds" << std::endl; + + // Copy the results back to the host + cudaMemcpy(h_output, d_output, numElements * sizeof(double), cudaMemcpyDeviceToHost); + + // Print the first 10 random doubles + std::cout << "First 10 random doubles:\n"; + for (int i = 0; i < 10; ++i) { + std::cout << h_output[i] << "\n"; + } + + double avg = 0.0; + for (int i = 0; i < numElements; i++) { + avg += (double)h_output[i] / numElements; + } + printf("Average: %f\n", avg); + double standardDeviation = 0.0; + for (int i = 0; i < numElements; i++) { + standardDeviation += std::pow((double)h_output[i] - avg, 2); + } + standardDeviation = sqrt(standardDeviation / numElements); + printf("standardDeviation: %f\n", standardDeviation); + + + // Free memory + delete[] h_output; + cudaFree(d_output); + + // -------------------------------------------------------------------------------- + + seed = std::stoull(argv[1]); // Seed (key) + startingCounter = std::stoull(argv[2]); // Starting counter + numElements = std::stoul(argv[3]); // Number of random numbers to generate + + // Allocate host and device memory + uint* h_output_int = new uint[numElements]; + uint* d_output_int; + cudaMalloc(&d_output_int, numElements * sizeof(uint)); + + // Launch the CUDA kernel + philox_4_32<<>>(d_output_int, seed, startingCounter, numElements); + + // Copy the results back to the host + cudaMemcpy(h_output_int, d_output_int, numElements * sizeof(uint), cudaMemcpyDeviceToHost); + + // Print the first 10 random doubles + std::cout << "First 10 random doubles:\n"; + for (int i = 0; i < 10; ++i) { + std::cout << std::hex << h_output_int[i] << " " << r123::uneg11(h_output_int[i]) << "\n"; + } + + // Free memory + delete[] h_output_int; + cudaFree(d_output_int); + + // -------------------------------------------------------------------------------- + + seed = std::stoull(argv[1]); // Seed (key) + startingCounter = std::stoull(argv[2]); // Starting counter + numElements = std::stoul(argv[3]); // Number of random numbers to generate + + // Allocate host and device memory + ulong* h_output_long = new ulong[numElements]; + ulong* d_output_long; + cudaMalloc(&d_output_long, numElements * sizeof(ulong)); + + // Launch the CUDA kernel + philox_4_64<<>>(d_output_long, seed, startingCounter, numElements); + + // Copy the results back to the host + cudaMemcpy(h_output_long, d_output_long, numElements * sizeof(ulong), cudaMemcpyDeviceToHost); + + // Print the first 10 random doubles + std::cout << "First 10 random doubles:\n"; + for (int i = 0; i < 10; ++i) { + std::cout << std::setprecision(17) << std::hex << h_output_long[i] << " " << (static_cast((long)h_output_long[i]) / LONG_MAX) << "\n"; + } + + // Free memory + delete[] h_output_long; + cudaFree(d_output_long); + + return 0; +} diff --git a/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx b/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx new file mode 100644 index 00000000000..90e0f4fa854 --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx @@ -0,0 +1,772 @@ +// +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-27506705 +// Cuda compilation tools, release 10.2, V10.2.89 +// Based on LLVM 3.4svn +// +// @page LICENSE +// Copyright 2010-2012, D. E. Shaw Research. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions, and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions, and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// * Neither the name of D. E. Shaw Research nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +.version 6.5 +.target sm_30 +.address_size 64 + + // .globl philox_4_64_uniform + +.visible .entry philox_4_64_uniform( + .param .u64 philox_4_64_uniform_param_0, + .param .u64 philox_4_64_uniform_param_1, + .param .align 8 .b8 philox_4_64_uniform_param_2[32], + .param .u64 philox_4_64_uniform_param_3 +) +{ + .reg .pred %p<6>; + .reg .b32 %r<5>; + .reg .f64 %fd<9>; + .reg .b64 %rd<110>; + + + ld.param.u64 %rd8, [philox_4_64_uniform_param_0]; + ld.param.u64 %rd9, [philox_4_64_uniform_param_1]; + ld.param.u64 %rd13, [philox_4_64_uniform_param_2+24]; + ld.param.u64 %rd1, [philox_4_64_uniform_param_2+16]; + ld.param.u64 %rd11, [philox_4_64_uniform_param_2+8]; + ld.param.u64 %rd10, [philox_4_64_uniform_param_2]; + ld.param.u64 %rd14, [philox_4_64_uniform_param_3]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r1, %r2, %r3; + cvt.u64.u32 %rd2, %r4; + mul.wide.u32 %rd3, %r4, 4; + setp.ge.u64 %p1, %rd3, %rd14; + @%p1 bra BB0_7; + + add.s64 %rd15, %rd10, %rd2; + setp.lt.u64 %p2, %rd15, %rd10; + selp.u64 %rd16, 1, 0, %p2; + add.s64 %rd17, %rd16, %rd11; + mov.u64 %rd18, -3249550476889527149; + mul.hi.u64 %rd19, %rd18, %rd15; + mul.lo.s64 %rd20, %rd15, -3249550476889527149; + mov.u64 %rd21, -3865633965929787049; + mul.hi.u64 %rd22, %rd21, %rd1; + xor.b64 %rd23, %rd22, %rd9; + xor.b64 %rd24, %rd23, %rd17; + xor.b64 %rd25, %rd19, %rd13; + mul.hi.u64 %rd26, %rd18, %rd24; + mul.lo.s64 %rd27, %rd24, -3249550476889527149; + mul.hi.u64 %rd28, %rd21, %rd25; + mul.lo.s64 %rd29, %rd25, -3865633965929787049; + add.s64 %rd30, %rd9, -7046029254386353131; + mul.lo.s64 %rd31, %rd1, -3865633965929787049; + xor.b64 %rd32, %rd31, %rd30; + xor.b64 %rd33, %rd32, %rd28; + xor.b64 %rd34, %rd20, %rd26; + xor.b64 %rd35, %rd34, -4942790177534073029; + mul.hi.u64 %rd36, %rd18, %rd33; + mul.lo.s64 %rd37, %rd33, -3249550476889527149; + mul.hi.u64 %rd38, %rd21, %rd35; + mul.lo.s64 %rd39, %rd35, -3865633965929787049; + add.s64 %rd40, %rd9, 4354685564936845354; + xor.b64 %rd41, %rd29, %rd40; + xor.b64 %rd42, %rd41, %rd38; + xor.b64 %rd43, %rd27, %rd36; + xor.b64 %rd44, %rd43, 8561163718641405558; + mul.hi.u64 %rd45, %rd18, %rd42; + mul.lo.s64 %rd46, %rd42, -3249550476889527149; + mul.hi.u64 %rd47, %rd21, %rd44; + mul.lo.s64 %rd48, %rd44, -3865633965929787049; + add.s64 %rd49, %rd9, -2691343689449507777; + xor.b64 %rd50, %rd39, %rd49; + xor.b64 %rd51, %rd50, %rd47; + xor.b64 %rd52, %rd37, %rd45; + xor.b64 %rd53, %rd52, 3618373541107332529; + mul.hi.u64 %rd54, %rd18, %rd51; + mul.lo.s64 %rd55, %rd51, -3249550476889527149; + mul.hi.u64 %rd56, %rd21, %rd53; + mul.lo.s64 %rd57, %rd53, -3865633965929787049; + add.s64 %rd58, %rd9, 8709371129873690708; + xor.b64 %rd59, %rd48, %rd58; + xor.b64 %rd60, %rd59, %rd56; + xor.b64 %rd61, %rd46, %rd54; + xor.b64 %rd62, %rd61, -1324416636426740500; + mul.hi.u64 %rd63, %rd18, %rd60; + mul.lo.s64 %rd64, %rd60, -3249550476889527149; + mul.hi.u64 %rd65, %rd21, %rd62; + mul.lo.s64 %rd66, %rd62, -3865633965929787049; + add.s64 %rd67, %rd9, 1663341875487337577; + xor.b64 %rd68, %rd57, %rd67; + xor.b64 %rd69, %rd68, %rd65; + xor.b64 %rd70, %rd55, %rd63; + xor.b64 %rd71, %rd70, -6267206813960813529; + mul.hi.u64 %rd72, %rd18, %rd69; + mul.lo.s64 %rd73, %rd69, -3249550476889527149; + mul.hi.u64 %rd74, %rd21, %rd71; + mul.lo.s64 %rd75, %rd71, -3865633965929787049; + add.s64 %rd76, %rd9, -5382687378899015554; + xor.b64 %rd77, %rd66, %rd76; + xor.b64 %rd78, %rd77, %rd74; + xor.b64 %rd79, %rd64, %rd72; + xor.b64 %rd80, %rd79, 7236747082214665058; + mul.hi.u64 %rd81, %rd18, %rd78; + mul.lo.s64 %rd82, %rd78, -3249550476889527149; + mul.hi.u64 %rd83, %rd21, %rd80; + mul.lo.s64 %rd84, %rd80, -3865633965929787049; + add.s64 %rd85, %rd9, 6018027440424182931; + xor.b64 %rd86, %rd75, %rd85; + xor.b64 %rd87, %rd86, %rd83; + xor.b64 %rd88, %rd73, %rd81; + xor.b64 %rd89, %rd88, 2293956904680592029; + mul.hi.u64 %rd90, %rd18, %rd87; + mul.lo.s64 %rd91, %rd87, -3249550476889527149; + mul.hi.u64 %rd92, %rd21, %rd89; + mul.lo.s64 %rd93, %rd89, -3865633965929787049; + add.s64 %rd94, %rd9, -1028001813962170200; + xor.b64 %rd95, %rd84, %rd94; + xor.b64 %rd96, %rd95, %rd92; + xor.b64 %rd97, %rd82, %rd90; + xor.b64 %rd4, %rd97, -2648833272853481000; + mul.hi.u64 %rd98, %rd18, %rd96; + mul.lo.s64 %rd5, %rd96, -3249550476889527149; + mul.hi.u64 %rd99, %rd21, %rd4; + add.s64 %rd100, %rd9, -8074031068348523331; + xor.b64 %rd101, %rd93, %rd100; + xor.b64 %rd102, %rd101, %rd99; + xor.b64 %rd103, %rd91, %rd98; + xor.b64 %rd6, %rd103, -7591623450387554029; + cvt.rn.f64.s64 %fd1, %rd102; + mul.f64 %fd2, %fd1, 0d3C00000000000000; + cvta.to.global.u64 %rd104, %rd8; + shl.b64 %rd105, %rd3, 3; + add.s64 %rd7, %rd104, %rd105; + st.global.f64 [%rd7], %fd2; + add.s64 %rd106, %rd3, 1; + setp.ge.u64 %p3, %rd106, %rd14; + @%p3 bra BB0_3; + + mul.lo.s64 %rd107, %rd4, -3865633965929787049; + cvt.rn.f64.s64 %fd3, %rd107; + mul.f64 %fd4, %fd3, 0d3C00000000000000; + st.global.f64 [%rd7+8], %fd4; + +BB0_3: + add.s64 %rd108, %rd3, 2; + setp.ge.u64 %p4, %rd108, %rd14; + @%p4 bra BB0_5; + + cvt.rn.f64.s64 %fd5, %rd6; + mul.f64 %fd6, %fd5, 0d3C00000000000000; + st.global.f64 [%rd7+16], %fd6; + +BB0_5: + add.s64 %rd109, %rd3, 3; + setp.ge.u64 %p5, %rd109, %rd14; + @%p5 bra BB0_7; + + cvt.rn.f64.s64 %fd7, %rd5; + mul.f64 %fd8, %fd7, 0d3C00000000000000; + st.global.f64 [%rd7+24], %fd8; + +BB0_7: + ret; +} + + // .globl philox_4_64_standard +.visible .entry philox_4_64_standard( + .param .u64 philox_4_64_standard_param_0, + .param .u64 philox_4_64_standard_param_1, + .param .align 8 .b8 philox_4_64_standard_param_2[32], + .param .u64 philox_4_64_standard_param_3 +) +{ + .reg .pred %p<7>; + .reg .b32 %r<5>; + .reg .f64 %fd<21>; + .reg .b64 %rd<191>; + + + ld.param.u64 %rd14, [philox_4_64_standard_param_0]; + ld.param.u64 %rd15, [philox_4_64_standard_param_1]; + ld.param.u64 %rd1, [philox_4_64_standard_param_2]; + ld.param.u64 %rd2, [philox_4_64_standard_param_2+8]; + ld.param.u64 %rd3, [philox_4_64_standard_param_2+16]; + ld.param.u64 %rd4, [philox_4_64_standard_param_2+24]; + ld.param.u64 %rd16, [philox_4_64_standard_param_3]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r1, %r2, %r3; + cvt.u64.u32 %rd5, %r4; + mul.wide.u32 %rd6, %r4, 4; + setp.ge.u64 %p1, %rd6, %rd16; + @%p1 bra BB1_7; + + add.s64 %rd17, %rd1, %rd5; + setp.lt.u64 %p2, %rd17, %rd1; + selp.u64 %rd18, 1, 0, %p2; + add.s64 %rd19, %rd18, %rd2; + add.s64 %rd20, %rd5, %rd16; + add.s64 %rd21, %rd1, %rd20; + setp.lt.u64 %p3, %rd21, %rd1; + selp.u64 %rd22, 1, 0, %p3; + add.s64 %rd23, %rd22, %rd2; + mov.u64 %rd24, -3249550476889527149; + mul.hi.u64 %rd25, %rd24, %rd17; + mul.lo.s64 %rd26, %rd17, -3249550476889527149; + xor.b64 %rd27, %rd19, %rd15; + mov.u64 %rd28, -3865633965929787049; + mul.hi.u64 %rd29, %rd28, %rd3; + xor.b64 %rd30, %rd27, %rd29; + xor.b64 %rd31, %rd25, %rd4; + mul.hi.u64 %rd32, %rd24, %rd30; + mul.lo.s64 %rd33, %rd30, -3249550476889527149; + mul.hi.u64 %rd34, %rd28, %rd31; + mul.lo.s64 %rd35, %rd31, -3865633965929787049; + add.s64 %rd36, %rd15, -7046029254386353131; + mul.lo.s64 %rd37, %rd3, -3865633965929787049; + xor.b64 %rd38, %rd37, %rd36; + xor.b64 %rd39, %rd38, %rd34; + xor.b64 %rd40, %rd26, %rd32; + xor.b64 %rd41, %rd40, -4942790177534073029; + mul.hi.u64 %rd42, %rd24, %rd39; + mul.lo.s64 %rd43, %rd39, -3249550476889527149; + mul.hi.u64 %rd44, %rd28, %rd41; + mul.lo.s64 %rd45, %rd41, -3865633965929787049; + add.s64 %rd46, %rd15, 4354685564936845354; + xor.b64 %rd47, %rd35, %rd46; + xor.b64 %rd48, %rd47, %rd44; + xor.b64 %rd49, %rd33, %rd42; + xor.b64 %rd50, %rd49, 8561163718641405558; + mul.hi.u64 %rd51, %rd24, %rd48; + mul.lo.s64 %rd52, %rd48, -3249550476889527149; + mul.hi.u64 %rd53, %rd28, %rd50; + mul.lo.s64 %rd54, %rd50, -3865633965929787049; + add.s64 %rd55, %rd15, -2691343689449507777; + xor.b64 %rd56, %rd45, %rd55; + xor.b64 %rd57, %rd56, %rd53; + xor.b64 %rd58, %rd43, %rd51; + xor.b64 %rd59, %rd58, 3618373541107332529; + mul.hi.u64 %rd60, %rd24, %rd57; + mul.lo.s64 %rd61, %rd57, -3249550476889527149; + mul.hi.u64 %rd62, %rd28, %rd59; + mul.lo.s64 %rd63, %rd59, -3865633965929787049; + add.s64 %rd64, %rd15, 8709371129873690708; + xor.b64 %rd65, %rd54, %rd64; + xor.b64 %rd66, %rd65, %rd62; + xor.b64 %rd67, %rd52, %rd60; + xor.b64 %rd68, %rd67, -1324416636426740500; + mul.hi.u64 %rd69, %rd24, %rd66; + mul.lo.s64 %rd70, %rd66, -3249550476889527149; + mul.hi.u64 %rd71, %rd28, %rd68; + mul.lo.s64 %rd72, %rd68, -3865633965929787049; + add.s64 %rd73, %rd15, 1663341875487337577; + xor.b64 %rd74, %rd63, %rd73; + xor.b64 %rd75, %rd74, %rd71; + xor.b64 %rd76, %rd61, %rd69; + xor.b64 %rd77, %rd76, -6267206813960813529; + mul.hi.u64 %rd78, %rd24, %rd75; + mul.lo.s64 %rd79, %rd75, -3249550476889527149; + mul.hi.u64 %rd80, %rd28, %rd77; + mul.lo.s64 %rd81, %rd77, -3865633965929787049; + add.s64 %rd82, %rd15, -5382687378899015554; + xor.b64 %rd83, %rd72, %rd82; + xor.b64 %rd84, %rd83, %rd80; + xor.b64 %rd85, %rd70, %rd78; + xor.b64 %rd86, %rd85, 7236747082214665058; + mul.hi.u64 %rd87, %rd24, %rd84; + mul.lo.s64 %rd88, %rd84, -3249550476889527149; + mul.hi.u64 %rd89, %rd28, %rd86; + mul.lo.s64 %rd90, %rd86, -3865633965929787049; + add.s64 %rd91, %rd15, 6018027440424182931; + xor.b64 %rd92, %rd81, %rd91; + xor.b64 %rd93, %rd92, %rd89; + xor.b64 %rd94, %rd79, %rd87; + xor.b64 %rd95, %rd94, 2293956904680592029; + mul.hi.u64 %rd96, %rd24, %rd93; + mul.lo.s64 %rd97, %rd93, -3249550476889527149; + mul.hi.u64 %rd98, %rd28, %rd95; + mul.lo.s64 %rd99, %rd95, -3865633965929787049; + add.s64 %rd100, %rd15, -1028001813962170200; + xor.b64 %rd101, %rd90, %rd100; + xor.b64 %rd102, %rd101, %rd98; + xor.b64 %rd103, %rd88, %rd96; + xor.b64 %rd7, %rd103, -2648833272853481000; + mul.hi.u64 %rd104, %rd24, %rd102; + mul.lo.s64 %rd8, %rd102, -3249550476889527149; + mul.hi.u64 %rd105, %rd28, %rd7; + add.s64 %rd106, %rd15, -8074031068348523331; + xor.b64 %rd107, %rd99, %rd106; + xor.b64 %rd108, %rd107, %rd105; + xor.b64 %rd109, %rd97, %rd104; + xor.b64 %rd9, %rd109, -7591623450387554029; + mul.hi.u64 %rd110, %rd24, %rd21; + mul.lo.s64 %rd111, %rd21, -3249550476889527149; + xor.b64 %rd112, %rd29, %rd15; + xor.b64 %rd113, %rd112, %rd23; + xor.b64 %rd114, %rd110, %rd4; + mul.hi.u64 %rd115, %rd24, %rd113; + mul.lo.s64 %rd116, %rd113, -3249550476889527149; + mul.hi.u64 %rd117, %rd28, %rd114; + mul.lo.s64 %rd118, %rd114, -3865633965929787049; + xor.b64 %rd119, %rd38, %rd117; + xor.b64 %rd120, %rd111, %rd115; + xor.b64 %rd121, %rd120, -4942790177534073029; + mul.hi.u64 %rd122, %rd24, %rd119; + mul.lo.s64 %rd123, %rd119, -3249550476889527149; + mul.hi.u64 %rd124, %rd28, %rd121; + mul.lo.s64 %rd125, %rd121, -3865633965929787049; + xor.b64 %rd126, %rd118, %rd46; + xor.b64 %rd127, %rd126, %rd124; + xor.b64 %rd128, %rd116, %rd122; + xor.b64 %rd129, %rd128, 8561163718641405558; + mul.hi.u64 %rd130, %rd24, %rd127; + mul.lo.s64 %rd131, %rd127, -3249550476889527149; + mul.hi.u64 %rd132, %rd28, %rd129; + mul.lo.s64 %rd133, %rd129, -3865633965929787049; + xor.b64 %rd134, %rd125, %rd55; + xor.b64 %rd135, %rd134, %rd132; + xor.b64 %rd136, %rd123, %rd130; + xor.b64 %rd137, %rd136, 3618373541107332529; + mul.hi.u64 %rd138, %rd24, %rd135; + mul.lo.s64 %rd139, %rd135, -3249550476889527149; + mul.hi.u64 %rd140, %rd28, %rd137; + mul.lo.s64 %rd141, %rd137, -3865633965929787049; + xor.b64 %rd142, %rd133, %rd64; + xor.b64 %rd143, %rd142, %rd140; + xor.b64 %rd144, %rd131, %rd138; + xor.b64 %rd145, %rd144, -1324416636426740500; + mul.hi.u64 %rd146, %rd24, %rd143; + mul.lo.s64 %rd147, %rd143, -3249550476889527149; + mul.hi.u64 %rd148, %rd28, %rd145; + mul.lo.s64 %rd149, %rd145, -3865633965929787049; + xor.b64 %rd150, %rd141, %rd73; + xor.b64 %rd151, %rd150, %rd148; + xor.b64 %rd152, %rd139, %rd146; + xor.b64 %rd153, %rd152, -6267206813960813529; + mul.hi.u64 %rd154, %rd24, %rd151; + mul.lo.s64 %rd155, %rd151, -3249550476889527149; + mul.hi.u64 %rd156, %rd28, %rd153; + mul.lo.s64 %rd157, %rd153, -3865633965929787049; + xor.b64 %rd158, %rd149, %rd82; + xor.b64 %rd159, %rd158, %rd156; + xor.b64 %rd160, %rd147, %rd154; + xor.b64 %rd161, %rd160, 7236747082214665058; + mul.hi.u64 %rd162, %rd24, %rd159; + mul.lo.s64 %rd163, %rd159, -3249550476889527149; + mul.hi.u64 %rd164, %rd28, %rd161; + mul.lo.s64 %rd165, %rd161, -3865633965929787049; + xor.b64 %rd166, %rd157, %rd91; + xor.b64 %rd167, %rd166, %rd164; + xor.b64 %rd168, %rd155, %rd162; + xor.b64 %rd169, %rd168, 2293956904680592029; + mul.hi.u64 %rd170, %rd24, %rd167; + mul.lo.s64 %rd171, %rd167, -3249550476889527149; + mul.hi.u64 %rd172, %rd28, %rd169; + mul.lo.s64 %rd173, %rd169, -3865633965929787049; + xor.b64 %rd174, %rd165, %rd100; + xor.b64 %rd175, %rd174, %rd172; + xor.b64 %rd176, %rd163, %rd170; + xor.b64 %rd10, %rd176, -2648833272853481000; + mul.hi.u64 %rd177, %rd24, %rd175; + mul.lo.s64 %rd11, %rd175, -3249550476889527149; + mul.hi.u64 %rd178, %rd28, %rd10; + xor.b64 %rd179, %rd173, %rd106; + xor.b64 %rd180, %rd179, %rd178; + xor.b64 %rd181, %rd171, %rd177; + xor.b64 %rd12, %rd181, -7591623450387554029; + cvt.rn.f64.s64 %fd1, %rd180; + cvt.rn.f64.s64 %fd2, %rd108; + mul.f64 %fd3, %fd2, 0d3C00000000000000; + fma.rn.f64 %fd4, %fd1, 0d3C00000000000000, %fd3; + mul.f64 %fd5, %fd4, 0d3FE0000000000000; + cvta.to.global.u64 %rd182, %rd14; + shl.b64 %rd183, %rd6, 3; + add.s64 %rd13, %rd182, %rd183; + st.global.f64 [%rd13], %fd5; + add.s64 %rd184, %rd6, 1; + setp.ge.u64 %p4, %rd184, %rd16; + @%p4 bra BB1_3; + + mul.lo.s64 %rd185, %rd10, -3865633965929787049; + cvt.rn.f64.s64 %fd6, %rd185; + mul.lo.s64 %rd186, %rd7, -3865633965929787049; + cvt.rn.f64.s64 %fd7, %rd186; + mul.f64 %fd8, %fd7, 0d3C00000000000000; + fma.rn.f64 %fd9, %fd6, 0d3C00000000000000, %fd8; + mul.f64 %fd10, %fd9, 0d3FE0000000000000; + st.global.f64 [%rd13+8], %fd10; + +BB1_3: + ld.param.u64 %rd189, [philox_4_64_standard_param_3]; + add.s64 %rd187, %rd6, 2; + setp.ge.u64 %p5, %rd187, %rd189; + @%p5 bra BB1_5; + + cvt.rn.f64.s64 %fd11, %rd12; + cvt.rn.f64.s64 %fd12, %rd9; + mul.f64 %fd13, %fd12, 0d3C00000000000000; + fma.rn.f64 %fd14, %fd11, 0d3C00000000000000, %fd13; + mul.f64 %fd15, %fd14, 0d3FE0000000000000; + st.global.f64 [%rd13+16], %fd15; + +BB1_5: + ld.param.u64 %rd190, [philox_4_64_standard_param_3]; + add.s64 %rd188, %rd6, 3; + setp.ge.u64 %p6, %rd188, %rd190; + @%p6 bra BB1_7; + + cvt.rn.f64.s64 %fd16, %rd11; + cvt.rn.f64.s64 %fd17, %rd8; + mul.f64 %fd18, %fd17, 0d3C00000000000000; + fma.rn.f64 %fd19, %fd16, 0d3C00000000000000, %fd18; + mul.f64 %fd20, %fd19, 0d3FE0000000000000; + st.global.f64 [%rd13+24], %fd20; + +BB1_7: + ret; +} + + // .globl philox_4_32 +.visible .entry philox_4_32( + .param .u64 philox_4_32_param_0, + .param .u32 philox_4_32_param_1, + .param .u32 philox_4_32_param_2, + .param .u64 philox_4_32_param_3 +) +{ + .reg .pred %p<5>; + .reg .b32 %r<60>; + .reg .b64 %rd<74>; + + + ld.param.u64 %rd5, [philox_4_32_param_0]; + ld.param.u32 %r5, [philox_4_32_param_1]; + ld.param.u32 %r6, [philox_4_32_param_2]; + ld.param.u64 %rd6, [philox_4_32_param_3]; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %ctaid.x; + mov.u32 %r9, %tid.x; + mad.lo.s32 %r1, %r7, %r8, %r9; + shl.b32 %r2, %r1, 2; + cvt.u64.u32 %rd1, %r2; + setp.ge.u64 %p1, %rd1, %rd6; + @%p1 bra BB2_7; + + cvta.to.global.u64 %rd7, %rd5; + add.s32 %r10, %r1, %r6; + mul.wide.u32 %rd8, %r10, -766435501; + shr.u64 %rd9, %rd8, 32; + mul.wide.u32 %rd10, %r5, -766435501; + shr.u64 %rd11, %rd10, 32; + mul.lo.s64 %rd12, %rd9, 3449720151; + shr.u64 %rd13, %rd12, 32; + cvt.u32.u64 %r11, %rd13; + cvt.u32.u64 %r12, %rd12; + add.s32 %r13, %r5, -1640531527; + xor.b32 %r14, %r11, %r13; + mul.wide.u32 %rd14, %r14, -766435501; + shr.u64 %rd15, %rd14, 32; + and.b64 %rd16, %rd8, 4294967295; + xor.b64 %rd17, %rd11, %rd16; + xor.b64 %rd18, %rd17, 3144134277; + mul.lo.s64 %rd19, %rd18, 3449720151; + shr.u64 %rd20, %rd19, 32; + cvt.u32.u64 %r15, %rd20; + cvt.u32.u64 %r16, %rd19; + add.s32 %r17, %r5, 1013904242; + xor.b32 %r18, %r12, %r17; + xor.b32 %r19, %r18, %r15; + mul.wide.u32 %rd21, %r19, -766435501; + shr.u64 %rd22, %rd21, 32; + and.b64 %rd23, %rd10, 4294967295; + xor.b64 %rd24, %rd23, %rd15; + xor.b64 %rd25, %rd24, 1993301258; + mul.lo.s64 %rd26, %rd25, 3449720151; + shr.u64 %rd27, %rd26, 32; + cvt.u32.u64 %r20, %rd27; + cvt.u32.u64 %r21, %rd26; + add.s32 %r22, %r5, -626627285; + xor.b32 %r23, %r16, %r22; + xor.b32 %r24, %r23, %r20; + mul.wide.u32 %rd28, %r24, -766435501; + shr.u64 %rd29, %rd28, 32; + and.b64 %rd30, %rd14, 4294967295; + xor.b64 %rd31, %rd30, %rd22; + xor.b64 %rd32, %rd31, 842468239; + mul.lo.s64 %rd33, %rd32, 3449720151; + shr.u64 %rd34, %rd33, 32; + cvt.u32.u64 %r25, %rd34; + cvt.u32.u64 %r26, %rd33; + add.s32 %r27, %r5, 2027808484; + xor.b32 %r28, %r21, %r27; + xor.b32 %r29, %r28, %r25; + mul.wide.u32 %rd35, %r29, -766435501; + shr.u64 %rd36, %rd35, 32; + and.b64 %rd37, %rd21, 4294967295; + xor.b64 %rd38, %rd37, %rd29; + xor.b64 %rd39, %rd38, 3986602516; + mul.lo.s64 %rd40, %rd39, 3449720151; + shr.u64 %rd41, %rd40, 32; + cvt.u32.u64 %r30, %rd41; + cvt.u32.u64 %r31, %rd40; + add.s32 %r32, %r5, 387276957; + xor.b32 %r33, %r26, %r32; + xor.b32 %r34, %r33, %r30; + mul.wide.u32 %rd42, %r34, -766435501; + shr.u64 %rd43, %rd42, 32; + and.b64 %rd44, %rd28, 4294967295; + xor.b64 %rd45, %rd44, %rd36; + xor.b64 %rd46, %rd45, 2835769497; + mul.lo.s64 %rd47, %rd46, 3449720151; + shr.u64 %rd48, %rd47, 32; + cvt.u32.u64 %r35, %rd48; + cvt.u32.u64 %r36, %rd47; + add.s32 %r37, %r5, -1253254570; + xor.b32 %r38, %r31, %r37; + xor.b32 %r39, %r38, %r35; + mul.wide.u32 %rd49, %r39, -766435501; + shr.u64 %rd50, %rd49, 32; + and.b64 %rd51, %rd35, 4294967295; + xor.b64 %rd52, %rd51, %rd43; + xor.b64 %rd53, %rd52, 1684936478; + mul.lo.s64 %rd54, %rd53, 3449720151; + shr.u64 %rd55, %rd54, 32; + cvt.u32.u64 %r40, %rd55; + cvt.u32.u64 %r41, %rd54; + add.s32 %r42, %r5, 1401181199; + xor.b32 %r43, %r36, %r42; + xor.b32 %r44, %r43, %r40; + mul.wide.u32 %rd56, %r44, -766435501; + shr.u64 %rd57, %rd56, 32; + cvt.u32.u64 %r3, %rd56; + and.b64 %rd58, %rd42, 4294967295; + xor.b64 %rd59, %rd58, %rd50; + xor.b64 %rd60, %rd59, 534103459; + mul.lo.s64 %rd61, %rd60, 3449720151; + shr.u64 %rd62, %rd61, 32; + cvt.u32.u64 %r45, %rd62; + cvt.u32.u64 %r46, %rd61; + add.s32 %r47, %r5, -239350328; + xor.b32 %r48, %r41, %r47; + xor.b32 %r4, %r48, %r45; + and.b64 %rd63, %rd49, 4294967295; + xor.b64 %rd64, %rd63, %rd57; + xor.b64 %rd65, %rd64, 3678237736; + mul.lo.s64 %rd2, %rd65, 3449720151; + shr.u64 %rd66, %rd2, 32; + cvt.u32.u64 %r49, %rd66; + add.s32 %r50, %r5, -1879881855; + xor.b32 %r51, %r46, %r50; + xor.b32 %r52, %r51, %r49; + shl.b64 %rd67, %rd1, 2; + add.s64 %rd68, %rd7, %rd67; + st.global.u32 [%rd68], %r52; + add.s32 %r53, %r2, 1; + cvt.u64.u32 %rd69, %r53; + mul.wide.u32 %rd70, %r2, 4; + add.s64 %rd3, %rd7, %rd70; + setp.ge.u64 %p2, %rd69, %rd6; + @%p2 bra BB2_3; + + st.global.u32 [%rd3+4], %rd2; + +BB2_3: + mul.wide.u32 %rd4, %r4, -766435501; + add.s32 %r54, %r2, 2; + cvt.u64.u32 %rd71, %r54; + setp.ge.u64 %p3, %rd71, %rd6; + @%p3 bra BB2_5; + + shr.u64 %rd72, %rd4, 32; + cvt.u32.u64 %r55, %rd72; + xor.b32 %r56, %r3, %r55; + xor.b32 %r57, %r56, -1767562579; + st.global.u32 [%rd3+8], %r57; + +BB2_5: + cvt.u32.u64 %r58, %rd1; + add.s32 %r59, %r58, 3; + cvt.u64.u32 %rd73, %r59; + setp.ge.u64 %p4, %rd73, %rd6; + @%p4 bra BB2_7; + + st.global.u32 [%rd3+12], %rd4; + +BB2_7: + ret; +} + + // .globl philox_4_64 +.visible .entry philox_4_64( + .param .u64 philox_4_64_param_0, + .param .u64 philox_4_64_param_1, + .param .u64 philox_4_64_param_2, + .param .u64 philox_4_64_param_3 +) +{ + .reg .pred %p<5>; + .reg .b32 %r<5>; + .reg .b64 %rd<101>; + + + ld.param.u64 %rd7, [philox_4_64_param_0]; + ld.param.u64 %rd8, [philox_4_64_param_1]; + ld.param.u64 %rd9, [philox_4_64_param_2]; + ld.param.u64 %rd10, [philox_4_64_param_3]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r1, %r2, %r3; + cvt.u64.u32 %rd1, %r4; + mul.wide.u32 %rd2, %r4, 4; + setp.ge.u64 %p1, %rd2, %rd10; + @%p1 bra BB3_7; + + add.s64 %rd11, %rd1, %rd9; + mov.u64 %rd12, -3249550476889527149; + mul.hi.u64 %rd13, %rd12, %rd11; + mul.lo.s64 %rd14, %rd11, -3249550476889527149; + mov.u64 %rd15, 0; + mov.u64 %rd16, -3865633965929787049; + mul.hi.u64 %rd17, %rd16, %rd15; + xor.b64 %rd18, %rd17, %rd8; + mul.hi.u64 %rd19, %rd12, %rd18; + mul.lo.s64 %rd20, %rd18, -3249550476889527149; + mul.hi.u64 %rd21, %rd16, %rd13; + mul.lo.s64 %rd22, %rd13, -3865633965929787049; + add.s64 %rd23, %rd8, -7046029254386353131; + xor.b64 %rd24, %rd21, %rd23; + xor.b64 %rd25, %rd14, %rd19; + xor.b64 %rd26, %rd25, -4942790177534073029; + mul.hi.u64 %rd27, %rd12, %rd24; + mul.lo.s64 %rd28, %rd24, -3249550476889527149; + mul.hi.u64 %rd29, %rd16, %rd26; + mul.lo.s64 %rd30, %rd26, -3865633965929787049; + add.s64 %rd31, %rd8, 4354685564936845354; + xor.b64 %rd32, %rd22, %rd31; + xor.b64 %rd33, %rd32, %rd29; + xor.b64 %rd34, %rd20, %rd27; + xor.b64 %rd35, %rd34, 8561163718641405558; + mul.hi.u64 %rd36, %rd12, %rd33; + mul.lo.s64 %rd37, %rd33, -3249550476889527149; + mul.hi.u64 %rd38, %rd16, %rd35; + mul.lo.s64 %rd39, %rd35, -3865633965929787049; + add.s64 %rd40, %rd8, -2691343689449507777; + xor.b64 %rd41, %rd30, %rd40; + xor.b64 %rd42, %rd41, %rd38; + xor.b64 %rd43, %rd28, %rd36; + xor.b64 %rd44, %rd43, 3618373541107332529; + mul.hi.u64 %rd45, %rd12, %rd42; + mul.lo.s64 %rd46, %rd42, -3249550476889527149; + mul.hi.u64 %rd47, %rd16, %rd44; + mul.lo.s64 %rd48, %rd44, -3865633965929787049; + add.s64 %rd49, %rd8, 8709371129873690708; + xor.b64 %rd50, %rd39, %rd49; + xor.b64 %rd51, %rd50, %rd47; + xor.b64 %rd52, %rd37, %rd45; + xor.b64 %rd53, %rd52, -1324416636426740500; + mul.hi.u64 %rd54, %rd12, %rd51; + mul.lo.s64 %rd55, %rd51, -3249550476889527149; + mul.hi.u64 %rd56, %rd16, %rd53; + mul.lo.s64 %rd57, %rd53, -3865633965929787049; + add.s64 %rd58, %rd8, 1663341875487337577; + xor.b64 %rd59, %rd48, %rd58; + xor.b64 %rd60, %rd59, %rd56; + xor.b64 %rd61, %rd46, %rd54; + xor.b64 %rd62, %rd61, -6267206813960813529; + mul.hi.u64 %rd63, %rd12, %rd60; + mul.lo.s64 %rd64, %rd60, -3249550476889527149; + mul.hi.u64 %rd65, %rd16, %rd62; + mul.lo.s64 %rd66, %rd62, -3865633965929787049; + add.s64 %rd67, %rd8, -5382687378899015554; + xor.b64 %rd68, %rd57, %rd67; + xor.b64 %rd69, %rd68, %rd65; + xor.b64 %rd70, %rd55, %rd63; + xor.b64 %rd71, %rd70, 7236747082214665058; + mul.hi.u64 %rd72, %rd12, %rd69; + mul.lo.s64 %rd73, %rd69, -3249550476889527149; + mul.hi.u64 %rd74, %rd16, %rd71; + mul.lo.s64 %rd75, %rd71, -3865633965929787049; + add.s64 %rd76, %rd8, 6018027440424182931; + xor.b64 %rd77, %rd66, %rd76; + xor.b64 %rd78, %rd77, %rd74; + xor.b64 %rd79, %rd64, %rd72; + xor.b64 %rd80, %rd79, 2293956904680592029; + mul.hi.u64 %rd81, %rd12, %rd78; + mul.lo.s64 %rd82, %rd78, -3249550476889527149; + mul.hi.u64 %rd83, %rd16, %rd80; + mul.lo.s64 %rd84, %rd80, -3865633965929787049; + add.s64 %rd85, %rd8, -1028001813962170200; + xor.b64 %rd86, %rd75, %rd85; + xor.b64 %rd87, %rd86, %rd83; + xor.b64 %rd88, %rd73, %rd81; + xor.b64 %rd3, %rd88, -2648833272853481000; + mul.hi.u64 %rd89, %rd12, %rd87; + mul.lo.s64 %rd4, %rd87, -3249550476889527149; + mul.hi.u64 %rd90, %rd16, %rd3; + add.s64 %rd91, %rd8, -8074031068348523331; + xor.b64 %rd92, %rd84, %rd91; + xor.b64 %rd93, %rd92, %rd90; + xor.b64 %rd94, %rd82, %rd89; + xor.b64 %rd5, %rd94, -7591623450387554029; + cvta.to.global.u64 %rd95, %rd7; + shl.b64 %rd96, %rd2, 3; + add.s64 %rd6, %rd95, %rd96; + st.global.u64 [%rd6], %rd93; + add.s64 %rd97, %rd2, 1; + setp.ge.u64 %p2, %rd97, %rd10; + @%p2 bra BB3_3; + + mul.lo.s64 %rd98, %rd3, -3865633965929787049; + st.global.u64 [%rd6+8], %rd98; + +BB3_3: + add.s64 %rd99, %rd2, 2; + setp.ge.u64 %p3, %rd99, %rd10; + @%p3 bra BB3_5; + + st.global.u64 [%rd6+16], %rd5; + +BB3_5: + add.s64 %rd100, %rd2, 3; + setp.ge.u64 %p4, %rd100, %rd10; + @%p4 bra BB3_7; + + st.global.u64 [%rd6+24], %rd4; + +BB3_7: + ret; +} + + diff --git a/scripts/staging/cuda-counter-based-prng/pom.xml b/scripts/staging/cuda-counter-based-prng/pom.xml new file mode 100644 index 00000000000..58f3530c380 --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/pom.xml @@ -0,0 +1,38 @@ + + + + 4.0.0 + + org.apache.systemds + CudaCounterBasedRandom + 1.0-SNAPSHOT + jar + + + + org.jcuda + jcuda + 10.2.0 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.1 + + 1.8 + 1.8 + + + + org.apache.maven.plugins + maven-jar-plugin + + + + diff --git a/scripts/staging/cuda-counter-based-prng/readme.md b/scripts/staging/cuda-counter-based-prng/readme.md new file mode 100644 index 00000000000..1cd04be6b7b --- /dev/null +++ b/scripts/staging/cuda-counter-based-prng/readme.md @@ -0,0 +1,410 @@ +# CUDA counter based PRNG + +Currently, random matrix generation is done using Java implementations. Either the Java Random class or the custom +counter based Philox4x64 implementation is used. This is not efficient for large matrices because first, Java is slow +and second, the matrix has to be copied from the main memory to the GPUs memory for performing matrix operations there. +We propose to implement a counter-based PRNG on CUDA to generate random matrices directly on the GPU. + +To be consistent with the current counter based PRNG implementation, we will use the Philox4x64 algorithm. +Unfortunately, the CUDA curand library is not open source, and we failed to replicate the numbers generated by the +curand library using a Java implementation. We therefore propose to use the random123 library, which is an open-source +library that implements the Philox4x64 algorithm under BSD-3 license. The random123 library is available +at https://github.com/DEShawResearch/random123. It is well tested using statistical tests as described in the +paper [Parallel random numbers: as easy as 1, 2, 3](https://doi.org/10.1145/2063384.2063405). + +## How to implement + +There are two ways how to integrate cuda kernels into the SystemDS project. The first way is to ship a precompiled +cuda ptx file with the SystemDS project. This has the drawback that the cuda ptx file has to be compiled for each +cuda version and each gpu architecture. + +The second way is to compile the cuda kernels during runtime. This means, the cuda build tools need to be installed +on the system where the SystemDS project is running, but the cuda ptx file can be compiled for the specific cuda +version and gpu architecture. + +### Precompiled cuda ptx file + +Example cuda kernel: + +```c++ +extern "C" __global__ void philox_4_64(ulong* output, uint64_t seed, uint64_t startingCounter, size_t numElements) { + // Calculate the thread's unique index + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Ensure the thread index is within bounds + if (idx * 4 < numElements) { + // Initialize the Philox generator with a unique counter and key + r123::Philox4x64 rng; + r123::Philox4x64::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}}; // Counter (startingCounter + thread index) + r123::Philox4x64::key_type key = {{seed}}; // Key (seed) + + // Generate 4 random integers + r123::Philox4x64::ctr_type result = rng(ctr, key); + + for (int i = 0; i < 4; ++i) { + size_t outputIdx = idx * 4 + i; + + // Ensure we don't exceed the output array bounds + if (outputIdx < numElements) { + output[outputIdx] = result[i]; + } + } + } +} +``` + +To compile the cuda kernel to a ptx file, you can use the following command: + +```bash +/usr/local/cuda/bin/nvcc kernel.cu -ccbin gcc-8 -lstdc++ -I ./random123/include -o cuda_test.ptx -lm --ptx -std=c++11 --gpu-architecture=sm_70 +``` + +This will compile the cuda kernel to a ptx file that can be shipped with the SystemDS project. + +```ptx +.version 6.5 +.target sm_70 +.address_size 64 + +.visible .entry philox_4_64( + .param .u64 philox_4_64_param_0, + .param .u64 philox_4_64_param_1, + .param .u64 philox_4_64_param_2, + .param .u64 philox_4_64_param_3 +) +{ + ... cuda kernel code ... +} + +``` +To use this ptx file in the SystemDS project, you can use the following code: + +```java +import jcuda.*; +import jcuda.driver.*; +import jcuda.nvrtc.*; +import jcuda.runtime.JCuda; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; + +import static jcuda.driver.JCudaDriver.cuCtxCreate; + +public class PhiloxJNvrtcExample { + + public static void main(String[] args) { + // Enable exceptions and omit error checks + JCuda.setExceptionsEnabled(true); + JCudaDriver.setExceptionsEnabled(true); + JNvrtc.setExceptionsEnabled(true); + + String ptx = ""; + try { + ptx = new String(Files.readAllBytes(Paths.get("philox_kernel.ptx"))); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + + // Initialize the driver API and create a context + JCudaDriver.cuInit(0); + CUdevice device = new CUdevice(); + JCudaDriver.cuDeviceGet(device, 0); + CUcontext context = new CUcontext(); + cuCtxCreate(context, 0, device); + + CUmodule module = new CUmodule(); + JCudaDriver.cuModuleLoadData(module, ptx); + + // Get a function pointer to the kernel + CUfunction function = new CUfunction(); + JCudaDriver.cuModuleGetFunction(function, module, "philox_4_64"); + + // Prepare data + int n = 1000; // Number of random numbers to generate + long[] hostOut = new long[n]; + CUdeviceptr deviceOut = new CUdeviceptr(); + JCudaDriver.cuMemAlloc(deviceOut, n * Sizeof.LONG); + + // Direkte Werte für seed und startingCounter + long seed = 0L; // Fester Seed-Wert + long startingCounter = 0L; // Startwert für Counter + + Pointer kernelParameters = Pointer.to( + Pointer.to(deviceOut), // ulong* output + Pointer.to(new long[]{seed}), // uint64_t seed + Pointer.to(new long[]{startingCounter}), // uint64_t startingCounter + Pointer.to(new long[]{n}) // size_t numElements + ); + + // Launch the kernel + int blockSizeX = 128; + int gridSizeX = (int) Math.ceil((double)n / blockSizeX); + JCudaDriver.cuLaunchKernel( + function, + gridSizeX, 1, 1, // Grid dimension + blockSizeX, 1, 1, // Block dimension + 0, null, // Shared memory size and stream + kernelParameters, null // Kernel- und extra parameters + ); + JCudaDriver.cuCtxSynchronize(); + + // Copy result back + JCudaDriver.cuMemcpyDtoH(Pointer.to(hostOut), deviceOut, n * Sizeof.LONG); + + // Print results + System.out.println("Generated random numbers with seed=" + + String.format("0x%016X", seed) + + " and startingCounter=" + startingCounter); + for (int i = 0; i < Math.min(10, n); i++) { + System.out.printf("hostOut[%d] = 0x%016X\n", i, hostOut[i]); + } + + // Cleanup + JCudaDriver.cuMemFree(deviceOut); + JCudaDriver.cuCtxDestroy(context); + } +} +``` + +Run the code with the following command: + +```bash +javac -cp .:./target/dependency/jcuda-10.2.0.jar:./target/dependency/jcuda-natives-10.2.0-linux-x86_64.jar PhiloxJNvrtcExample.java && java -cp .:./target/dependency/jcuda-10.2.0.jar:./target/dependency/jcuda-natives-10.2.0-linux-x86_64.jar PhiloxJNvrtcExample +``` + +### Compile cuda kernels during runtime + +To compile the cuda kernel during runtime, you can use the following code: + +```java +import jcuda.*; +import jcuda.driver.*; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static java.nio.file.Files.readAllBytes; +import static jcuda.driver.JCudaDriver.*; + +public class Random123_cuda implements AutoCloseable { + private static String philox4x64KernelSource = "#include \n" + + "#include \n" + + "extern \"C\" __global__ void philox_4_64(ulong* output, uint64_t startingCounter, uint64_t seed, size_t numElements) {\n" + + + " uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n" + + " if (idx * 4 < numElements) {\n" + + " r123::Philox4x64 rng;\n" + + " r123::Philox4x64::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}};\n" + + " r123::Philox4x64::key_type key = {{seed}};\n" + + " r123::Philox4x64::ctr_type result = rng(ctr, key);\n" + + " for (int i = 0; i < 4; ++i) {\n" + + " size_t outputIdx = idx * 4 + i;\n" + + " if (outputIdx < numElements) {\n" + + " output[outputIdx] = result[i];\n" + + " }\n" + + " }\n" + + " }\n" + + "}\n"; + + private final CUcontext context; + private final CUmodule module; + private final CUfunction function; + private final int blockSize; + + public Random123_cuda() { + JCudaDriver.setExceptionsEnabled(true); + // Initialize CUDA + cuInit(0); + CUdevice device = new CUdevice(); + cuDeviceGet(device, 0); + context = new CUcontext(); + int result = cuCtxCreate(context, 0, device); + if (result != CUresult.CUDA_SUCCESS) { + throw new RuntimeException( + "Faild to create CUDA context: " + result + ", " + CUresult.stringFor(result)); + } + + // Compile to PTX + String ptx = compileToTPX(philox4x64KernelSource); + + // Load the PTX + module = new CUmodule(); + cuModuleLoadData(module, ptx); + function = new CUfunction(); + cuModuleGetFunction(function, module, "philox_4_64"); + + // Set block size based on device capabilities + blockSize = 64; // Can be adjusted based on device properties + } + + private String compileToTPX(String source) { + try { + // create temp files + File sourceFile = File.createTempFile("philox_kernel", ".cu"); + File outputFile = File.createTempFile("philox_kernel", ".ptx"); + + // Write cuda source to temp file + try (FileWriter writer = new FileWriter(sourceFile)) { + writer.write(philox4x64KernelSource); + } + + // build nvcc command + List command = new ArrayList<>(); + command.add("/usr/local/cuda/bin/nvcc"); + command.add("-ccbin"); + command.add("gcc-8"); + command.add("--ptx"); // PTX-Output generieren + command.add("-o"); + command.add(outputFile.getAbsolutePath()); + command.add("-I"); + command.add("./lib/random123/include"); + command.add(sourceFile.getAbsolutePath()); + + ProcessBuilder pb = new ProcessBuilder(command); + pb.redirectErrorStream(true); + Process process = pb.start(); + + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream()))) { + String line; + StringBuilder output = new StringBuilder(); + while ((line = reader.readLine()) != null) { + output.append(line).append("\n"); + } + System.out.println("Compiler Output: " + output.toString()); + } + + int exitCode = process.waitFor(); + if (exitCode != 0) { + throw new RuntimeException("nvcc compiler returned non-zero exit code: " + exitCode); + } + + // Read PTX code + String ptxCode = new String(readAllBytes(outputFile.toPath())); + + // Cleanup + sourceFile.delete(); + outputFile.delete(); + + return ptxCode; + + } catch (Exception e) { + throw new RuntimeException("CUDA-compilation failed: " + e.getMessage(), e); + } + } + + /** + * Generates random numbers using the Philox4x64 algorithm + * + * @param startingCounter Initial counter value + * @param seed Random seed + * @param numElements Number of random numbers to generate + * @return Array of random numbers + */ + public CUdeviceptr Philox4x64(long startingCounter, long seed, int numElements) { + // Allocate host memory for results + // long[] hostOutput = new long[numElements]; + + // Allocate device memory + CUdeviceptr deviceOutput = new CUdeviceptr(); + cuMemAlloc(deviceOutput, (long) numElements * Sizeof.LONG); + + try { + System.out.printf("numElements: %d, seed: %d, startingCounter: %d%n", + numElements, seed, startingCounter); + + Pointer kernelParams = Pointer.to( + Pointer.to(deviceOutput), + Pointer.to(new long[] { startingCounter }), + Pointer.to(new long[] { seed }), + Pointer.to(new long[] { numElements })); + + // Calculate grid size + int gridSize = (numElements + (blockSize * 4) - 1) / (blockSize * 4); + + int kernelResult = cuLaunchKernel(function, + gridSize, 1, 1, // Grid dimension + blockSize, 1, 1, // Block dimension + 0, null, // Shared memory size and stream + kernelParams, null // Kernel parameters and extra parameters + ); + if (kernelResult != CUresult.CUDA_SUCCESS) { + throw new RuntimeException( + "Kernel-launch failed: " + kernelResult + ", " + CUresult.stringFor(kernelResult)); + } + + // Copy results back to host + // cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, (long) numElements * + // Sizeof.LONG); + } finally { + // Free device memory + // cuMemFree(deviceOutput); + } + + // return hostOutput; + return deviceOutput; + } + + /** + * Cleans up CUDA resources + */ + public void close() { + cuModuleUnload(module); + cuCtxDestroy(context); + } + + // Example usage + public static void main(String[] args) { + try (Random123_cuda generator = new Random123_cuda()) { + // Generate 1 million random numbers + int numElements = 1_000_000; + long seed = 0L; + long startingCounter = 0L; + + CUdeviceptr randomNumbers = generator.Philox4x64(startingCounter, seed, numElements); + + long[] elements = new long[10]; + cuMemcpyDtoH(Pointer.to(elements), randomNumbers, 10L * Sizeof.LONG); + cuMemFree(randomNumbers); + + // Print first few numbers + System.out.println("First 10 random numbers:"); + for (int i = 0; i < 10; i++) { + System.out.printf("%d: %x%n", i, elements[i]); + } + + int size = 10_000_000; + long start = System.currentTimeMillis(); + CUdeviceptr ptr = generator.Philox4x64(0L, 0L, size); + long end = System.currentTimeMillis(); + System.out.println("philox4x64 speed test: " + (end - start) * 1000 + " microseconds"); + cuMemFree(ptr); + Random r = new Random(); + long javaStart = System.currentTimeMillis(); + for (int i = 0; i < size; i++) { + r.nextLong(); + } + long javaEnd = System.currentTimeMillis(); + System.out.println("java speed test: " + (javaEnd - javaStart) * 1000 + " microseconds"); + System.out.println("philox4x64 is " + (double) (javaEnd - javaStart) / (double) (end - start) + + " times faster than java"); + + } + } +} +``` + +Run the code with the following command: + +```bash +javac -cp .:./target/dependency/jcuda-10.2.0.jar:./target/dependency/jcuda-natives-10.2.0-linux-x86_64.jar Random123_cuda.java && java -cp .:./target/dependency/jcuda-10.2.0.jar:./target/dependency/jcuda-natives-10.2.0-linux-x86_64.jar Random123_cuda +``` + + From fc578c6cee5207ed1a381bdcd8dd58a92f95b269 Mon Sep 17 00:00:00 2001 From: chris-1187 Date: Tue, 4 Feb 2025 17:19:22 +0100 Subject: [PATCH 12/13] Added licenses, shortened readme.md Signed-off-by: chris-1187 --- .../PhiloxJNvrtcExample.java | 19 + .../PhiloxRuntimeCompilationExample.java | 19 + .../staging/cuda-counter-based-prng/kernel.cu | 19 + .../staging/cuda-counter-based-prng/pom.xml | 18 + .../staging/cuda-counter-based-prng/readme.md | 333 ++---------------- 5 files changed, 97 insertions(+), 311 deletions(-) diff --git a/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java b/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java index 7d1cafee4fe..04855bd83ea 100644 --- a/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java +++ b/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + import jcuda.*; import jcuda.driver.*; import jcuda.nvrtc.*; diff --git a/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java b/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java index 93a1840ba3e..72fb5a05c45 100644 --- a/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java +++ b/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + import jcuda.*; import jcuda.driver.*; diff --git a/scripts/staging/cuda-counter-based-prng/kernel.cu b/scripts/staging/cuda-counter-based-prng/kernel.cu index 456cca39c03..8ecce451c66 100644 --- a/scripts/staging/cuda-counter-based-prng/kernel.cu +++ b/scripts/staging/cuda-counter-based-prng/kernel.cu @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #include #include #include diff --git a/scripts/staging/cuda-counter-based-prng/pom.xml b/scripts/staging/cuda-counter-based-prng/pom.xml index 58f3530c380..ff96112cfc4 100644 --- a/scripts/staging/cuda-counter-based-prng/pom.xml +++ b/scripts/staging/cuda-counter-based-prng/pom.xml @@ -1,4 +1,22 @@ + diff --git a/scripts/staging/cuda-counter-based-prng/readme.md b/scripts/staging/cuda-counter-based-prng/readme.md index 1cd04be6b7b..b8649d6a913 100644 --- a/scripts/staging/cuda-counter-based-prng/readme.md +++ b/scripts/staging/cuda-counter-based-prng/readme.md @@ -1,3 +1,22 @@ + + # CUDA counter based PRNG Currently, random matrix generation is done using Java implementations. Either the Java Random class or the custom @@ -77,96 +96,9 @@ This will compile the cuda kernel to a ptx file that can be shipped with the Sys } ``` -To use this ptx file in the SystemDS project, you can use the following code: - -```java -import jcuda.*; -import jcuda.driver.*; -import jcuda.nvrtc.*; -import jcuda.runtime.JCuda; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; - -import static jcuda.driver.JCudaDriver.cuCtxCreate; - -public class PhiloxJNvrtcExample { - - public static void main(String[] args) { - // Enable exceptions and omit error checks - JCuda.setExceptionsEnabled(true); - JCudaDriver.setExceptionsEnabled(true); - JNvrtc.setExceptionsEnabled(true); - - String ptx = ""; - try { - ptx = new String(Files.readAllBytes(Paths.get("philox_kernel.ptx"))); - } catch (IOException e) { - System.out.println(e.getMessage()); - } - - // Initialize the driver API and create a context - JCudaDriver.cuInit(0); - CUdevice device = new CUdevice(); - JCudaDriver.cuDeviceGet(device, 0); - CUcontext context = new CUcontext(); - cuCtxCreate(context, 0, device); - - CUmodule module = new CUmodule(); - JCudaDriver.cuModuleLoadData(module, ptx); +To use this ptx file in the SystemDS project, you can use this code: - // Get a function pointer to the kernel - CUfunction function = new CUfunction(); - JCudaDriver.cuModuleGetFunction(function, module, "philox_4_64"); - - // Prepare data - int n = 1000; // Number of random numbers to generate - long[] hostOut = new long[n]; - CUdeviceptr deviceOut = new CUdeviceptr(); - JCudaDriver.cuMemAlloc(deviceOut, n * Sizeof.LONG); - - // Direkte Werte für seed und startingCounter - long seed = 0L; // Fester Seed-Wert - long startingCounter = 0L; // Startwert für Counter - - Pointer kernelParameters = Pointer.to( - Pointer.to(deviceOut), // ulong* output - Pointer.to(new long[]{seed}), // uint64_t seed - Pointer.to(new long[]{startingCounter}), // uint64_t startingCounter - Pointer.to(new long[]{n}) // size_t numElements - ); - - // Launch the kernel - int blockSizeX = 128; - int gridSizeX = (int) Math.ceil((double)n / blockSizeX); - JCudaDriver.cuLaunchKernel( - function, - gridSizeX, 1, 1, // Grid dimension - blockSizeX, 1, 1, // Block dimension - 0, null, // Shared memory size and stream - kernelParameters, null // Kernel- und extra parameters - ); - JCudaDriver.cuCtxSynchronize(); - - // Copy result back - JCudaDriver.cuMemcpyDtoH(Pointer.to(hostOut), deviceOut, n * Sizeof.LONG); - - // Print results - System.out.println("Generated random numbers with seed=" + - String.format("0x%016X", seed) + - " and startingCounter=" + startingCounter); - for (int i = 0; i < Math.min(10, n); i++) { - System.out.printf("hostOut[%d] = 0x%016X\n", i, hostOut[i]); - } - - // Cleanup - JCudaDriver.cuMemFree(deviceOut); - JCudaDriver.cuCtxDestroy(context); - } -} -``` +[PhiloxJNvrtcExample.java](/scripts/staging/cuda-counter-based-prng/PhiloxJNvrtcExample.java) Run the code with the following command: @@ -178,228 +110,7 @@ javac -cp .:./target/dependency/jcuda-10.2.0.jar:./target/dependency/jcuda-nativ To compile the cuda kernel during runtime, you can use the following code: -```java -import jcuda.*; -import jcuda.driver.*; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileWriter; -import java.io.InputStreamReader; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -import static java.nio.file.Files.readAllBytes; -import static jcuda.driver.JCudaDriver.*; - -public class Random123_cuda implements AutoCloseable { - private static String philox4x64KernelSource = "#include \n" + - "#include \n" + - "extern \"C\" __global__ void philox_4_64(ulong* output, uint64_t startingCounter, uint64_t seed, size_t numElements) {\n" - + - " uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n" + - " if (idx * 4 < numElements) {\n" + - " r123::Philox4x64 rng;\n" + - " r123::Philox4x64::ctr_type ctr = {{startingCounter + idx, 0, 0, 0}};\n" + - " r123::Philox4x64::key_type key = {{seed}};\n" + - " r123::Philox4x64::ctr_type result = rng(ctr, key);\n" + - " for (int i = 0; i < 4; ++i) {\n" + - " size_t outputIdx = idx * 4 + i;\n" + - " if (outputIdx < numElements) {\n" + - " output[outputIdx] = result[i];\n" + - " }\n" + - " }\n" + - " }\n" + - "}\n"; - - private final CUcontext context; - private final CUmodule module; - private final CUfunction function; - private final int blockSize; - - public Random123_cuda() { - JCudaDriver.setExceptionsEnabled(true); - // Initialize CUDA - cuInit(0); - CUdevice device = new CUdevice(); - cuDeviceGet(device, 0); - context = new CUcontext(); - int result = cuCtxCreate(context, 0, device); - if (result != CUresult.CUDA_SUCCESS) { - throw new RuntimeException( - "Faild to create CUDA context: " + result + ", " + CUresult.stringFor(result)); - } - - // Compile to PTX - String ptx = compileToTPX(philox4x64KernelSource); - - // Load the PTX - module = new CUmodule(); - cuModuleLoadData(module, ptx); - function = new CUfunction(); - cuModuleGetFunction(function, module, "philox_4_64"); - - // Set block size based on device capabilities - blockSize = 64; // Can be adjusted based on device properties - } - - private String compileToTPX(String source) { - try { - // create temp files - File sourceFile = File.createTempFile("philox_kernel", ".cu"); - File outputFile = File.createTempFile("philox_kernel", ".ptx"); - - // Write cuda source to temp file - try (FileWriter writer = new FileWriter(sourceFile)) { - writer.write(philox4x64KernelSource); - } - - // build nvcc command - List command = new ArrayList<>(); - command.add("/usr/local/cuda/bin/nvcc"); - command.add("-ccbin"); - command.add("gcc-8"); - command.add("--ptx"); // PTX-Output generieren - command.add("-o"); - command.add(outputFile.getAbsolutePath()); - command.add("-I"); - command.add("./lib/random123/include"); - command.add(sourceFile.getAbsolutePath()); - - ProcessBuilder pb = new ProcessBuilder(command); - pb.redirectErrorStream(true); - Process process = pb.start(); - - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(process.getInputStream()))) { - String line; - StringBuilder output = new StringBuilder(); - while ((line = reader.readLine()) != null) { - output.append(line).append("\n"); - } - System.out.println("Compiler Output: " + output.toString()); - } - - int exitCode = process.waitFor(); - if (exitCode != 0) { - throw new RuntimeException("nvcc compiler returned non-zero exit code: " + exitCode); - } - - // Read PTX code - String ptxCode = new String(readAllBytes(outputFile.toPath())); - - // Cleanup - sourceFile.delete(); - outputFile.delete(); - - return ptxCode; - - } catch (Exception e) { - throw new RuntimeException("CUDA-compilation failed: " + e.getMessage(), e); - } - } - - /** - * Generates random numbers using the Philox4x64 algorithm - * - * @param startingCounter Initial counter value - * @param seed Random seed - * @param numElements Number of random numbers to generate - * @return Array of random numbers - */ - public CUdeviceptr Philox4x64(long startingCounter, long seed, int numElements) { - // Allocate host memory for results - // long[] hostOutput = new long[numElements]; - - // Allocate device memory - CUdeviceptr deviceOutput = new CUdeviceptr(); - cuMemAlloc(deviceOutput, (long) numElements * Sizeof.LONG); - - try { - System.out.printf("numElements: %d, seed: %d, startingCounter: %d%n", - numElements, seed, startingCounter); - - Pointer kernelParams = Pointer.to( - Pointer.to(deviceOutput), - Pointer.to(new long[] { startingCounter }), - Pointer.to(new long[] { seed }), - Pointer.to(new long[] { numElements })); - - // Calculate grid size - int gridSize = (numElements + (blockSize * 4) - 1) / (blockSize * 4); - - int kernelResult = cuLaunchKernel(function, - gridSize, 1, 1, // Grid dimension - blockSize, 1, 1, // Block dimension - 0, null, // Shared memory size and stream - kernelParams, null // Kernel parameters and extra parameters - ); - if (kernelResult != CUresult.CUDA_SUCCESS) { - throw new RuntimeException( - "Kernel-launch failed: " + kernelResult + ", " + CUresult.stringFor(kernelResult)); - } - - // Copy results back to host - // cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, (long) numElements * - // Sizeof.LONG); - } finally { - // Free device memory - // cuMemFree(deviceOutput); - } - - // return hostOutput; - return deviceOutput; - } - - /** - * Cleans up CUDA resources - */ - public void close() { - cuModuleUnload(module); - cuCtxDestroy(context); - } - - // Example usage - public static void main(String[] args) { - try (Random123_cuda generator = new Random123_cuda()) { - // Generate 1 million random numbers - int numElements = 1_000_000; - long seed = 0L; - long startingCounter = 0L; - - CUdeviceptr randomNumbers = generator.Philox4x64(startingCounter, seed, numElements); - - long[] elements = new long[10]; - cuMemcpyDtoH(Pointer.to(elements), randomNumbers, 10L * Sizeof.LONG); - cuMemFree(randomNumbers); - - // Print first few numbers - System.out.println("First 10 random numbers:"); - for (int i = 0; i < 10; i++) { - System.out.printf("%d: %x%n", i, elements[i]); - } - - int size = 10_000_000; - long start = System.currentTimeMillis(); - CUdeviceptr ptr = generator.Philox4x64(0L, 0L, size); - long end = System.currentTimeMillis(); - System.out.println("philox4x64 speed test: " + (end - start) * 1000 + " microseconds"); - cuMemFree(ptr); - Random r = new Random(); - long javaStart = System.currentTimeMillis(); - for (int i = 0; i < size; i++) { - r.nextLong(); - } - long javaEnd = System.currentTimeMillis(); - System.out.println("java speed test: " + (javaEnd - javaStart) * 1000 + " microseconds"); - System.out.println("philox4x64 is " + (double) (javaEnd - javaStart) / (double) (end - start) - + " times faster than java"); - - } - } -} -``` +[PhiloxRuntimeCompilationExample.java](/scripts/staging/cuda-counter-based-prng/PhiloxRuntimeCompilationExample.java) Run the code with the following command: From fe8a9860be6f0adb24cf81080a99e0d29bbcc40f Mon Sep 17 00:00:00 2001 From: ichbinstudent <45435943+ichbinstudent@users.noreply.github.com> Date: Fri, 14 Feb 2025 21:10:22 +0100 Subject: [PATCH 13/13] Delete scripts/staging/cuda-counter-based-prng/philox_kernel.ptx Removed because can be compiled from the provided sources --- .../cuda-counter-based-prng/philox_kernel.ptx | 772 ------------------ 1 file changed, 772 deletions(-) delete mode 100644 scripts/staging/cuda-counter-based-prng/philox_kernel.ptx diff --git a/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx b/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx deleted file mode 100644 index 90e0f4fa854..00000000000 --- a/scripts/staging/cuda-counter-based-prng/philox_kernel.ptx +++ /dev/null @@ -1,772 +0,0 @@ -// -// Generated by NVIDIA NVVM Compiler -// -// Compiler Build ID: CL-27506705 -// Cuda compilation tools, release 10.2, V10.2.89 -// Based on LLVM 3.4svn -// -// @page LICENSE -// Copyright 2010-2012, D. E. Shaw Research. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions, and the following disclaimer. -// -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions, and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// * Neither the name of D. E. Shaw Research nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -.version 6.5 -.target sm_30 -.address_size 64 - - // .globl philox_4_64_uniform - -.visible .entry philox_4_64_uniform( - .param .u64 philox_4_64_uniform_param_0, - .param .u64 philox_4_64_uniform_param_1, - .param .align 8 .b8 philox_4_64_uniform_param_2[32], - .param .u64 philox_4_64_uniform_param_3 -) -{ - .reg .pred %p<6>; - .reg .b32 %r<5>; - .reg .f64 %fd<9>; - .reg .b64 %rd<110>; - - - ld.param.u64 %rd8, [philox_4_64_uniform_param_0]; - ld.param.u64 %rd9, [philox_4_64_uniform_param_1]; - ld.param.u64 %rd13, [philox_4_64_uniform_param_2+24]; - ld.param.u64 %rd1, [philox_4_64_uniform_param_2+16]; - ld.param.u64 %rd11, [philox_4_64_uniform_param_2+8]; - ld.param.u64 %rd10, [philox_4_64_uniform_param_2]; - ld.param.u64 %rd14, [philox_4_64_uniform_param_3]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r1, %r2, %r3; - cvt.u64.u32 %rd2, %r4; - mul.wide.u32 %rd3, %r4, 4; - setp.ge.u64 %p1, %rd3, %rd14; - @%p1 bra BB0_7; - - add.s64 %rd15, %rd10, %rd2; - setp.lt.u64 %p2, %rd15, %rd10; - selp.u64 %rd16, 1, 0, %p2; - add.s64 %rd17, %rd16, %rd11; - mov.u64 %rd18, -3249550476889527149; - mul.hi.u64 %rd19, %rd18, %rd15; - mul.lo.s64 %rd20, %rd15, -3249550476889527149; - mov.u64 %rd21, -3865633965929787049; - mul.hi.u64 %rd22, %rd21, %rd1; - xor.b64 %rd23, %rd22, %rd9; - xor.b64 %rd24, %rd23, %rd17; - xor.b64 %rd25, %rd19, %rd13; - mul.hi.u64 %rd26, %rd18, %rd24; - mul.lo.s64 %rd27, %rd24, -3249550476889527149; - mul.hi.u64 %rd28, %rd21, %rd25; - mul.lo.s64 %rd29, %rd25, -3865633965929787049; - add.s64 %rd30, %rd9, -7046029254386353131; - mul.lo.s64 %rd31, %rd1, -3865633965929787049; - xor.b64 %rd32, %rd31, %rd30; - xor.b64 %rd33, %rd32, %rd28; - xor.b64 %rd34, %rd20, %rd26; - xor.b64 %rd35, %rd34, -4942790177534073029; - mul.hi.u64 %rd36, %rd18, %rd33; - mul.lo.s64 %rd37, %rd33, -3249550476889527149; - mul.hi.u64 %rd38, %rd21, %rd35; - mul.lo.s64 %rd39, %rd35, -3865633965929787049; - add.s64 %rd40, %rd9, 4354685564936845354; - xor.b64 %rd41, %rd29, %rd40; - xor.b64 %rd42, %rd41, %rd38; - xor.b64 %rd43, %rd27, %rd36; - xor.b64 %rd44, %rd43, 8561163718641405558; - mul.hi.u64 %rd45, %rd18, %rd42; - mul.lo.s64 %rd46, %rd42, -3249550476889527149; - mul.hi.u64 %rd47, %rd21, %rd44; - mul.lo.s64 %rd48, %rd44, -3865633965929787049; - add.s64 %rd49, %rd9, -2691343689449507777; - xor.b64 %rd50, %rd39, %rd49; - xor.b64 %rd51, %rd50, %rd47; - xor.b64 %rd52, %rd37, %rd45; - xor.b64 %rd53, %rd52, 3618373541107332529; - mul.hi.u64 %rd54, %rd18, %rd51; - mul.lo.s64 %rd55, %rd51, -3249550476889527149; - mul.hi.u64 %rd56, %rd21, %rd53; - mul.lo.s64 %rd57, %rd53, -3865633965929787049; - add.s64 %rd58, %rd9, 8709371129873690708; - xor.b64 %rd59, %rd48, %rd58; - xor.b64 %rd60, %rd59, %rd56; - xor.b64 %rd61, %rd46, %rd54; - xor.b64 %rd62, %rd61, -1324416636426740500; - mul.hi.u64 %rd63, %rd18, %rd60; - mul.lo.s64 %rd64, %rd60, -3249550476889527149; - mul.hi.u64 %rd65, %rd21, %rd62; - mul.lo.s64 %rd66, %rd62, -3865633965929787049; - add.s64 %rd67, %rd9, 1663341875487337577; - xor.b64 %rd68, %rd57, %rd67; - xor.b64 %rd69, %rd68, %rd65; - xor.b64 %rd70, %rd55, %rd63; - xor.b64 %rd71, %rd70, -6267206813960813529; - mul.hi.u64 %rd72, %rd18, %rd69; - mul.lo.s64 %rd73, %rd69, -3249550476889527149; - mul.hi.u64 %rd74, %rd21, %rd71; - mul.lo.s64 %rd75, %rd71, -3865633965929787049; - add.s64 %rd76, %rd9, -5382687378899015554; - xor.b64 %rd77, %rd66, %rd76; - xor.b64 %rd78, %rd77, %rd74; - xor.b64 %rd79, %rd64, %rd72; - xor.b64 %rd80, %rd79, 7236747082214665058; - mul.hi.u64 %rd81, %rd18, %rd78; - mul.lo.s64 %rd82, %rd78, -3249550476889527149; - mul.hi.u64 %rd83, %rd21, %rd80; - mul.lo.s64 %rd84, %rd80, -3865633965929787049; - add.s64 %rd85, %rd9, 6018027440424182931; - xor.b64 %rd86, %rd75, %rd85; - xor.b64 %rd87, %rd86, %rd83; - xor.b64 %rd88, %rd73, %rd81; - xor.b64 %rd89, %rd88, 2293956904680592029; - mul.hi.u64 %rd90, %rd18, %rd87; - mul.lo.s64 %rd91, %rd87, -3249550476889527149; - mul.hi.u64 %rd92, %rd21, %rd89; - mul.lo.s64 %rd93, %rd89, -3865633965929787049; - add.s64 %rd94, %rd9, -1028001813962170200; - xor.b64 %rd95, %rd84, %rd94; - xor.b64 %rd96, %rd95, %rd92; - xor.b64 %rd97, %rd82, %rd90; - xor.b64 %rd4, %rd97, -2648833272853481000; - mul.hi.u64 %rd98, %rd18, %rd96; - mul.lo.s64 %rd5, %rd96, -3249550476889527149; - mul.hi.u64 %rd99, %rd21, %rd4; - add.s64 %rd100, %rd9, -8074031068348523331; - xor.b64 %rd101, %rd93, %rd100; - xor.b64 %rd102, %rd101, %rd99; - xor.b64 %rd103, %rd91, %rd98; - xor.b64 %rd6, %rd103, -7591623450387554029; - cvt.rn.f64.s64 %fd1, %rd102; - mul.f64 %fd2, %fd1, 0d3C00000000000000; - cvta.to.global.u64 %rd104, %rd8; - shl.b64 %rd105, %rd3, 3; - add.s64 %rd7, %rd104, %rd105; - st.global.f64 [%rd7], %fd2; - add.s64 %rd106, %rd3, 1; - setp.ge.u64 %p3, %rd106, %rd14; - @%p3 bra BB0_3; - - mul.lo.s64 %rd107, %rd4, -3865633965929787049; - cvt.rn.f64.s64 %fd3, %rd107; - mul.f64 %fd4, %fd3, 0d3C00000000000000; - st.global.f64 [%rd7+8], %fd4; - -BB0_3: - add.s64 %rd108, %rd3, 2; - setp.ge.u64 %p4, %rd108, %rd14; - @%p4 bra BB0_5; - - cvt.rn.f64.s64 %fd5, %rd6; - mul.f64 %fd6, %fd5, 0d3C00000000000000; - st.global.f64 [%rd7+16], %fd6; - -BB0_5: - add.s64 %rd109, %rd3, 3; - setp.ge.u64 %p5, %rd109, %rd14; - @%p5 bra BB0_7; - - cvt.rn.f64.s64 %fd7, %rd5; - mul.f64 %fd8, %fd7, 0d3C00000000000000; - st.global.f64 [%rd7+24], %fd8; - -BB0_7: - ret; -} - - // .globl philox_4_64_standard -.visible .entry philox_4_64_standard( - .param .u64 philox_4_64_standard_param_0, - .param .u64 philox_4_64_standard_param_1, - .param .align 8 .b8 philox_4_64_standard_param_2[32], - .param .u64 philox_4_64_standard_param_3 -) -{ - .reg .pred %p<7>; - .reg .b32 %r<5>; - .reg .f64 %fd<21>; - .reg .b64 %rd<191>; - - - ld.param.u64 %rd14, [philox_4_64_standard_param_0]; - ld.param.u64 %rd15, [philox_4_64_standard_param_1]; - ld.param.u64 %rd1, [philox_4_64_standard_param_2]; - ld.param.u64 %rd2, [philox_4_64_standard_param_2+8]; - ld.param.u64 %rd3, [philox_4_64_standard_param_2+16]; - ld.param.u64 %rd4, [philox_4_64_standard_param_2+24]; - ld.param.u64 %rd16, [philox_4_64_standard_param_3]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r1, %r2, %r3; - cvt.u64.u32 %rd5, %r4; - mul.wide.u32 %rd6, %r4, 4; - setp.ge.u64 %p1, %rd6, %rd16; - @%p1 bra BB1_7; - - add.s64 %rd17, %rd1, %rd5; - setp.lt.u64 %p2, %rd17, %rd1; - selp.u64 %rd18, 1, 0, %p2; - add.s64 %rd19, %rd18, %rd2; - add.s64 %rd20, %rd5, %rd16; - add.s64 %rd21, %rd1, %rd20; - setp.lt.u64 %p3, %rd21, %rd1; - selp.u64 %rd22, 1, 0, %p3; - add.s64 %rd23, %rd22, %rd2; - mov.u64 %rd24, -3249550476889527149; - mul.hi.u64 %rd25, %rd24, %rd17; - mul.lo.s64 %rd26, %rd17, -3249550476889527149; - xor.b64 %rd27, %rd19, %rd15; - mov.u64 %rd28, -3865633965929787049; - mul.hi.u64 %rd29, %rd28, %rd3; - xor.b64 %rd30, %rd27, %rd29; - xor.b64 %rd31, %rd25, %rd4; - mul.hi.u64 %rd32, %rd24, %rd30; - mul.lo.s64 %rd33, %rd30, -3249550476889527149; - mul.hi.u64 %rd34, %rd28, %rd31; - mul.lo.s64 %rd35, %rd31, -3865633965929787049; - add.s64 %rd36, %rd15, -7046029254386353131; - mul.lo.s64 %rd37, %rd3, -3865633965929787049; - xor.b64 %rd38, %rd37, %rd36; - xor.b64 %rd39, %rd38, %rd34; - xor.b64 %rd40, %rd26, %rd32; - xor.b64 %rd41, %rd40, -4942790177534073029; - mul.hi.u64 %rd42, %rd24, %rd39; - mul.lo.s64 %rd43, %rd39, -3249550476889527149; - mul.hi.u64 %rd44, %rd28, %rd41; - mul.lo.s64 %rd45, %rd41, -3865633965929787049; - add.s64 %rd46, %rd15, 4354685564936845354; - xor.b64 %rd47, %rd35, %rd46; - xor.b64 %rd48, %rd47, %rd44; - xor.b64 %rd49, %rd33, %rd42; - xor.b64 %rd50, %rd49, 8561163718641405558; - mul.hi.u64 %rd51, %rd24, %rd48; - mul.lo.s64 %rd52, %rd48, -3249550476889527149; - mul.hi.u64 %rd53, %rd28, %rd50; - mul.lo.s64 %rd54, %rd50, -3865633965929787049; - add.s64 %rd55, %rd15, -2691343689449507777; - xor.b64 %rd56, %rd45, %rd55; - xor.b64 %rd57, %rd56, %rd53; - xor.b64 %rd58, %rd43, %rd51; - xor.b64 %rd59, %rd58, 3618373541107332529; - mul.hi.u64 %rd60, %rd24, %rd57; - mul.lo.s64 %rd61, %rd57, -3249550476889527149; - mul.hi.u64 %rd62, %rd28, %rd59; - mul.lo.s64 %rd63, %rd59, -3865633965929787049; - add.s64 %rd64, %rd15, 8709371129873690708; - xor.b64 %rd65, %rd54, %rd64; - xor.b64 %rd66, %rd65, %rd62; - xor.b64 %rd67, %rd52, %rd60; - xor.b64 %rd68, %rd67, -1324416636426740500; - mul.hi.u64 %rd69, %rd24, %rd66; - mul.lo.s64 %rd70, %rd66, -3249550476889527149; - mul.hi.u64 %rd71, %rd28, %rd68; - mul.lo.s64 %rd72, %rd68, -3865633965929787049; - add.s64 %rd73, %rd15, 1663341875487337577; - xor.b64 %rd74, %rd63, %rd73; - xor.b64 %rd75, %rd74, %rd71; - xor.b64 %rd76, %rd61, %rd69; - xor.b64 %rd77, %rd76, -6267206813960813529; - mul.hi.u64 %rd78, %rd24, %rd75; - mul.lo.s64 %rd79, %rd75, -3249550476889527149; - mul.hi.u64 %rd80, %rd28, %rd77; - mul.lo.s64 %rd81, %rd77, -3865633965929787049; - add.s64 %rd82, %rd15, -5382687378899015554; - xor.b64 %rd83, %rd72, %rd82; - xor.b64 %rd84, %rd83, %rd80; - xor.b64 %rd85, %rd70, %rd78; - xor.b64 %rd86, %rd85, 7236747082214665058; - mul.hi.u64 %rd87, %rd24, %rd84; - mul.lo.s64 %rd88, %rd84, -3249550476889527149; - mul.hi.u64 %rd89, %rd28, %rd86; - mul.lo.s64 %rd90, %rd86, -3865633965929787049; - add.s64 %rd91, %rd15, 6018027440424182931; - xor.b64 %rd92, %rd81, %rd91; - xor.b64 %rd93, %rd92, %rd89; - xor.b64 %rd94, %rd79, %rd87; - xor.b64 %rd95, %rd94, 2293956904680592029; - mul.hi.u64 %rd96, %rd24, %rd93; - mul.lo.s64 %rd97, %rd93, -3249550476889527149; - mul.hi.u64 %rd98, %rd28, %rd95; - mul.lo.s64 %rd99, %rd95, -3865633965929787049; - add.s64 %rd100, %rd15, -1028001813962170200; - xor.b64 %rd101, %rd90, %rd100; - xor.b64 %rd102, %rd101, %rd98; - xor.b64 %rd103, %rd88, %rd96; - xor.b64 %rd7, %rd103, -2648833272853481000; - mul.hi.u64 %rd104, %rd24, %rd102; - mul.lo.s64 %rd8, %rd102, -3249550476889527149; - mul.hi.u64 %rd105, %rd28, %rd7; - add.s64 %rd106, %rd15, -8074031068348523331; - xor.b64 %rd107, %rd99, %rd106; - xor.b64 %rd108, %rd107, %rd105; - xor.b64 %rd109, %rd97, %rd104; - xor.b64 %rd9, %rd109, -7591623450387554029; - mul.hi.u64 %rd110, %rd24, %rd21; - mul.lo.s64 %rd111, %rd21, -3249550476889527149; - xor.b64 %rd112, %rd29, %rd15; - xor.b64 %rd113, %rd112, %rd23; - xor.b64 %rd114, %rd110, %rd4; - mul.hi.u64 %rd115, %rd24, %rd113; - mul.lo.s64 %rd116, %rd113, -3249550476889527149; - mul.hi.u64 %rd117, %rd28, %rd114; - mul.lo.s64 %rd118, %rd114, -3865633965929787049; - xor.b64 %rd119, %rd38, %rd117; - xor.b64 %rd120, %rd111, %rd115; - xor.b64 %rd121, %rd120, -4942790177534073029; - mul.hi.u64 %rd122, %rd24, %rd119; - mul.lo.s64 %rd123, %rd119, -3249550476889527149; - mul.hi.u64 %rd124, %rd28, %rd121; - mul.lo.s64 %rd125, %rd121, -3865633965929787049; - xor.b64 %rd126, %rd118, %rd46; - xor.b64 %rd127, %rd126, %rd124; - xor.b64 %rd128, %rd116, %rd122; - xor.b64 %rd129, %rd128, 8561163718641405558; - mul.hi.u64 %rd130, %rd24, %rd127; - mul.lo.s64 %rd131, %rd127, -3249550476889527149; - mul.hi.u64 %rd132, %rd28, %rd129; - mul.lo.s64 %rd133, %rd129, -3865633965929787049; - xor.b64 %rd134, %rd125, %rd55; - xor.b64 %rd135, %rd134, %rd132; - xor.b64 %rd136, %rd123, %rd130; - xor.b64 %rd137, %rd136, 3618373541107332529; - mul.hi.u64 %rd138, %rd24, %rd135; - mul.lo.s64 %rd139, %rd135, -3249550476889527149; - mul.hi.u64 %rd140, %rd28, %rd137; - mul.lo.s64 %rd141, %rd137, -3865633965929787049; - xor.b64 %rd142, %rd133, %rd64; - xor.b64 %rd143, %rd142, %rd140; - xor.b64 %rd144, %rd131, %rd138; - xor.b64 %rd145, %rd144, -1324416636426740500; - mul.hi.u64 %rd146, %rd24, %rd143; - mul.lo.s64 %rd147, %rd143, -3249550476889527149; - mul.hi.u64 %rd148, %rd28, %rd145; - mul.lo.s64 %rd149, %rd145, -3865633965929787049; - xor.b64 %rd150, %rd141, %rd73; - xor.b64 %rd151, %rd150, %rd148; - xor.b64 %rd152, %rd139, %rd146; - xor.b64 %rd153, %rd152, -6267206813960813529; - mul.hi.u64 %rd154, %rd24, %rd151; - mul.lo.s64 %rd155, %rd151, -3249550476889527149; - mul.hi.u64 %rd156, %rd28, %rd153; - mul.lo.s64 %rd157, %rd153, -3865633965929787049; - xor.b64 %rd158, %rd149, %rd82; - xor.b64 %rd159, %rd158, %rd156; - xor.b64 %rd160, %rd147, %rd154; - xor.b64 %rd161, %rd160, 7236747082214665058; - mul.hi.u64 %rd162, %rd24, %rd159; - mul.lo.s64 %rd163, %rd159, -3249550476889527149; - mul.hi.u64 %rd164, %rd28, %rd161; - mul.lo.s64 %rd165, %rd161, -3865633965929787049; - xor.b64 %rd166, %rd157, %rd91; - xor.b64 %rd167, %rd166, %rd164; - xor.b64 %rd168, %rd155, %rd162; - xor.b64 %rd169, %rd168, 2293956904680592029; - mul.hi.u64 %rd170, %rd24, %rd167; - mul.lo.s64 %rd171, %rd167, -3249550476889527149; - mul.hi.u64 %rd172, %rd28, %rd169; - mul.lo.s64 %rd173, %rd169, -3865633965929787049; - xor.b64 %rd174, %rd165, %rd100; - xor.b64 %rd175, %rd174, %rd172; - xor.b64 %rd176, %rd163, %rd170; - xor.b64 %rd10, %rd176, -2648833272853481000; - mul.hi.u64 %rd177, %rd24, %rd175; - mul.lo.s64 %rd11, %rd175, -3249550476889527149; - mul.hi.u64 %rd178, %rd28, %rd10; - xor.b64 %rd179, %rd173, %rd106; - xor.b64 %rd180, %rd179, %rd178; - xor.b64 %rd181, %rd171, %rd177; - xor.b64 %rd12, %rd181, -7591623450387554029; - cvt.rn.f64.s64 %fd1, %rd180; - cvt.rn.f64.s64 %fd2, %rd108; - mul.f64 %fd3, %fd2, 0d3C00000000000000; - fma.rn.f64 %fd4, %fd1, 0d3C00000000000000, %fd3; - mul.f64 %fd5, %fd4, 0d3FE0000000000000; - cvta.to.global.u64 %rd182, %rd14; - shl.b64 %rd183, %rd6, 3; - add.s64 %rd13, %rd182, %rd183; - st.global.f64 [%rd13], %fd5; - add.s64 %rd184, %rd6, 1; - setp.ge.u64 %p4, %rd184, %rd16; - @%p4 bra BB1_3; - - mul.lo.s64 %rd185, %rd10, -3865633965929787049; - cvt.rn.f64.s64 %fd6, %rd185; - mul.lo.s64 %rd186, %rd7, -3865633965929787049; - cvt.rn.f64.s64 %fd7, %rd186; - mul.f64 %fd8, %fd7, 0d3C00000000000000; - fma.rn.f64 %fd9, %fd6, 0d3C00000000000000, %fd8; - mul.f64 %fd10, %fd9, 0d3FE0000000000000; - st.global.f64 [%rd13+8], %fd10; - -BB1_3: - ld.param.u64 %rd189, [philox_4_64_standard_param_3]; - add.s64 %rd187, %rd6, 2; - setp.ge.u64 %p5, %rd187, %rd189; - @%p5 bra BB1_5; - - cvt.rn.f64.s64 %fd11, %rd12; - cvt.rn.f64.s64 %fd12, %rd9; - mul.f64 %fd13, %fd12, 0d3C00000000000000; - fma.rn.f64 %fd14, %fd11, 0d3C00000000000000, %fd13; - mul.f64 %fd15, %fd14, 0d3FE0000000000000; - st.global.f64 [%rd13+16], %fd15; - -BB1_5: - ld.param.u64 %rd190, [philox_4_64_standard_param_3]; - add.s64 %rd188, %rd6, 3; - setp.ge.u64 %p6, %rd188, %rd190; - @%p6 bra BB1_7; - - cvt.rn.f64.s64 %fd16, %rd11; - cvt.rn.f64.s64 %fd17, %rd8; - mul.f64 %fd18, %fd17, 0d3C00000000000000; - fma.rn.f64 %fd19, %fd16, 0d3C00000000000000, %fd18; - mul.f64 %fd20, %fd19, 0d3FE0000000000000; - st.global.f64 [%rd13+24], %fd20; - -BB1_7: - ret; -} - - // .globl philox_4_32 -.visible .entry philox_4_32( - .param .u64 philox_4_32_param_0, - .param .u32 philox_4_32_param_1, - .param .u32 philox_4_32_param_2, - .param .u64 philox_4_32_param_3 -) -{ - .reg .pred %p<5>; - .reg .b32 %r<60>; - .reg .b64 %rd<74>; - - - ld.param.u64 %rd5, [philox_4_32_param_0]; - ld.param.u32 %r5, [philox_4_32_param_1]; - ld.param.u32 %r6, [philox_4_32_param_2]; - ld.param.u64 %rd6, [philox_4_32_param_3]; - mov.u32 %r7, %ntid.x; - mov.u32 %r8, %ctaid.x; - mov.u32 %r9, %tid.x; - mad.lo.s32 %r1, %r7, %r8, %r9; - shl.b32 %r2, %r1, 2; - cvt.u64.u32 %rd1, %r2; - setp.ge.u64 %p1, %rd1, %rd6; - @%p1 bra BB2_7; - - cvta.to.global.u64 %rd7, %rd5; - add.s32 %r10, %r1, %r6; - mul.wide.u32 %rd8, %r10, -766435501; - shr.u64 %rd9, %rd8, 32; - mul.wide.u32 %rd10, %r5, -766435501; - shr.u64 %rd11, %rd10, 32; - mul.lo.s64 %rd12, %rd9, 3449720151; - shr.u64 %rd13, %rd12, 32; - cvt.u32.u64 %r11, %rd13; - cvt.u32.u64 %r12, %rd12; - add.s32 %r13, %r5, -1640531527; - xor.b32 %r14, %r11, %r13; - mul.wide.u32 %rd14, %r14, -766435501; - shr.u64 %rd15, %rd14, 32; - and.b64 %rd16, %rd8, 4294967295; - xor.b64 %rd17, %rd11, %rd16; - xor.b64 %rd18, %rd17, 3144134277; - mul.lo.s64 %rd19, %rd18, 3449720151; - shr.u64 %rd20, %rd19, 32; - cvt.u32.u64 %r15, %rd20; - cvt.u32.u64 %r16, %rd19; - add.s32 %r17, %r5, 1013904242; - xor.b32 %r18, %r12, %r17; - xor.b32 %r19, %r18, %r15; - mul.wide.u32 %rd21, %r19, -766435501; - shr.u64 %rd22, %rd21, 32; - and.b64 %rd23, %rd10, 4294967295; - xor.b64 %rd24, %rd23, %rd15; - xor.b64 %rd25, %rd24, 1993301258; - mul.lo.s64 %rd26, %rd25, 3449720151; - shr.u64 %rd27, %rd26, 32; - cvt.u32.u64 %r20, %rd27; - cvt.u32.u64 %r21, %rd26; - add.s32 %r22, %r5, -626627285; - xor.b32 %r23, %r16, %r22; - xor.b32 %r24, %r23, %r20; - mul.wide.u32 %rd28, %r24, -766435501; - shr.u64 %rd29, %rd28, 32; - and.b64 %rd30, %rd14, 4294967295; - xor.b64 %rd31, %rd30, %rd22; - xor.b64 %rd32, %rd31, 842468239; - mul.lo.s64 %rd33, %rd32, 3449720151; - shr.u64 %rd34, %rd33, 32; - cvt.u32.u64 %r25, %rd34; - cvt.u32.u64 %r26, %rd33; - add.s32 %r27, %r5, 2027808484; - xor.b32 %r28, %r21, %r27; - xor.b32 %r29, %r28, %r25; - mul.wide.u32 %rd35, %r29, -766435501; - shr.u64 %rd36, %rd35, 32; - and.b64 %rd37, %rd21, 4294967295; - xor.b64 %rd38, %rd37, %rd29; - xor.b64 %rd39, %rd38, 3986602516; - mul.lo.s64 %rd40, %rd39, 3449720151; - shr.u64 %rd41, %rd40, 32; - cvt.u32.u64 %r30, %rd41; - cvt.u32.u64 %r31, %rd40; - add.s32 %r32, %r5, 387276957; - xor.b32 %r33, %r26, %r32; - xor.b32 %r34, %r33, %r30; - mul.wide.u32 %rd42, %r34, -766435501; - shr.u64 %rd43, %rd42, 32; - and.b64 %rd44, %rd28, 4294967295; - xor.b64 %rd45, %rd44, %rd36; - xor.b64 %rd46, %rd45, 2835769497; - mul.lo.s64 %rd47, %rd46, 3449720151; - shr.u64 %rd48, %rd47, 32; - cvt.u32.u64 %r35, %rd48; - cvt.u32.u64 %r36, %rd47; - add.s32 %r37, %r5, -1253254570; - xor.b32 %r38, %r31, %r37; - xor.b32 %r39, %r38, %r35; - mul.wide.u32 %rd49, %r39, -766435501; - shr.u64 %rd50, %rd49, 32; - and.b64 %rd51, %rd35, 4294967295; - xor.b64 %rd52, %rd51, %rd43; - xor.b64 %rd53, %rd52, 1684936478; - mul.lo.s64 %rd54, %rd53, 3449720151; - shr.u64 %rd55, %rd54, 32; - cvt.u32.u64 %r40, %rd55; - cvt.u32.u64 %r41, %rd54; - add.s32 %r42, %r5, 1401181199; - xor.b32 %r43, %r36, %r42; - xor.b32 %r44, %r43, %r40; - mul.wide.u32 %rd56, %r44, -766435501; - shr.u64 %rd57, %rd56, 32; - cvt.u32.u64 %r3, %rd56; - and.b64 %rd58, %rd42, 4294967295; - xor.b64 %rd59, %rd58, %rd50; - xor.b64 %rd60, %rd59, 534103459; - mul.lo.s64 %rd61, %rd60, 3449720151; - shr.u64 %rd62, %rd61, 32; - cvt.u32.u64 %r45, %rd62; - cvt.u32.u64 %r46, %rd61; - add.s32 %r47, %r5, -239350328; - xor.b32 %r48, %r41, %r47; - xor.b32 %r4, %r48, %r45; - and.b64 %rd63, %rd49, 4294967295; - xor.b64 %rd64, %rd63, %rd57; - xor.b64 %rd65, %rd64, 3678237736; - mul.lo.s64 %rd2, %rd65, 3449720151; - shr.u64 %rd66, %rd2, 32; - cvt.u32.u64 %r49, %rd66; - add.s32 %r50, %r5, -1879881855; - xor.b32 %r51, %r46, %r50; - xor.b32 %r52, %r51, %r49; - shl.b64 %rd67, %rd1, 2; - add.s64 %rd68, %rd7, %rd67; - st.global.u32 [%rd68], %r52; - add.s32 %r53, %r2, 1; - cvt.u64.u32 %rd69, %r53; - mul.wide.u32 %rd70, %r2, 4; - add.s64 %rd3, %rd7, %rd70; - setp.ge.u64 %p2, %rd69, %rd6; - @%p2 bra BB2_3; - - st.global.u32 [%rd3+4], %rd2; - -BB2_3: - mul.wide.u32 %rd4, %r4, -766435501; - add.s32 %r54, %r2, 2; - cvt.u64.u32 %rd71, %r54; - setp.ge.u64 %p3, %rd71, %rd6; - @%p3 bra BB2_5; - - shr.u64 %rd72, %rd4, 32; - cvt.u32.u64 %r55, %rd72; - xor.b32 %r56, %r3, %r55; - xor.b32 %r57, %r56, -1767562579; - st.global.u32 [%rd3+8], %r57; - -BB2_5: - cvt.u32.u64 %r58, %rd1; - add.s32 %r59, %r58, 3; - cvt.u64.u32 %rd73, %r59; - setp.ge.u64 %p4, %rd73, %rd6; - @%p4 bra BB2_7; - - st.global.u32 [%rd3+12], %rd4; - -BB2_7: - ret; -} - - // .globl philox_4_64 -.visible .entry philox_4_64( - .param .u64 philox_4_64_param_0, - .param .u64 philox_4_64_param_1, - .param .u64 philox_4_64_param_2, - .param .u64 philox_4_64_param_3 -) -{ - .reg .pred %p<5>; - .reg .b32 %r<5>; - .reg .b64 %rd<101>; - - - ld.param.u64 %rd7, [philox_4_64_param_0]; - ld.param.u64 %rd8, [philox_4_64_param_1]; - ld.param.u64 %rd9, [philox_4_64_param_2]; - ld.param.u64 %rd10, [philox_4_64_param_3]; - mov.u32 %r1, %ntid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r1, %r2, %r3; - cvt.u64.u32 %rd1, %r4; - mul.wide.u32 %rd2, %r4, 4; - setp.ge.u64 %p1, %rd2, %rd10; - @%p1 bra BB3_7; - - add.s64 %rd11, %rd1, %rd9; - mov.u64 %rd12, -3249550476889527149; - mul.hi.u64 %rd13, %rd12, %rd11; - mul.lo.s64 %rd14, %rd11, -3249550476889527149; - mov.u64 %rd15, 0; - mov.u64 %rd16, -3865633965929787049; - mul.hi.u64 %rd17, %rd16, %rd15; - xor.b64 %rd18, %rd17, %rd8; - mul.hi.u64 %rd19, %rd12, %rd18; - mul.lo.s64 %rd20, %rd18, -3249550476889527149; - mul.hi.u64 %rd21, %rd16, %rd13; - mul.lo.s64 %rd22, %rd13, -3865633965929787049; - add.s64 %rd23, %rd8, -7046029254386353131; - xor.b64 %rd24, %rd21, %rd23; - xor.b64 %rd25, %rd14, %rd19; - xor.b64 %rd26, %rd25, -4942790177534073029; - mul.hi.u64 %rd27, %rd12, %rd24; - mul.lo.s64 %rd28, %rd24, -3249550476889527149; - mul.hi.u64 %rd29, %rd16, %rd26; - mul.lo.s64 %rd30, %rd26, -3865633965929787049; - add.s64 %rd31, %rd8, 4354685564936845354; - xor.b64 %rd32, %rd22, %rd31; - xor.b64 %rd33, %rd32, %rd29; - xor.b64 %rd34, %rd20, %rd27; - xor.b64 %rd35, %rd34, 8561163718641405558; - mul.hi.u64 %rd36, %rd12, %rd33; - mul.lo.s64 %rd37, %rd33, -3249550476889527149; - mul.hi.u64 %rd38, %rd16, %rd35; - mul.lo.s64 %rd39, %rd35, -3865633965929787049; - add.s64 %rd40, %rd8, -2691343689449507777; - xor.b64 %rd41, %rd30, %rd40; - xor.b64 %rd42, %rd41, %rd38; - xor.b64 %rd43, %rd28, %rd36; - xor.b64 %rd44, %rd43, 3618373541107332529; - mul.hi.u64 %rd45, %rd12, %rd42; - mul.lo.s64 %rd46, %rd42, -3249550476889527149; - mul.hi.u64 %rd47, %rd16, %rd44; - mul.lo.s64 %rd48, %rd44, -3865633965929787049; - add.s64 %rd49, %rd8, 8709371129873690708; - xor.b64 %rd50, %rd39, %rd49; - xor.b64 %rd51, %rd50, %rd47; - xor.b64 %rd52, %rd37, %rd45; - xor.b64 %rd53, %rd52, -1324416636426740500; - mul.hi.u64 %rd54, %rd12, %rd51; - mul.lo.s64 %rd55, %rd51, -3249550476889527149; - mul.hi.u64 %rd56, %rd16, %rd53; - mul.lo.s64 %rd57, %rd53, -3865633965929787049; - add.s64 %rd58, %rd8, 1663341875487337577; - xor.b64 %rd59, %rd48, %rd58; - xor.b64 %rd60, %rd59, %rd56; - xor.b64 %rd61, %rd46, %rd54; - xor.b64 %rd62, %rd61, -6267206813960813529; - mul.hi.u64 %rd63, %rd12, %rd60; - mul.lo.s64 %rd64, %rd60, -3249550476889527149; - mul.hi.u64 %rd65, %rd16, %rd62; - mul.lo.s64 %rd66, %rd62, -3865633965929787049; - add.s64 %rd67, %rd8, -5382687378899015554; - xor.b64 %rd68, %rd57, %rd67; - xor.b64 %rd69, %rd68, %rd65; - xor.b64 %rd70, %rd55, %rd63; - xor.b64 %rd71, %rd70, 7236747082214665058; - mul.hi.u64 %rd72, %rd12, %rd69; - mul.lo.s64 %rd73, %rd69, -3249550476889527149; - mul.hi.u64 %rd74, %rd16, %rd71; - mul.lo.s64 %rd75, %rd71, -3865633965929787049; - add.s64 %rd76, %rd8, 6018027440424182931; - xor.b64 %rd77, %rd66, %rd76; - xor.b64 %rd78, %rd77, %rd74; - xor.b64 %rd79, %rd64, %rd72; - xor.b64 %rd80, %rd79, 2293956904680592029; - mul.hi.u64 %rd81, %rd12, %rd78; - mul.lo.s64 %rd82, %rd78, -3249550476889527149; - mul.hi.u64 %rd83, %rd16, %rd80; - mul.lo.s64 %rd84, %rd80, -3865633965929787049; - add.s64 %rd85, %rd8, -1028001813962170200; - xor.b64 %rd86, %rd75, %rd85; - xor.b64 %rd87, %rd86, %rd83; - xor.b64 %rd88, %rd73, %rd81; - xor.b64 %rd3, %rd88, -2648833272853481000; - mul.hi.u64 %rd89, %rd12, %rd87; - mul.lo.s64 %rd4, %rd87, -3249550476889527149; - mul.hi.u64 %rd90, %rd16, %rd3; - add.s64 %rd91, %rd8, -8074031068348523331; - xor.b64 %rd92, %rd84, %rd91; - xor.b64 %rd93, %rd92, %rd90; - xor.b64 %rd94, %rd82, %rd89; - xor.b64 %rd5, %rd94, -7591623450387554029; - cvta.to.global.u64 %rd95, %rd7; - shl.b64 %rd96, %rd2, 3; - add.s64 %rd6, %rd95, %rd96; - st.global.u64 [%rd6], %rd93; - add.s64 %rd97, %rd2, 1; - setp.ge.u64 %p2, %rd97, %rd10; - @%p2 bra BB3_3; - - mul.lo.s64 %rd98, %rd3, -3865633965929787049; - st.global.u64 [%rd6+8], %rd98; - -BB3_3: - add.s64 %rd99, %rd2, 2; - setp.ge.u64 %p3, %rd99, %rd10; - @%p3 bra BB3_5; - - st.global.u64 [%rd6+16], %rd5; - -BB3_5: - add.s64 %rd100, %rd2, 3; - setp.ge.u64 %p4, %rd100, %rd10; - @%p4 bra BB3_7; - - st.global.u64 [%rd6+24], %rd4; - -BB3_7: - ret; -} - -