diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java index 904bec2..9c4b4a9 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java @@ -390,7 +390,7 @@ private & NativeType> List encodeInputs(List encodedInputTensors = new ArrayList(); Gson gson = new Gson(); for (Tensor tt : inputTensors) { - SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), true, true); shmaInputList.add(shma); HashMap map = new HashMap(); map.put(NAME_KEY, tt.getName()); @@ -415,7 +415,7 @@ List encodeOutputs(List> outputTensors) { if (!tt.isEmpty()) { map.put(SHAPE_KEY, tt.getShape()); map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData())); - SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), true, true); shmaOutputList.add(shma); map.put(MEM_NAME_KEY, shma.getName()); } else if (PlatformDetection.isWindows()){ diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java index 0880edb..9776c42 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java @@ -88,7 +88,7 @@ private static void buildFromTensorUByte(NDArray tensor, String memoryName) thro if (CommonUtils.int32Overflows(arrayShape, 1)) throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), true, true); shma.getDataBufferNoHeader().put(tensor.toByteArray()); if (PlatformDetection.isWindows()) shma.close(); } @@ -100,7 +100,7 @@ private static void buildFromTensorInt(NDArray tensor, String memoryName) throws throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), true, true); shma.getDataBufferNoHeader().put(tensor.toByteArray()); if (PlatformDetection.isWindows()) shma.close(); } @@ -112,7 +112,7 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), true, true); shma.getDataBufferNoHeader().put(tensor.toByteArray()); if (PlatformDetection.isWindows()) shma.close(); } @@ -124,7 +124,7 @@ private static void buildFromTensorDouble(NDArray tensor, String memoryName) thr throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), true, true); shma.getDataBufferNoHeader().put(tensor.toByteArray()); if (PlatformDetection.isWindows()) shma.close(); } @@ -137,7 +137,7 @@ private static void buildFromTensorLong(NDArray tensor, String memoryName) throw + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); - SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true); + SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), true, true); shma.getDataBufferNoHeader().put(tensor.toByteArray()); if (PlatformDetection.isWindows()) shma.close(); }