Skip to content

Commit

Permalink
WIP investigate gw bug (#2130)
Browse files Browse the repository at this point in the history
* investigate gw bug

* fix

* fix

* fix

* fix
  • Loading branch information
tudor-malene authored Nov 8, 2024
1 parent 883ca07 commit 7547940
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 33 deletions.
6 changes: 2 additions & 4 deletions lib/gethfork/node/extract_params_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ func newHTTPParamsHandler(exposedParam string, next http.Handler) http.Handler {
func (handler *httpParamsHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
val := q.Get(handler.exposedParam)
if len(val) > 0 {
ctx := context.WithValue(r.Context(), rpc.GWTokenKey{}, val)
handler.next.ServeHTTP(out, r.WithContext(ctx))
}
ctx := context.WithValue(r.Context(), rpc.GWTokenKey{}, val)
handler.next.ServeHTTP(out, r.WithContext(ctx))
handler.next.ServeHTTP(out, r)
}
3 changes: 2 additions & 1 deletion lib/gethfork/rpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ func (h *handler) startCallProc(fn func(*callProc)) {
defer h.callWG.Done()
defer cancel()
// handle the case when normal rpc calls are made over a ws connection
if ctx.Value(GWTokenKey{}) == nil {
v, ok := ctx.Value(GWTokenKey{}).(string)
if !ok || len(v) == 0 {
ctx = context.WithValue(ctx, GWTokenKey{}, hexutils.BytesToHex(h.UserID))
}
fn(&callProc{ctx: ctx})
Expand Down
9 changes: 3 additions & 6 deletions tools/walletextension/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package common

import (
"github.com/ten-protocol/go-ten/go/common/viewingkey"
"golang.org/x/exp/maps"

"github.com/ethereum/go-ethereum/common"
)
Expand All @@ -19,10 +20,6 @@ type GWUser struct {
UserKey []byte
}

func (u GWUser) GetAllAddresses() []*common.Address {
accts := make([]*common.Address, 0)
for _, acc := range u.Accounts {
accts = append(accts, acc.Address)
}
return accts
func (u GWUser) GetAllAddresses() []common.Address {
return maps.Keys(u.Accounts)
}
2 changes: 1 addition & 1 deletion tools/walletextension/httpapi/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func getUserID(conn UserConn) ([]byte, error) {
return hexutils.HexToBytes(userID), nil
}

return nil, fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d od %d", len(userID), common.MessageUserIDLenWithPrefix, common.MessageUserIDLen))
return nil, fmt.Errorf("wrong length of userID from URL. Got: %d, Expected: %d od %d", len(userID), common.MessageUserIDLenWithPrefix, common.MessageUserIDLen)
}

return nil, fmt.Errorf("missing token field")
Expand Down
10 changes: 5 additions & 5 deletions tools/walletextension/rpcapi/filter_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp
errorChannels := make([]<-chan error, 0)
backendSubscriptions := make([]*rpc.ClientSubscription, 0)
for _, address := range candidateAddresses {
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[*address])
rpcWSClient, err := api.we.BackendRPC.ConnectWS(ctx, user.Accounts[address])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,14 +174,14 @@ func getUserAndNotifier(ctx context.Context, api *FilterAPI) (*rpc.Notifier, *we
return subNotifier, user, nil
}

func searchForAddressInFilterCriteria(filterCriteria common.FilterCriteria, possibleAddresses []*gethcommon.Address) []*gethcommon.Address {
result := make([]*gethcommon.Address, 0)
func searchForAddressInFilterCriteria(filterCriteria common.FilterCriteria, possibleAddresses []gethcommon.Address) []gethcommon.Address {
result := make([]gethcommon.Address, 0)
addrMap := toMap(possibleAddresses)
for _, topicCondition := range filterCriteria.Topics {
for _, topic := range topicCondition {
potentialAddr := common.ExtractPotentialAddress(topic)
if potentialAddr != nil && addrMap[*potentialAddr] != nil {
result = append(result, potentialAddr)
if potentialAddr != nil && addrMap[*potentialAddr] {
result = append(result, *potentialAddr)
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions tools/walletextension/rpcapi/from_tx_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"github.com/ten-protocol/go-ten/go/common/gethapi"
)

func searchFromAndData(possibleAddresses []*common.Address, args gethapi.TransactionArgs) *common.Address {
if args.From != nil {
func searchFromAndData(possibleAddresses []common.Address, args gethapi.TransactionArgs) *common.Address {
if args.From != nil && (*args.From != common.Address{}) {
return args.From
}

Expand All @@ -22,7 +22,7 @@ func searchFromAndData(possibleAddresses []*common.Address, args gethapi.Transac
return searchDataFieldForAccount(addressesMap, *args.Data)
}

func searchDataFieldForAccount(addressesMap map[common.Address]*common.Address, data []byte) *common.Address {
func searchDataFieldForAccount(addressesMap map[common.Address]bool, data []byte) *common.Address {
hexEncodedData := hexutils.BytesToHex(data)

// We check that the data field is long enough before removing the leading "0x" (1 bytes/2 chars) and the method ID
Expand Down Expand Up @@ -57,10 +57,10 @@ func searchDataFieldForAccount(addressesMap map[common.Address]*common.Address,
return nil
}

func toMap(possibleAddresses []*common.Address) map[common.Address]*common.Address {
addresses := map[common.Address]*common.Address{}
func toMap(possibleAddresses []common.Address) map[common.Address]bool {
addresses := map[common.Address]bool{}
for i := range possibleAddresses {
addresses[*possibleAddresses[i]] = possibleAddresses[i]
addresses[possibleAddresses[i]] = true
}
return addresses
}
5 changes: 4 additions & 1 deletion tools/walletextension/rpcapi/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func ExecAuthRPC[R any](ctx context.Context, w *services.Services, cfg *ExecCfg,
return res, err
}

func getCandidateAccounts(user *common.GWUser, _ *services.Services, cfg *ExecCfg) ([]*common.GWAccount, error) {
func getCandidateAccounts(user *common.GWUser, we *services.Services, cfg *ExecCfg) ([]*common.GWAccount, error) {
candidateAccts := make([]*common.GWAccount, 0)
// for users with multiple accounts try to determine a candidate account based on the available information
switch {
Expand All @@ -171,6 +171,9 @@ func getCandidateAccounts(user *common.GWUser, _ *services.Services, cfg *ExecCf
if acc != nil {
candidateAccts = append(candidateAccts, acc)
return candidateAccts, nil
} else {
// this should not happen, because the suggestedAddress is one of the addresses
return nil, fmt.Errorf("should not happen. From: %s . UserId: %s", suggestedAddress.Hex(), hexutils.BytesToHex(user.UserID))
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions tools/walletextension/storage/database/common/db_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package common
import (
"github.com/ethereum/go-ethereum/common"
"github.com/ten-protocol/go-ten/go/common/viewingkey"
common2 "github.com/ten-protocol/go-ten/tools/walletextension/common"
wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common"
)

type GWUserDB struct {
Expand All @@ -18,22 +18,22 @@ type GWAccountDB struct {
SignatureType int `json:"signatureType"`
}

func (userDB *GWUserDB) ToGWUser() *common2.GWUser {
result := &common2.GWUser{
func (userDB *GWUserDB) ToGWUser() *wecommon.GWUser {
result := &wecommon.GWUser{
UserID: userDB.UserId,
Accounts: make(map[common.Address]*common2.GWAccount),
Accounts: make(map[common.Address]*wecommon.GWAccount),
UserKey: userDB.PrivateKey,
}

for _, accountDB := range userDB.Accounts {
address := common.BytesToAddress(accountDB.AccountAddress)
gwAccount := &common2.GWAccount{
gwAccount := wecommon.GWAccount{
User: result,
Address: &address,
Signature: accountDB.Signature,
SignatureType: viewingkey.SignatureType(accountDB.SignatureType),
}
result.Accounts[address] = gwAccount
result.Accounts[address] = &gwAccount
}

return result
Expand Down
19 changes: 16 additions & 3 deletions tools/walletextension/storage/database/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type SqliteDB struct {
db *sql.DB
}

const sqliteCfg = "_foreign_keys=on&_journal_mode=wal&_txlock=immediate&_synchronous=normal"

func NewSqliteDatabase(dbPath string) (*SqliteDB, error) {
// load the db file
dbFilePath, err := createOrLoad(dbPath)
Expand All @@ -36,7 +38,8 @@ func NewSqliteDatabase(dbPath string) (*SqliteDB, error) {
}

// open the db
db, err := sql.Open("sqlite3", dbFilePath)
path := fmt.Sprintf("file:%s?%s", dbFilePath, sqliteCfg)
db, err := sql.Open("sqlite3", path)
if err != nil {
fmt.Println("Error opening database: ", err)
return nil, err
Expand Down Expand Up @@ -104,7 +107,13 @@ func (s *SqliteDB) DeleteUser(userID []byte) error {

func (s *SqliteDB) AddAccount(userID []byte, accountAddress []byte, signature []byte, signatureType viewingkey.SignatureType) error {
var userDataJSON string
err := s.db.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON)
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()

err = tx.QueryRow("SELECT user_data FROM users WHERE id = ?", string(userID)).Scan(&userDataJSON)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
Expand All @@ -128,7 +137,7 @@ func (s *SqliteDB) AddAccount(userID []byte, accountAddress []byte, signature []
return fmt.Errorf("error marshaling updated user: %w", err)
}

stmt, err := s.db.Prepare("UPDATE users SET user_data = ? WHERE id = ?")
stmt, err := tx.Prepare("UPDATE users SET user_data = ? WHERE id = ?")
if err != nil {
return err
}
Expand All @@ -138,6 +147,10 @@ func (s *SqliteDB) AddAccount(userID []byte, accountAddress []byte, signature []
if err != nil {
return fmt.Errorf("failed to update user with new account: %w", err)
}
err = tx.Commit()
if err != nil {
return err
}

return nil
}
Expand Down

0 comments on commit 7547940

Please sign in to comment.