diff --git a/io/io.go b/io/io.go new file mode 100644 index 0000000..f85bb07 --- /dev/null +++ b/io/io.go @@ -0,0 +1,332 @@ +package io + +/* +#include +#include +#include +#include +#include + +// Safe_pwrite is a wrapper around pwrite(2) that retries on EINTR. +ssize_t safe_pwrite(int fd, const void *buf, size_t count, off_t offset) +{ + while (count > 0) { + ssize_t r = pwrite(fd, buf, count, offset); + if (r < 0) { + if (errno == EINTR) + continue; + return -errno; + } + count -= r; + buf = (char *)buf + r; + offset += r; + } + return 0; +} +ssize_t safe_pread(int fd, void *buf, size_t count, off_t offset) +{ + size_t cnt = 0; + char *b = (char*)buf; + + while (cnt < count) { + ssize_t r = pread(fd, b + cnt, count - cnt, offset + cnt); + if (r <= 0) { + if (r == 0) { + // EOF + return cnt; + } + if (errno == EINTR) + continue; + return -errno; + } + cnt += r; + } + return cnt; +} +ssize_t safe_pread_exact(int fd, void *buf, size_t count, off_t offset) +{ + ssize_t ret = safe_pread(fd, buf, count, offset); + if (ret < 0) + return ret; + if ((size_t)ret != count) + return -EDOM; + return 0; +} +// Write data to a file descriptor with O_DIRECT +int directWrite(int fd, void *buf, size_t count, off_t offset) { + return safe_pwrite(fd, buf, count, offset); +} +// Read data from a file descriptor with O_DIRECT +int directRead(int fd, void *buf, size_t count, off_t offset) { + return safe_pread(fd, buf, count, offset); +} +*/ +import "C" + +import ( + "fmt" + "os" + "sync" + "syscall" + "unsafe" +) + +const ( + baseAlignSize = 4096 + maxChunkSize = 4194304 + producerNum = 8 + workerNum = 6 + BLKGETSIZE64 = 0x80081272 +) + +type Content struct { + offset uint64 + buf []byte +} + +func Copy(src *os.File, dst *os.File, chunkSize int) error { + var ioError error + ioProducer := func(seqID int, src *os.File, chunkNums int, chunkSize int, objs chan<- Content) { + for i := seqID; i < chunkNums; i += producerNum { + buf := make([]byte, chunkSize) + _, err := PReadExact(src, buf, chunkSize, uint64(i*chunkSize)) + if err != nil { + fmt.Printf("Error reading data: %v\n", err) + ioError = err + continue + } + objs <- Content{offset: uint64(i * chunkSize), buf: buf} + } + } + + ioWorker := func(ioQueue chan Content, wg *sync.WaitGroup) { + for { + obj, ok := <-ioQueue + if !ok { + return + } + _, err := PWrite(dst, obj.buf, len(obj.buf), obj.offset, chunkSize) + wg.Done() + if err != nil { + fmt.Printf("Error writing data: %v\n", err) + ioError = err + } + } + } + + srcSize, err := getSourceVolSize(src) + if err != nil { + fmt.Printf("Error getting file size: %v\n", err) + return fmt.Errorf("error getting file size") + } + + if chunkSize > maxChunkSize { + return fmt.Errorf("chunk size is too large, max chunk size is %d", maxChunkSize) + } + if chunkSize%baseAlignSize != 0 { + return fmt.Errorf("chunk size must be a multiple of %d", baseAlignSize) + } + + // Create a channel to receive the results + ioQ := make(chan Content, producerNum) + + chunkNums := int(srcSize) / chunkSize + if int(srcSize)%chunkSize != 0 { + chunkNums++ + } + var wg sync.WaitGroup + wg.Add(chunkNums) + + for i := 0; i < producerNum; i++ { + go ioProducer(i, src, chunkNums, chunkSize, ioQ) + } + for i := 0; i < workerNum; i++ { + go ioWorker(ioQ, &wg) + } + + wg.Wait() + close(ioQ) + + if ioError != nil { + fmt.Printf("Error: %v\n", ioError) + return ioError + } + return nil + +} + +func Write(dst *os.File, data []byte, size int, chunkSize int) error { + var ioError error + ioProducer := func(seqID, chunkNums int, objs chan<- Content) { + for i := seqID; i < chunkNums; i += producerNum { + start := i * chunkSize + end := start + chunkSize + if i == chunkNums-1 { + end = len(data) + } + chunk := data[start:end] + objs <- Content{offset: uint64(start), buf: chunk} + } + } + + ioWorker := func(ioQueue chan Content, wg *sync.WaitGroup) { + for { + obj, ok := <-ioQueue + if !ok { + return + } + _, err := PWrite(dst, obj.buf, len(obj.buf), obj.offset, chunkSize) + wg.Done() + if err != nil { + fmt.Printf("Error writing data: %v\n", err) + ioError = err + } + } + } + + if chunkSize > maxChunkSize { + return fmt.Errorf("chunk size is too large, max chunk size is %d", maxChunkSize) + } + if chunkSize%baseAlignSize != 0 { + return fmt.Errorf("chunk size must be a multiple of %d", baseAlignSize) + } + + // Calculate the number of chunks based on the alignment size + numChunks := size / chunkSize + if len(data)%chunkSize != 0 { + numChunks++ + } + + // Create a channel to receive the results + ioQ := make(chan Content, producerNum) + + var wg sync.WaitGroup + wg.Add(numChunks) + + for i := 0; i < producerNum; i++ { + go ioProducer(i, numChunks, ioQ) + } + for i := 0; i < workerNum; i++ { + go ioWorker(ioQ, &wg) + } + + wg.Wait() + close(ioQ) + + if ioError != nil { + fmt.Printf("Error: %v\n", ioError) + return ioError + } + return nil +} + +func PWrite(dst *os.File, data []byte, size int, offset uint64, alignSize int) (int, error) { + // Special case, we need to handle the last chunk here. + // If last chunk < baseAlignSize, we can write it directly. + // If last baseAlignSize < last chunk < alignSize, we need to split it. + for size > baseAlignSize && size < alignSize && size != 0 { + if alignSize == baseAlignSize { + _, err := PWrite(dst, data, size, offset, alignSize) + if err != nil { + return 0, fmt.Errorf("error writing data: %v", err) + } + // final write, directly set size to zero + size = 0 + break + } + shiftBits := calShiftBits(size, alignSize) + alignSize = alignSize >> shiftBits + _, err := PWrite(dst, data[0:alignSize-1], alignSize, offset, alignSize) + if err != nil { + return 0, fmt.Errorf("error writing data: %v", err) + } + size -= alignSize + data = data[alignSize:] + offset += uint64(alignSize) + } + if size == 0 { + return 0, nil + } + + var writeBuffer unsafe.Pointer + if C.posix_memalign((*unsafe.Pointer)(unsafe.Pointer(&writeBuffer)), C.size_t(alignSize), C.size_t(size)) != 0 { + fmt.Printf("Error allocating aligned memory\n") + return 0, fmt.Errorf("error allocating aligned memory") + } + defer C.free(unsafe.Pointer(writeBuffer)) + + // Copy the Go data into the C buffer + C.memcpy(writeBuffer, unsafe.Pointer(&data[0]), C.size_t(size)) + + // Call the C function to write with O_DIRECT + ret := C.directWrite(C.int(dst.Fd()), writeBuffer, C.size_t(alignSize), C.off_t(offset)) + if ret < 0 { + fmt.Printf("Error writing data: %v\n", ret) + return 0, fmt.Errorf("error writing data") + } + + return int(ret), nil +} + +func PReadExact(src *os.File, buf []byte, count int, offset uint64) (int, error) { + var readBuffer unsafe.Pointer + if C.posix_memalign((*unsafe.Pointer)(unsafe.Pointer(&readBuffer)), C.size_t(count), C.size_t(count)) != 0 { + fmt.Printf("Error allocating aligned memory\n") + return 0, fmt.Errorf("error allocating aligned memory") + } + defer C.free(unsafe.Pointer(readBuffer)) + + // Call the C function to read with O_DIRECT + ret := C.directRead(C.int(src.Fd()), readBuffer, C.size_t(count), C.off_t(offset)) + if ret < 0 { + fmt.Printf("Error reading data: %v\n", ret) + return 0, fmt.Errorf("error reading data") + } + + // Copy the C data into the Go buffer + C.memcpy(unsafe.Pointer(&buf[0]), readBuffer, C.size_t(count)) + + return count, nil +} + +func calShiftBits(size, alignSize int) int { + shiftBits := 0 + for { + if size >= alignSize { + break + } + alignSize >>= 1 + shiftBits++ + } + + return shiftBits +} + +func getSourceVolSize(src *os.File) (uint64, error) { + + var srcSize uint64 + srcInfo, err := src.Stat() + if err != nil { + return 0, err + } + + if srcInfo.Mode().IsRegular() { + // file size should not be negative, directly return as uint64 + return uint64(srcInfo.Size()), nil + } + + if (srcInfo.Mode() & os.ModeDevice) != 0 { + _, _, err := syscall.Syscall( + syscall.SYS_IOCTL, + src.Fd(), + BLKGETSIZE64, + uintptr(unsafe.Pointer(&srcSize)), + ) + if err != 0 { + return 0, fmt.Errorf("error getting file size: %v", err) + } + return srcSize, nil + } + + return 0, fmt.Errorf("unsupported file type: %v", srcInfo.Mode()) +} diff --git a/io/io_test.go b/io/io_test.go new file mode 100644 index 0000000..fc12af4 --- /dev/null +++ b/io/io_test.go @@ -0,0 +1,327 @@ +package io + +import ( + "crypto/rand" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type IOTestSuite struct { + suite.Suite +} + +func TestIOTestSuite(t *testing.T) { + suite.Run(t, new(IOTestSuite)) +} + +func (suite *IOTestSuite) TestWriteAlignSmallFile() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "512B_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 512) // 512 bytes + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestWriteUnalignSmallFile() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "777B_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 777) // 777 bytes + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestWriteAlign() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "4M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 4*1024*1024) // 4M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestWriteUnalign() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "5M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 5*1024*1024) // 5M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestWriteAlignBigChunk() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "128M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 128*1024*1024) // 128M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 1048576) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestWriteUnalignBigChunk() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "135M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 135*1024*1024) // 135M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4194304) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) +} + +func (suite *IOTestSuite) TestCopyAlign() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "4M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 4*1024*1024) // 4M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) + + // Create a temporary dstFile for testing + dstFile, err := os.CreateTemp("", "dstfile") + suite.Require().NoError(err) + defer os.Remove(dstFile.Name()) + + // Copy the data from srcFile to dstFile using Copy + err = Copy(srcFile, dstFile, 4096) + suite.Require().NoError(err) + + // Read the written data from the dstFile + dstData := make([]byte, len(data)) + _, err = dstFile.ReadAt(dstData, 0) + suite.Require().NoError(err) + // Assert that the written data in dstFile matches the original data + assert.Equal(suite.T(), data, dstData) +} + +func (suite *IOTestSuite) TestCopyUnalign() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "5M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 5*1024*1024) // 5M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4096) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) + + // Create a temporary dstFile for testing + dstFile, err := os.CreateTemp("", "dstfile") + suite.Require().NoError(err) + defer os.Remove(dstFile.Name()) + + // Copy the data from srcFile to dstFile using Copy + err = Copy(srcFile, dstFile, 4096) + suite.Require().NoError(err) + + // Read the written data from the dstFile + dstData := make([]byte, len(data)) + _, err = dstFile.ReadAt(dstData, 0) + suite.Require().NoError(err) + // Assert that the written data in dstFile matches the original data + assert.Equal(suite.T(), data, dstData) +} + +func (suite *IOTestSuite) TestCopyAlignBigChunk() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "128M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 128*1024*1024) // 128M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4194304) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) + + // Create a temporary dstFile for testing + dstFile, err := os.CreateTemp("", "dstfile") + suite.Require().NoError(err) + defer os.Remove(dstFile.Name()) + + // Copy the data from srcFile to dstFile using Copy + err = Copy(srcFile, dstFile, 4194304) + suite.Require().NoError(err) + + // Read the written data from the dstFile + dstData := make([]byte, len(data)) + _, err = dstFile.ReadAt(dstData, 0) + suite.Require().NoError(err) + // Assert that the written data in dstFile matches the original data + assert.Equal(suite.T(), data, dstData) +} + +func (suite *IOTestSuite) TestCopyUnalignBigChunk() { + // Create a temporary srcFile for testing + srcFile, err := os.CreateTemp("", "135M_file") + suite.Require().NoError(err) + defer os.Remove(srcFile.Name()) + + // Generate random data + data := make([]byte, 135*1024*1024) // 135M + _, err = rand.Read(data) + suite.Require().NoError(err) + + // Write the data to the file using IOWrite + err = Write(srcFile, data, len(data), 4194304) + suite.Require().NoError(err) + + // Read the written data from the srcFile + readData := make([]byte, len(data)) + _, err = srcFile.ReadAt(readData, 0) + suite.Require().NoError(err) + + // Assert that the written data matches the original data + assert.Equal(suite.T(), data, readData) + + // Create a temporary dstFile for testing + dstFile, err := os.CreateTemp("", "dstfile") + suite.Require().NoError(err) + defer os.Remove(dstFile.Name()) + + // Copy the data from srcFile to dstFile using Copy + err = Copy(srcFile, dstFile, 4194304) + suite.Require().NoError(err) + + // Read the written data from the dstFile + dstData := make([]byte, len(data)) + _, err = dstFile.ReadAt(dstData, 0) + suite.Require().NoError(err) + // Assert that the written data in dstFile matches the original data + assert.Equal(suite.T(), data, dstData) +} + +func (suite *IOTestSuite) TestCalShiftBits() { + assert.Equal(suite.T(), 1, calShiftBits(3145728, 4194304)) + assert.Equal(suite.T(), 2, calShiftBits(1048576, 4194304)) +}