Works
This commit is contained in:
parent
3eeeac2c13
commit
346d8d7967
@ -585,9 +585,9 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||
return writeMatrixDiscoveryResponse(w)
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
||||
t := fromContext[topic](r, contextTopic)
|
||||
vrate := fromContext[visitor](r, contextRateVisitor)
|
||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
|
||||
t := fromContext[*topic](r, contextTopic)
|
||||
vrate := fromContext[*visitor](r, contextRateVisitor)
|
||||
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -670,7 +670,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||
}
|
||||
|
||||
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
m, err := s.handlePublishWithoutResponse(r, v)
|
||||
m, err := s.handlePublishInternal(r, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -678,10 +678,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
_, err := s.handlePublishWithoutResponse(r, v)
|
||||
_, err := s.handlePublishInternal(r, v)
|
||||
if err != nil {
|
||||
if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode {
|
||||
return writeMatrixResponse(w, e.rejectedPushKey)
|
||||
topic := fromContext[*topic](r, contextTopic)
|
||||
pushKey := fromContext[string](r, contextMatrixPushKey)
|
||||
if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter {
|
||||
return writeMatrixResponse(w, pushKey)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -1011,6 +1015,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
for _, t := range topics {
|
||||
t.Keepalive()
|
||||
}
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -1037,7 +1044,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
case <-r.Context().Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message")
|
||||
ev := logvr(v, r).Tag(tagSubscribe)
|
||||
if len(topics) == 1 {
|
||||
ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID)
|
||||
} else {
|
||||
ev.Trace("Sending keepalive message to %d topics", len(topics))
|
||||
}
|
||||
v.Keepalive()
|
||||
for _, t := range topics {
|
||||
t.Keepalive()
|
||||
@ -1154,6 +1166,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
for _, t := range topics {
|
||||
t.Keepalive()
|
||||
}
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
|
@ -2,8 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) execManager() {
|
||||
@ -39,13 +39,13 @@ func (s *Server) execManager() {
|
||||
ev := log.Tag(tagManager).With(t)
|
||||
if t.Stale() {
|
||||
if ev.IsTrace() {
|
||||
ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, lastAccess.Format(time.RFC822))
|
||||
ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, util.FormatTime(lastAccess))
|
||||
}
|
||||
emptyTopics++
|
||||
delete(s.topics, t.ID)
|
||||
} else {
|
||||
if ev.IsTrace() {
|
||||
ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, lastAccess.Format(time.RFC822))
|
||||
ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, util.FormatTime(lastAccess))
|
||||
}
|
||||
subscribers += subs
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Matrix Push Gateway / UnifiedPush / ntfy integration:
|
||||
@ -71,6 +72,14 @@ type matrixResponse struct {
|
||||
Rejected []string `json:"rejected"`
|
||||
}
|
||||
|
||||
const (
|
||||
// matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter is the time after which a Matrix response
|
||||
// will return an HTTP 200 with the push key (i.e. "rejected":["<pushkey>"]}), if no rate visitor has been set on
|
||||
// the topic. Rejecting the push key will instruct the Matrix server to invalidate the pushkey and stop sending
|
||||
// messages to it. See https://spec.matrix.org/v1.6/push-gateway-api/
|
||||
matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter = 12 * time.Hour
|
||||
)
|
||||
|
||||
// errMatrixPushkeyRejected represents an error when handing Matrix gateway messages
|
||||
//
|
||||
// If the push key is set, the app server will remove it and will never send messages using the same
|
||||
@ -126,7 +135,9 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int)
|
||||
if r.Header.Get("X-Forwarded-For") != "" {
|
||||
newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
|
||||
}
|
||||
newRequest.Header.Set("X-Matrix-Pushkey", pushKey)
|
||||
newRequest = withContext(newRequest, map[contextKey]any{
|
||||
contextMatrixPushKey: pushKey,
|
||||
})
|
||||
return newRequest, nil
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ type contextKey int
|
||||
const (
|
||||
contextRateVisitor contextKey = iota + 2586
|
||||
contextTopic
|
||||
contextMatrixPushKey
|
||||
)
|
||||
|
||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||
|
@ -1172,6 +1172,56 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
|
||||
require.Equal(t, 400, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish and check last access
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"Cache": "no",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
|
||||
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
|
||||
|
||||
// Topic won't get pruned
|
||||
s.execManager()
|
||||
require.NotNil(t, s.topics["mytopic"])
|
||||
|
||||
// Fudge with last access, but subscribe, and see that it won't get pruned (because of subscriber)
|
||||
subID := s.topics["mytopic"].Subscribe(subFn, "", func() {})
|
||||
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||
s.execManager()
|
||||
require.NotNil(t, s.topics["mytopic"])
|
||||
|
||||
// It'll finally get pruned now that there are no subscribers and last access is 17 hours ago
|
||||
s.topics["mytopic"].Unsubscribe(subID)
|
||||
s.execManager()
|
||||
require.Nil(t, s.topics["mytopic"])
|
||||
}
|
||||
|
||||
func TestServer_TopicKeepaliveOnPoll(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
// Create topic by polling once
|
||||
response := request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Mess with last access time
|
||||
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||
|
||||
// Poll again and check keepalive time
|
||||
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
|
||||
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
|
||||
}
|
||||
|
||||
func TestServer_UnifiedPushDiscovery(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "GET", "/mytopic?up=1", "", nil)
|
||||
@ -1301,6 +1351,32 @@ func TestServer_MatrixGateway_Push_Failure_NoSubscriber(t *testing.T) {
|
||||
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_MatrixGateway_Push_Failure_NoSubscriber_After13Hours(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
c.VisitorSubscriberRateLimiting = true
|
||||
s := newTestServer(t, c)
|
||||
notification := `{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/mytopic?up=1"}]}}`
|
||||
|
||||
// No success if no rate visitor set (this also creates the topic in memory
|
||||
response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
|
||||
require.Equal(t, 507, response.Code)
|
||||
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
|
||||
require.Nil(t, s.topics["mytopic"].rateVisitor)
|
||||
|
||||
// Fake: This topic has been around for 13 hours without a rate visitor
|
||||
s.topics["mytopic"].lastAccess = time.Now().Add(-13 * time.Hour)
|
||||
|
||||
// Same request should now return HTTP 200 with a rejected pushkey
|
||||
response = request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"rejected":["http://127.0.0.1:12345/mytopic?up=1"]}`, strings.TrimSpace(response.Body.String()))
|
||||
|
||||
// Slightly unrelated: Test that topic is pruned after 16 hours
|
||||
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||
s.execManager()
|
||||
require.Nil(t, s.topics["mytopic"])
|
||||
}
|
||||
|
||||
func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}`
|
||||
|
@ -2,13 +2,18 @@ package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
topicExpiryDuration = 6 * time.Hour
|
||||
// topicExpungeAfter defines how long a topic is active before it is removed from memory.
|
||||
//
|
||||
// This must be larger than matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter to give
|
||||
// time for more requests to come in, so that we can send a {"rejected":["<pushkey>"]} response back.
|
||||
topicExpungeAfter = 16 * time.Hour
|
||||
)
|
||||
|
||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||
@ -59,7 +64,13 @@ func (t *topic) Stale() bool {
|
||||
if t.rateVisitor != nil && !t.rateVisitor.Stale() {
|
||||
return false
|
||||
}
|
||||
return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration
|
||||
return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpungeAfter
|
||||
}
|
||||
|
||||
func (t *topic) LastAccess() time.Time {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.lastAccess
|
||||
}
|
||||
|
||||
func (t *topic) SetRateVisitor(v *visitor) {
|
||||
@ -148,6 +159,7 @@ func (t *topic) Context() log.Context {
|
||||
fields := map[string]any{
|
||||
"topic": t.ID,
|
||||
"topic_subscribers": len(t.subscribers),
|
||||
"topic_last_access": util.FormatTime(t.lastAccess),
|
||||
}
|
||||
if t.rateVisitor != nil {
|
||||
for k, v := range t.rateVisitor.Context() {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTopic_CancelSubscribers(t *testing.T) {
|
||||
@ -28,3 +29,13 @@ func TestTopic_CancelSubscribers(t *testing.T) {
|
||||
require.True(t, canceled1.Load())
|
||||
require.False(t, canceled2.Load())
|
||||
}
|
||||
|
||||
func TestTopic_Keepalive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
to := newTopic("mytopic")
|
||||
to.lastAccess = time.Now().Add(-1 * time.Hour)
|
||||
to.Keepalive()
|
||||
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
|
||||
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
|
||||
}
|
||||
|
@ -107,8 +107,8 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
|
||||
return r.WithContext(c)
|
||||
}
|
||||
|
||||
func fromContext[T any](r *http.Request, key contextKey) *T {
|
||||
t, ok := r.Context().Value(key).(*T)
|
||||
func fromContext[T any](r *http.Request, key contextKey) T {
|
||||
t, ok := r.Context().Value(key).(T)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("cannot find key %v in request context", key))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user