diff --git a/clients/iota-go/iotaconn/websocket.go b/clients/iota-go/iotaconn/websocket.go index fbe9280a72..d0d978c8e5 100644 --- a/clients/iota-go/iotaconn/websocket.go +++ b/clients/iota-go/iotaconn/websocket.go @@ -8,7 +8,9 @@ import ( "strconv" "sync" "sync/atomic" + "time" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/iotaledger/hive.go/logger" @@ -16,11 +18,15 @@ import ( type WebsocketClient struct { idCounter uint32 + url string conn *websocket.Conn writeQueue chan *jsonrpcMessage readers sync.Map // id -> chan *jsonrpcMessage log *logger.Logger shutdownWaitGroup sync.WaitGroup + reconnectMx sync.Mutex + subscriptions []*subscription + pendingCalls sync.Map } func NewWebsocketClient( @@ -28,16 +34,19 @@ func NewWebsocketClient( url string, log *logger.Logger, ) (*WebsocketClient, error) { - dialer := websocket.Dialer{} - conn, _, err := dialer.DialContext(ctx, url, nil) - if err != nil { - return nil, fmt.Errorf("failed to connect to websocket server: %w", err) - } + c := &WebsocketClient{ - conn: conn, - writeQueue: make(chan *jsonrpcMessage), - log: log, + url: url, + writeQueue: make(chan *jsonrpcMessage), + log: log, + subscriptions: make([]*subscription, 0, 2), } + + err := c.reconnect(ctx) + if err != nil { + return nil, err + } + c.shutdownWaitGroup.Add(1) go c.loop(ctx) return c, nil @@ -56,15 +65,22 @@ func (c *WebsocketClient) loop(ctx context.Context) { } receivedMsgs := make(chan readMsgResult) go func() { + c.log.Infof("websocket loop started") + defer c.log.Infof("websocket loop finished") defer close(receivedMsgs) for { - m, p, err := c.conn.ReadMessage() + m, p, err := c.readMessage() if err != nil { c.log.Errorf("WebsocketClient read loop: %s", err) - return - } else { - receivedMsgs <- readMsgResult{messageType: m, p: p} + continue } + var j *jsonrpcMessage + if err := json.Unmarshal(p, &j); err != nil { + c.log.Errorf("WebsocketClient: could not unmarshal response body: %s", err) + continue + } + c.log.Debugf("ws message was read: %v, %v", j.ID, j.Method) + receivedMsgs <- readMsgResult{messageType: m, p: p} } }() @@ -79,7 +95,7 @@ func (c *WebsocketClient) loop(ctx context.Context) { c.log.Errorf("WebsocketClient: could not marshal json: %s", err) continue } - err = c.conn.WriteMessage(websocket.TextMessage, reqBody) + err = c.writeMessage(websocket.TextMessage, reqBody) if err != nil { c.log.Errorf("WebsocketClient: write error: %s", err) return @@ -88,6 +104,7 @@ func (c *WebsocketClient) loop(ctx context.Context) { if !ok { return } + switch receivedMsg.messageType { case websocket.TextMessage: var m *jsonrpcMessage @@ -99,6 +116,7 @@ func (c *WebsocketClient) loop(ctx context.Context) { if len(m.ID) > 0 { // this is a response to a method call id = string(m.ID) + c.log.Debugf("response to method call: %+v", m.ID) } else if m.Method != "" { // this is a subscription message var s struct { @@ -109,6 +127,7 @@ func (c *WebsocketClient) loop(ctx context.Context) { continue } id = fmt.Sprintf("%s:%d", m.Method, s.Subscription) + c.log.Debugf("subscription message: %v", id) } else { c.log.Errorf("WebsocketClient: cannot identify message: %s", receivedMsg.p) continue @@ -117,6 +136,7 @@ func (c *WebsocketClient) loop(ctx context.Context) { if ok { readCh.(chan *jsonrpcMessage) <- m } else { + // this can sometimes happen, but it's not an issue: the channel should be associated with the new id by now c.log.Errorf("WebsocketClient: no reader for message: %s", receivedMsg.p) continue } @@ -128,6 +148,37 @@ func (c *WebsocketClient) loop(ctx context.Context) { } } +func (c *WebsocketClient) readMessage() (messageType int, p []byte, err error) { + if c.conn == nil { + return 0, nil, fmt.Errorf("connection is nil") + } + + messageType, p, err = c.conn.ReadMessage() + if err != nil { + c.log.Warnf("read failed: %s", err) + if reconnErr := c.reconnect(context.Background()); reconnErr != nil { + return 0, nil, fmt.Errorf("read failed and reconnect failed: %w", err) + } + return c.readMessage() + } + return messageType, p, nil +} + +func (c *WebsocketClient) writeMessage(messageType int, data []byte) error { + if c.conn == nil { + return fmt.Errorf("connection is nil") + } + err := c.conn.WriteMessage(messageType, data) + if err != nil { + c.log.Warnf("write failed: %s", err) + if reconnErr := c.reconnect(context.Background()); reconnErr != nil { + return fmt.Errorf("write failed and reconnect failed: %w", err) + } + return c.writeMessage(messageType, data) + } + return nil +} + func (c *WebsocketClient) writeMsg(method JsonRPCMethod, args ...interface{}) (string, error) { msg, err := c.newMessage(method.String(), args...) if err != nil { @@ -140,6 +191,19 @@ func (c *WebsocketClient) writeMsg(method JsonRPCMethod, args ...interface{}) (s return id, nil } +type subscription struct { + method JsonRPCMethod + args []interface{} + id string + uuid uuid.UUID +} + +type call struct { + method JsonRPCMethod + args []interface{} + id string +} + func (c *WebsocketClient) CallContext( ctx context.Context, result interface{}, @@ -149,13 +213,22 @@ func (c *WebsocketClient) CallContext( if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) } + id, err := c.writeMsg(method, args...) if err != nil { return err } + + c.pendingCalls.Store(id, &call{method: method, args: args, id: id}) + defer func() { + c.pendingCalls.Delete(id) + }() + readCh, _ := c.readers.Load(id) defer c.readers.Delete(id) + c.log.Debugf("waiting for response to %s", id) respmsg := <-readCh.(chan *jsonrpcMessage) + c.log.Debugf("response to %s received", id) if respmsg.Error != nil { return respmsg.Error } @@ -180,6 +253,14 @@ func (c *WebsocketClient) Subscribe( readCh := make(chan *jsonrpcMessage) c.readers.Store(id, readCh) + c.subscriptions = append(c.subscriptions, &subscription{ + method: method, + args: args, + id: id, + uuid: uuid.New(), + }) + c.log.Debugf("subscribing to %s", method) + go func() { defer close(resultCh) defer c.readers.Delete(id) @@ -201,6 +282,7 @@ func (c *WebsocketClient) Subscribe( c.log.Errorf("could not unmarshal msg.Params: %s", err) continue } + c.log.Debugf("subscription result: %+v", params.Result) resultCh <- params.Result } } @@ -229,3 +311,111 @@ func (c *WebsocketClient) nextID() string { id := atomic.AddUint32(&c.idCounter, 1) return strconv.FormatUint(uint64(id), 10) } + +func (c *WebsocketClient) reconnect(ctx context.Context) error { + c.log.Debugf("reconnecting") + if c.reconnectMx.TryLock() { + defer c.reconnectMx.Unlock() + } else { + // already reconnecting, try again later + time.Sleep(50 * time.Millisecond) + return nil + } + + if c.conn != nil { + c.conn.Close() + } + + const retryInterval = time.Second + attempt := 1 + + for { + dialer := websocket.Dialer{} + conn, _, err := dialer.DialContext(ctx, c.url, nil) + if err != nil { + c.log.Warnf("connection attempt %d failed: %v", attempt, err) + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while reconnecting: %w", ctx.Err()) + case <-time.After(retryInterval): + attempt++ + continue + } + } + + c.conn = conn + c.log.Debugf("new connection set after %d attempts", attempt) + + // recreating subscriptions and recreating pending calls. This should happen asynchronously because it needs the loop to be running + go c.resubscribe(ctx) + go c.recreatePendingCalls() + + return nil + } +} + +// recreatePendingCalls recreates pending calls. Errors in this function will cause particular calls to not complete, so no need to fail other calls +func (c *WebsocketClient) recreatePendingCalls() { + c.pendingCalls.Range(func(key, value interface{}) bool { + call := value.(*call) + oldId := key.(string) + + msg, err := c.newMessage(call.method.String(), call.args...) + if err != nil { + c.log.Errorf("failed to recreate pending call %s: %s", oldId, err) + return true + } + + newId := string(msg.ID) + + c.log.Debugf("recreate writing message: oldId: %s, newId: %s, %+v", oldId, newId, msg) + + ch, ok := c.readers.Load(oldId) + if !ok { + c.log.Errorf("failed to recreate pending call: reader for old id %s not found", oldId) + return true + } + readCh := ch.(chan *jsonrpcMessage) + c.readers.Store(newId, readCh) + c.writeQueue <- msg + + c.readers.Delete(oldId) + + return true + }) +} + +// resubscribe to subscriptions. Errors in this function probably mean that subscription configurations themself contain errors, so ignoring +func (c *WebsocketClient) resubscribe(ctx context.Context) { + c.log.Debugf("resubscribing to %d subscriptions", len(c.subscriptions)) + defer c.log.Debugf("resubscribed") + + for _, sub := range c.subscriptions { + c.log.Debugf("resubscribing to %s, %+v", sub.method, sub.args) + defer c.log.Debugf("resubscribed to %s", sub.method) + + method := sub.method + args := sub.args + oldId := sub.id + + var subID uint64 + err := c.CallContext(ctx, &subID, method, args...) + if err != nil { + c.log.Errorf("failed to resubscribe to %s: %s", method, err) + continue + } + newId := fmt.Sprintf("%s:%d", method, subID) + + // store reader channel with new id + ch, ok := c.readers.Load(oldId) + c.readers.Delete(oldId) + if !ok { + c.log.Errorf("reader for old id %s not found", oldId) + continue + } + c.readers.Store(newId, ch) + + // need to update subscription id so that next resubscribe works + sub.id = newId + } +}