-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathservercodec.go
144 lines (117 loc) · 3.1 KB
/
servercodec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package sunrpc
import (
"bytes"
"io"
"log"
"net/rpc"
"github.com/rasky/go-xdr/xdr2"
)
type serverCodec struct {
conn io.ReadWriteCloser
closed bool
notifyClose chan<- io.ReadWriteCloser
recordReader io.Reader
}
// NewServerCodec returns a new rpc.ServerCodec using Sun RPC on conn.
// If a non-nil channel is passed as second argument, the conn is sent on
// that channel when Close() is called on conn.
func NewServerCodec(conn io.ReadWriteCloser, notifyClose chan<- io.ReadWriteCloser) rpc.ServerCodec {
return &serverCodec{conn: conn, notifyClose: notifyClose}
}
func (c *serverCodec) ReadRequestHeader(req *rpc.Request) error {
// NOTE:
// Errors returned by this function aren't relayed back to the client
// as WriteResponse() isn't called. The net/rpc package will call
// c.Close() when this function returns an error.
// Read entire RPC message from network
record, err := ReadFullRecord(c.conn)
if err != nil {
if err != io.EOF {
log.Println(err)
}
return err
}
c.recordReader = bytes.NewReader(record)
// Unmarshall RPC message
var call RPCMsg
_, err = xdr.Unmarshal(c.recordReader, &call)
if err != nil {
log.Println(err)
return err
}
if call.Type != Call {
log.Println(ErrInvalidRPCMessageType)
return ErrInvalidRPCMessageType
}
// Set req.Seq and req.ServiceMethod
req.Seq = uint64(call.Xid)
procedureID := ProcedureID{call.CBody.Program, call.CBody.Version, call.CBody.Procedure}
procedureName, ok := GetProcedureName(procedureID)
if ok {
req.ServiceMethod = procedureName
} else {
// Due to our simpler map implementation, we cannot distinguish
// between ErrProgUnavail and ErrProcUnavail
log.Printf("%s: %+v\n", ErrProcUnavail, procedureID)
return ErrProcUnavail
}
return nil
}
func (c *serverCodec) ReadRequestBody(funcArgs interface{}) error {
if funcArgs == nil {
return nil
}
if _, err := xdr.Unmarshal(c.recordReader, &funcArgs); err != nil {
c.Close()
return err
}
return nil
}
func (c *serverCodec) WriteResponse(resp *rpc.Response, result interface{}) error {
if resp.Error != "" {
// The remote function returned error (shouldn't really happen)
log.Println(resp.Error)
}
var buf bytes.Buffer
reply := RPCMsg{
Xid: uint32(resp.Seq),
Type: Reply,
RBody: ReplyBody{
Stat: MsgAccepted,
Areply: AcceptedReply{
Stat: Success,
},
},
}
if _, err := xdr.Marshal(&buf, reply); err != nil {
c.Close()
return err
}
// Marshal and fill procedure-specific reply into the buffer
if _, err := xdr.Marshal(&buf, result); err != nil {
c.Close()
return err
}
// Write buffer contents to network
if _, err := WriteFullRecord(c.conn, buf.Bytes()); err != nil {
c.Close()
return err
}
return nil
}
func (c *serverCodec) Close() error {
if c.closed {
return nil
}
err := c.conn.Close()
if err == nil {
c.closed = true
if c.notifyClose != nil {
c.notifyClose <- c.conn
}
}
return err
}