-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Vicente Cheng <[email protected]>
- Loading branch information
1 parent
22ebbe4
commit 036be8c
Showing
2 changed files
with
579 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
package io | ||
|
||
/* | ||
#include <fcntl.h> | ||
#include <stdlib.h> | ||
#include <unistd.h> | ||
#include <string.h> | ||
#include <errno.h> | ||
// Safe_pwrite is a wrapper around pwrite(2) that retries on EINTR. | ||
ssize_t safe_pwrite(int fd, const void *buf, size_t count, off_t offset) | ||
{ | ||
while (count > 0) { | ||
ssize_t r = pwrite(fd, buf, count, offset); | ||
if (r < 0) { | ||
if (errno == EINTR) | ||
continue; | ||
return -errno; | ||
} | ||
count -= r; | ||
buf = (char *)buf + r; | ||
offset += r; | ||
} | ||
return 0; | ||
} | ||
ssize_t safe_pread(int fd, void *buf, size_t count, off_t offset) | ||
{ | ||
size_t cnt = 0; | ||
char *b = (char*)buf; | ||
while (cnt < count) { | ||
ssize_t r = pread(fd, b + cnt, count - cnt, offset + cnt); | ||
if (r <= 0) { | ||
if (r == 0) { | ||
// EOF | ||
return cnt; | ||
} | ||
if (errno == EINTR) | ||
continue; | ||
return -errno; | ||
} | ||
cnt += r; | ||
} | ||
return cnt; | ||
} | ||
ssize_t safe_pread_exact(int fd, void *buf, size_t count, off_t offset) | ||
{ | ||
ssize_t ret = safe_pread(fd, buf, count, offset); | ||
if (ret < 0) | ||
return ret; | ||
if ((size_t)ret != count) | ||
return -EDOM; | ||
return 0; | ||
} | ||
// Write data to a file descriptor with O_DIRECT | ||
int directWrite(int fd, void *buf, size_t count, off_t offset) { | ||
return safe_pwrite(fd, buf, count, offset); | ||
} | ||
// Read data from a file descriptor with O_DIRECT | ||
int directRead(int fd, void *buf, size_t count, off_t offset) { | ||
return safe_pread(fd, buf, count, offset); | ||
} | ||
*/ | ||
import "C" | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"sync" | ||
"unsafe" | ||
) | ||
|
||
const ( | ||
baseAlignSize = 4096 | ||
maxChunkSize = 4194304 | ||
producerNum = 8 | ||
workerNum = 6 | ||
) | ||
|
||
type Content struct { | ||
offset uint64 | ||
buf []byte | ||
} | ||
|
||
func Copy(src *os.File, dst *os.File, chunkSize int) error { | ||
var ioError error | ||
ioProducer := func(seqID int, src *os.File, chunkNums int, chunkSize int, objs chan<- Content) { | ||
for i := seqID; i < chunkNums; i += producerNum { | ||
buf := make([]byte, chunkSize) | ||
_, err := PReadExact(src, buf, chunkSize, uint64(i*chunkSize)) | ||
if err != nil { | ||
fmt.Printf("Error reading data: %v\n", err) | ||
ioError = err | ||
continue | ||
} | ||
objs <- Content{offset: uint64(i * chunkSize), buf: buf} | ||
} | ||
} | ||
|
||
ioWorker := func(ioQueue chan Content, wg *sync.WaitGroup) { | ||
for { | ||
obj, ok := <-ioQueue | ||
if !ok { | ||
return | ||
} | ||
_, err := PWrite(dst, obj.buf, len(obj.buf), obj.offset, chunkSize) | ||
wg.Done() | ||
if err != nil { | ||
fmt.Printf("Error writing data: %v\n", err) | ||
ioError = err | ||
} | ||
} | ||
} | ||
|
||
srcInfo, err := src.Stat() | ||
if err != nil { | ||
fmt.Printf("Error getting file info: %v\n", err) | ||
return err | ||
} | ||
srcSize := srcInfo.Size() | ||
|
||
if chunkSize > maxChunkSize { | ||
return fmt.Errorf("chunk size is too large, max chunk size is %d", maxChunkSize) | ||
} | ||
if chunkSize%baseAlignSize != 0 { | ||
return fmt.Errorf("chunk size must be a multiple of %d", baseAlignSize) | ||
} | ||
|
||
// Create a channel to receive the results | ||
ioQ := make(chan Content, producerNum) | ||
|
||
chunkNums := int(srcSize) / chunkSize | ||
if int(srcSize)%chunkSize != 0 { | ||
chunkNums++ | ||
} | ||
var wg sync.WaitGroup | ||
wg.Add(chunkNums) | ||
|
||
for i := 0; i < producerNum; i++ { | ||
go ioProducer(i, src, chunkNums, chunkSize, ioQ) | ||
} | ||
for i := 0; i < workerNum; i++ { | ||
go ioWorker(ioQ, &wg) | ||
} | ||
|
||
wg.Wait() | ||
close(ioQ) | ||
|
||
if ioError != nil { | ||
fmt.Printf("Error: %v\n", ioError) | ||
return ioError | ||
} | ||
return nil | ||
|
||
} | ||
|
||
func Write(dst *os.File, data []byte, size int, chunkSize int) error { | ||
var ioError error | ||
ioProducer := func(seqID, chunkNums int, objs chan<- Content) { | ||
for i := seqID; i < chunkNums; i += producerNum { | ||
start := i * chunkSize | ||
end := start + chunkSize | ||
if i == chunkNums-1 { | ||
end = len(data) | ||
} | ||
chunk := data[start:end] | ||
objs <- Content{offset: uint64(start), buf: chunk} | ||
} | ||
} | ||
|
||
ioWorker := func(ioQueue chan Content, wg *sync.WaitGroup) { | ||
for { | ||
obj, ok := <-ioQueue | ||
if !ok { | ||
return | ||
} | ||
_, err := PWrite(dst, obj.buf, len(obj.buf), obj.offset, chunkSize) | ||
wg.Done() | ||
if err != nil { | ||
fmt.Printf("Error writing data: %v\n", err) | ||
ioError = err | ||
} | ||
} | ||
} | ||
|
||
if chunkSize > maxChunkSize { | ||
return fmt.Errorf("chunk size is too large, max chunk size is %d", maxChunkSize) | ||
} | ||
if chunkSize%baseAlignSize != 0 { | ||
return fmt.Errorf("chunk size must be a multiple of %d", baseAlignSize) | ||
} | ||
|
||
// Calculate the number of chunks based on the alignment size | ||
numChunks := len(data) / chunkSize | ||
if len(data)%chunkSize != 0 { | ||
numChunks++ | ||
} | ||
|
||
// Create a channel to receive the results | ||
ioQ := make(chan Content, producerNum) | ||
|
||
var wg sync.WaitGroup | ||
wg.Add(numChunks) | ||
|
||
for i := 0; i < producerNum; i++ { | ||
go ioProducer(i, numChunks, ioQ) | ||
} | ||
for i := 0; i < workerNum; i++ { | ||
go ioWorker(ioQ, &wg) | ||
} | ||
|
||
wg.Wait() | ||
close(ioQ) | ||
|
||
if ioError != nil { | ||
fmt.Printf("Error: %v\n", ioError) | ||
return ioError | ||
} | ||
return nil | ||
} | ||
|
||
func PWrite(dst *os.File, data []byte, size int, offset uint64, alignSize int) (int, error) { | ||
// Special case, we need to handle the last chunk here. | ||
for size < alignSize && size != 0 { | ||
if alignSize == baseAlignSize { | ||
_, err := PWrite(dst, data, size, offset, alignSize) | ||
if err != nil { | ||
return 0, fmt.Errorf("error writing data: %v", err) | ||
} | ||
// final write, directly set size to zero | ||
size = 0 | ||
break | ||
} | ||
shiftBits := calShiftBits(size, alignSize) | ||
alignSize = alignSize >> shiftBits | ||
_, err := PWrite(dst, data[0:alignSize-1], alignSize, offset, alignSize) | ||
if err != nil { | ||
return 0, fmt.Errorf("error writing data: %v", err) | ||
} | ||
size -= alignSize | ||
data = data[alignSize:] | ||
offset += uint64(alignSize) | ||
} | ||
if size == 0 { | ||
return 0, nil | ||
} | ||
|
||
var writeBuffer unsafe.Pointer | ||
if C.posix_memalign((*unsafe.Pointer)(unsafe.Pointer(&writeBuffer)), C.size_t(alignSize), C.size_t(size)) != 0 { | ||
fmt.Printf("Error allocating aligned memory\n") | ||
return 0, fmt.Errorf("error allocating aligned memory") | ||
} | ||
defer C.free(unsafe.Pointer(writeBuffer)) | ||
|
||
// Copy the Go data into the C buffer | ||
C.memcpy(writeBuffer, unsafe.Pointer(&data[0]), C.size_t(size)) | ||
|
||
// Call the C function to write with O_DIRECT | ||
ret := C.directWrite(C.int(dst.Fd()), writeBuffer, C.size_t(alignSize), C.off_t(offset)) | ||
if ret < 0 { | ||
fmt.Printf("Error writing data: %v\n", ret) | ||
return 0, fmt.Errorf("error writing data") | ||
} | ||
|
||
return int(ret), nil | ||
} | ||
|
||
func PReadExact(src *os.File, buf []byte, count int, offset uint64) (int, error) { | ||
var readBuffer unsafe.Pointer | ||
if C.posix_memalign((*unsafe.Pointer)(unsafe.Pointer(&readBuffer)), C.size_t(count), C.size_t(count)) != 0 { | ||
fmt.Printf("Error allocating aligned memory\n") | ||
return 0, fmt.Errorf("error allocating aligned memory") | ||
} | ||
defer C.free(unsafe.Pointer(readBuffer)) | ||
|
||
// Call the C function to read with O_DIRECT | ||
ret := C.directRead(C.int(src.Fd()), readBuffer, C.size_t(count), C.off_t(offset)) | ||
if ret < 0 { | ||
fmt.Printf("Error reading data: %v\n", ret) | ||
return 0, fmt.Errorf("error reading data") | ||
} | ||
|
||
// Copy the C data into the Go buffer | ||
C.memcpy(unsafe.Pointer(&buf[0]), readBuffer, C.size_t(count)) | ||
|
||
return count, nil | ||
} | ||
|
||
func calShiftBits(size, alignSize int) int { | ||
shiftBits := 0 | ||
for { | ||
if size >= alignSize { | ||
break | ||
} | ||
alignSize >>= 1 | ||
shiftBits++ | ||
} | ||
|
||
return shiftBits | ||
} |
Oops, something went wrong.