package v1 import ( "context" "errors" "fmt" "io" "log/slog" "sync" remotev1 "code.icb4dc0.de/buildr/api/generated/remote/v1" "github.com/google/uuid" ) var ( ErrNoReplyMsgID = errors.New("no repliesTo set - cannot handle message") ErrNoMatchingRequest = errors.New("no matching request for given message id") ) func NewRequestResponseClient(logger *slog.Logger, sender StreamSender[*remotev1.ExecutionServerMessage]) *RequestResponseClient { return &RequestResponseClient{ logger: logger, sender: sender, requests: make(map[uuid.UUID]chan *remotev1.ExecutionClientMessage), } } type RequestResponseClient struct { sender StreamSender[*remotev1.ExecutionServerMessage] logger *slog.Logger requests map[uuid.UUID]chan *remotev1.ExecutionClientMessage lock sync.Mutex } func (rrc *RequestResponseClient) DispatchMessages(receiver StreamReceiver[*remotev1.ExecutionClientMessage]) (errs chan error) { errs = make(chan error) go func() { for { msg, err := receiver.Receive() if err != nil { if errors.Is(err, io.EOF) { slog.Default().Info("Closing request response client") return } errs <- err return } rrc.logger.Debug("Dispatching message") repliesTo, ok := msg.Meta.(*remotev1.ExecutionClientMessage_RepliesTo) if !ok { errs <- ErrNoReplyMsgID continue } msgID, err := uuid.FromBytes(repliesTo.RepliesTo) if err != nil { errs <- err continue } rrc.logger.Debug("Checking for pending request", slog.String("msgId", msgID.String())) rrc.lock.Lock() waitingChan, ok := rrc.requests[msgID] if !ok { rrc.lock.Unlock() errs <- fmt.Errorf("%w: %s", ErrNoMatchingRequest, msgID) continue } delete(rrc.requests, msgID) rrc.lock.Unlock() rrc.logger.Debug("Replying to pending request", slog.String("msgId", msgID.String())) waitingChan <- msg } }() return errs } func (rrc *RequestResponseClient) Send( ctx context.Context, req *remotev1.ExecutionServerMessage, ) (resp *remotev1.ExecutionClientMessage, err error) { var ( msgID = uuid.New() respChan = make(chan *remotev1.ExecutionClientMessage) ) req.Meta = &remotev1.ExecutionServerMessage_MessageId{ MessageId: uuidBytes(msgID), } rrc.lock.Lock() rrc.requests[msgID] = respChan rrc.lock.Unlock() defer rrc.cleanRequest(msgID) if err := rrc.sender.Send(req); err != nil { return nil, err } select { case <-ctx.Done(): return nil, ctx.Err() case resp := <-respChan: switch unwrap := resp.Envelope.(type) { case *remotev1.ExecutionClientMessage_Error: if errMsg := unwrap.Error.Error; errMsg != "" { return nil, errors.New(unwrap.Error.Error) } //nolint:nilnil // fine here return nil, nil default: return resp, nil } } } func (rrc *RequestResponseClient) cleanRequest(id uuid.UUID) { rrc.lock.Lock() defer rrc.lock.Unlock() c, ok := rrc.requests[id] if ok { close(c) delete(rrc.requests, id) } } func uuidBytes(uid uuid.UUID) []byte { b, _ := uid.MarshalBinary() return b }