From 70e05751e8b7528bae64bb5900446160bcef77f7 Mon Sep 17 00:00:00 2001 From: "fox.cpp" Date: Fri, 8 Nov 2019 02:00:25 +0300 Subject: [PATCH] server: Simplify statusCollector data structure --- backend.go | 3 ++- conn.go | 74 ++++++++++++++++++++++++------------------------------ 2 files changed, 35 insertions(+), 42 deletions(-) diff --git a/backend.go b/backend.go index fa85ddf..75cce17 100644 --- a/backend.go +++ b/backend.go @@ -62,7 +62,8 @@ type LMTPSession interface { // LMTPData implementation sets status information using passed // StatusCollector by calling SetStatus once per each AddRcpt // call, even if AddRcpt was called multiple times with - // the same argument. + // the same argument. SetStatus must not be called after + // LMTPData returns. // // Return value of LMTPData itself is used as a status for // recipients that got no status set before using StatusCollector. diff --git a/conn.go b/conn.go index 9d05f41..595da1a 100644 --- a/conn.go +++ b/conn.go @@ -503,15 +503,9 @@ func (c *Conn) handleData(arg string) { } type statusCollector struct { - mapLock sync.Mutex - // Contains map from recipient to list of channels that are used for that // recipient. - // - // First SetStatus call uses first channel, second - second, etc. Channels - // that are already used are set to nil (otherwise we can accidentally - // reuse channel that was consumed by handleDataLMTP already). - statusMap map[string][]chan error + statusMap map[string]chan error // Contains channels from statusMap, in the same // order as Conn.recipients. @@ -519,58 +513,56 @@ type statusCollector struct { } // fillRemaining sets status for all recipients SetStatus was not called for before. -func (s statusCollector) fillRemaining(err error) { - s.mapLock.Lock() - defer s.mapLock.Unlock() - - // Since used channels in statusMap are set to nil, we can simply send - // on all non-nil channels to fill statuses not set by LMTPData. - for _, chList := range s.statusMap { - for _, ch := range chList { - if ch == nil { - continue +func (s *statusCollector) fillRemaining(err error) { + // Amount of times certain recipient was specified is indicated by the channel + // buffer size, so once we fill it, we can be confident that we sent + // at least as much statuses as needed. Extra statuses will be ignored anyway. +chLoop: + for _, ch := range s.statusMap { + for { + select { + case ch <- err: + default: + continue chLoop } - ch <- err } } } -func (s statusCollector) SetStatus(rcptTo string, err error) { - s.mapLock.Lock() - defer s.mapLock.Unlock() - - chList := s.statusMap[rcptTo] - if chList == nil { +func (s *statusCollector) SetStatus(rcptTo string, err error) { + ch := s.statusMap[rcptTo] + if ch == nil { panic("SetStatus is called for recipient that was not specified before") } - // Pick the first non-nil channel from list. - var usedCh chan error - for i, ch := range chList { - if ch != nil { - usedCh = ch - chList[i] = nil - break - } - } - if usedCh == nil { + select { + case ch <- err: + default: + // There enough buffer space to fit all statuses at once, if this is + // not the case - backend is doing something wrong. panic("SetStatus is called more times than particular recipient was specified") } - - usedCh <- err } func (c *Conn) handleDataLMTP() { r := newDataReader(c) - status := statusCollector{ - statusMap: make(map[string][]chan error, len(c.recipients)), + rcptCounts := make(map[string]int, len(c.recipients)) + + status := &statusCollector{ + statusMap: make(map[string]chan error, len(c.recipients)), status: make([]chan error, 0, len(c.recipients)), } for _, rcpt := range c.recipients { - ch := make(chan error, 1) - status.status = append(status.status, ch) - status.statusMap[rcpt] = append(status.statusMap[rcpt], ch) + rcptCounts[rcpt]++ + } + // Create channels with buffer sizes necessary to fit all + // statuses for a single recipient to avoid deadlocks. + for rcpt, count := range rcptCounts { + status.statusMap[rcpt] = make(chan error, count) + } + for _, rcpt := range c.recipients { + status.status = append(status.status, status.statusMap[rcpt]) } lmtpSession, ok := c.Session().(LMTPSession)