Logging WIP
This commit is contained in:
parent
a6641980c2
commit
5d6051c490
@ -26,6 +26,8 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) {
|
||||
stdin.WriteString("philpass\nphilpass\nbenpass\nbenpass")
|
||||
require.Nil(t, runUserCommand(app, conf, "add", "--role=admin", "phil"))
|
||||
require.Nil(t, runUserCommand(app, conf, "add", "ben"))
|
||||
|
||||
app, stdin, _, _ = newTestApp()
|
||||
require.Nil(t, runAccessCommand(app, conf, "ben", "announcements", "rw"))
|
||||
require.Nil(t, runAccessCommand(app, conf, "ben", "sometopic", "read"))
|
||||
require.Nil(t, runAccessCommand(app, conf, "everyone", "announcements", "read"))
|
||||
|
@ -76,7 +76,7 @@ func (e *Event) Fields(fields map[string]any) *Event {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *Event) Context(contexts ...Ctx) *Event {
|
||||
func (e *Event) Context(contexts ...Contexter) *Event {
|
||||
for _, c := range contexts {
|
||||
e.Fields(c.Context())
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func Trace(message string, v ...any) {
|
||||
newEvent().Trace(message, v...)
|
||||
}
|
||||
|
||||
func Context(contexts ...Ctx) *Event {
|
||||
func Context(contexts ...Contexter) *Event {
|
||||
return newEvent().Context(contexts...)
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,7 @@ func ToFormat(s string) Format {
|
||||
}
|
||||
}
|
||||
|
||||
type Ctx interface {
|
||||
type Contexter interface {
|
||||
Context() map[string]any
|
||||
}
|
||||
|
||||
@ -101,7 +101,7 @@ func (f fieldsCtx) Context() map[string]any {
|
||||
return f
|
||||
}
|
||||
|
||||
func NewCtx(fields map[string]any) Ctx {
|
||||
func NewCtx(fields map[string]any) Contexter {
|
||||
return fieldsCtx(fields)
|
||||
}
|
||||
|
||||
|
@ -149,6 +149,7 @@ const (
|
||||
tagManager = "manager"
|
||||
tagResetter = "resetter"
|
||||
tagWebsocket = "websocket"
|
||||
tagMatrix = "matrix"
|
||||
)
|
||||
|
||||
// New instantiates a new Server. It creates the cache and adds a Firebase
|
||||
@ -328,9 +329,9 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
||||
if isNormalError {
|
||||
logvr(v, r).Tag(tagWebsocket).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
|
||||
logvr(v, r).Tag(tagWebsocket).Err(err).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
|
||||
} else {
|
||||
logvr(v, r).Tag(tagWebsocket).Info("WebSocket error: %s", err.Error())
|
||||
logvr(v, r).Tag(tagWebsocket).Err(err).Info("WebSocket error: %s", err.Error())
|
||||
}
|
||||
return // Do not attempt to write to upgraded connection
|
||||
}
|
||||
@ -711,7 +712,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request")
|
||||
return
|
||||
} else if response.StatusCode != http.StatusOK {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d")
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -1537,6 +1538,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||
return next(w, r, v)
|
||||
} else if err := v.RequestAllowed(); err != nil {
|
||||
logvr(v, r).Err(err).Fields(requestLimiterFields(v.RequestLimiter())).Trace("Request not allowed by rate limiter")
|
||||
return errHTTPTooManyRequestsLimitRequests
|
||||
}
|
||||
return next(w, r, v)
|
||||
@ -1601,6 +1603,7 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
|
||||
if err != nil {
|
||||
logvr(v, r).Tag(tagMatrix).Err(err).Trace("Invalid Matrix request")
|
||||
return err
|
||||
}
|
||||
if err := next(w, newRequest, v); err != nil {
|
||||
@ -1630,7 +1633,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
||||
u := v.User()
|
||||
for _, t := range topics {
|
||||
if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
|
||||
logvr(v, r).Err(err).Debug("Unauthorized")
|
||||
logvr(v, r).Err(err).Field("message_topic", t.ID).Debug("Access to topic %s not authorized", t.ID)
|
||||
return errHTTPForbidden
|
||||
}
|
||||
}
|
||||
@ -1644,7 +1647,7 @@ func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||
var u *user.User // may stay nil if no auth header!
|
||||
if u, err = s.authenticate(r); err != nil {
|
||||
logr(r).Debug("Authentication failed: %s", err.Error())
|
||||
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
}
|
||||
v = s.visitor(ip, u)
|
||||
|
@ -160,7 +160,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil {
|
||||
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name)
|
||||
@ -462,18 +462,19 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||
// maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier),
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
|
||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
|
||||
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -147,7 +146,7 @@ func writeMatrixDiscoveryResponse(w http.ResponseWriter) error {
|
||||
|
||||
// writeMatrixError logs and writes the errMatrix to the given http.ResponseWriter as a matrixResponse
|
||||
func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error {
|
||||
log.Debug("%s Matrix gateway error: %s", logHTTPPrefix(v, r), err.Error())
|
||||
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Matrix gateway error")
|
||||
return writeMatrixResponse(w, err.pushKey)
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
@ -121,7 +120,13 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||
} else if tier.StripePriceID == "" {
|
||||
return errNotAPaidTier
|
||||
}
|
||||
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Fields(map[string]any{
|
||||
"tier": tier,
|
||||
"stripe_price_id": tier.StripePriceID,
|
||||
}).
|
||||
Info("Creating Stripe checkout flow")
|
||||
var stripeCustomerID *string
|
||||
if u.Billing.StripeCustomerID != "" {
|
||||
stripeCustomerID = &u.Billing.StripeCustomerID
|
||||
@ -190,6 +195,18 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||
return err
|
||||
}
|
||||
v.SetUser(u)
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Fields(map[string]any{
|
||||
"tier_id": tier.ID,
|
||||
"tier_name": tier.Name,
|
||||
"stripe_price_id": tier.StripePriceID,
|
||||
"stripe_customer_id": sess.Customer.ID,
|
||||
"stripe_subscription_id": sub.ID,
|
||||
"stripe_subscription_status": string(sub.Status),
|
||||
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
|
||||
}).
|
||||
Info("Stripe checkout flow succeeded, updating user tier and subscription")
|
||||
customerParams := &stripe.CustomerParams{
|
||||
Params: stripe.Params{
|
||||
Metadata: map[string]string{
|
||||
@ -201,7 +218,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||
@ -223,7 +240,15 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Fields(map[string]any{
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_name": tier.Name,
|
||||
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||
// Other stripe_* fields filled by visitor context
|
||||
}).
|
||||
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
|
||||
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -250,8 +275,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||
// and cancelling the Stripe subscription entirely
|
||||
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
logvr(v, r).Tag(tagPay).Info("Deleting Stripe subscription")
|
||||
u := v.User()
|
||||
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
|
||||
if u.Billing.StripeSubscriptionID != "" {
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(true),
|
||||
@ -267,11 +292,11 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
||||
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
||||
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
logvr(v, r).Tag(tagPay).Info("Creating Stripe billing portal session")
|
||||
u := v.User()
|
||||
if u.Billing.StripeCustomerID == "" {
|
||||
return errHTTPBadRequestNotAPaidUser
|
||||
}
|
||||
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(u.Billing.StripeCustomerID),
|
||||
ReturnURL: stripe.String(s.config.BaseURL),
|
||||
@ -289,7 +314,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
||||
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
|
||||
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
stripeSignature := r.Header.Get("Stripe-Signature")
|
||||
if stripeSignature == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
@ -308,74 +333,105 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
switch event.Type {
|
||||
case "customer.subscription.updated":
|
||||
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
|
||||
return s.handleAccountBillingWebhookSubscriptionUpdated(r, v, event)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
|
||||
return s.handleAccountBillingWebhookSubscriptionDeleted(r, v, event)
|
||||
default:
|
||||
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Field("stripe_webhook_type", event.Type).
|
||||
Warn("Unhandled Stripe webhook event %s received", event.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, v *visitor, event stripe.Event) error {
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
||||
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Fields(map[string]any{
|
||||
"stripe_webhook_type": event.Type,
|
||||
"stripe_customer_id": ev.Customer,
|
||||
"stripe_subscription_id": ev.ID,
|
||||
"stripe_subscription_status": ev.Status,
|
||||
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
|
||||
"stripe_subscription_cancel_at": ev.CancelAt,
|
||||
"stripe_price_id": priceID,
|
||||
}).
|
||||
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
|
||||
userFn := func() (*user.User, error) {
|
||||
return s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
}
|
||||
// We retry the user retrieval function, because during the Stripe checkout, there a race between the browser
|
||||
// checkout success redirect (see handleAccountBillingSubscriptionCreateSuccess), and this webhook. The checkout
|
||||
// success call is the one that updates the user with the Stripe customer ID.
|
||||
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.SetUser(u)
|
||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, v *visitor, event stripe.Event) error {
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ev.Customer == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
|
||||
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
v.SetUser(u)
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Field("stripe_webhook_type", event.Type).
|
||||
Info("Subscription deleted, downgrading to unpaid tier")
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationLimit
|
||||
}
|
||||
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
||||
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, reservationsLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if tier == nil {
|
||||
if tier == nil && u.Tier != nil {
|
||||
logvr(v, r).Tag(tagPay).Info("Resetting tier for user %s", u.Name)
|
||||
if err := s.userManager.ResetTier(u.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
} else if tier != nil && u.TierID() != tier.ID {
|
||||
logvr(v, r).
|
||||
Tag(tagPay).
|
||||
Fields(map[string]any{
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_name": tier.Name,
|
||||
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||
}).
|
||||
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
|
||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
|
||||
}
|
||||
|
||||
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
logem(s.state).Debug("%s MAIL FROM: %s (with options: %#v)", from, opts)
|
||||
logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,15 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||
@ -48,90 +45,6 @@ func readQueryParam(r *http.Request, names ...string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func logr(r *http.Request) *log.Event {
|
||||
return log.Fields(logFieldsHTTP(r))
|
||||
}
|
||||
|
||||
func logv(v *visitor) *log.Event {
|
||||
return log.Context(v)
|
||||
}
|
||||
|
||||
func logvr(v *visitor, r *http.Request) *log.Event {
|
||||
return logv(v).Fields(logFieldsHTTP(r))
|
||||
}
|
||||
|
||||
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
|
||||
return logvr(v, r).Context(m)
|
||||
}
|
||||
|
||||
func logvm(v *visitor, m *message) *log.Event {
|
||||
return logv(v).Context(m)
|
||||
}
|
||||
|
||||
func logem(state *smtp.ConnectionState) *log.Event {
|
||||
return log.
|
||||
Tag(tagSMTP).
|
||||
Fields(map[string]any{
|
||||
"smtp_hostname": state.Hostname,
|
||||
"smtp_remote_addr": state.RemoteAddr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func logFieldsHTTP(r *http.Request) map[string]any {
|
||||
requestURI := r.RequestURI
|
||||
if requestURI == "" {
|
||||
requestURI = r.URL.Path
|
||||
}
|
||||
return map[string]any{
|
||||
"http_method": r.Method,
|
||||
"http_path": requestURI,
|
||||
}
|
||||
}
|
||||
|
||||
func logHTTPPrefix(v *visitor, r *http.Request) string {
|
||||
requestURI := r.RequestURI
|
||||
if requestURI == "" {
|
||||
requestURI = r.URL.Path
|
||||
}
|
||||
return fmt.Sprintf("HTTP %s %s %s", v.String(), r.Method, requestURI)
|
||||
}
|
||||
|
||||
func logStripePrefix(customerID, subscriptionID string) string {
|
||||
if subscriptionID != "" {
|
||||
return fmt.Sprintf("STRIPE %s/%s", customerID, subscriptionID)
|
||||
}
|
||||
return fmt.Sprintf("STRIPE %s", customerID)
|
||||
}
|
||||
|
||||
func renderHTTPRequest(r *http.Request) string {
|
||||
peekLimit := 4096
|
||||
lines := fmt.Sprintf("%s %s %s\n", r.Method, r.URL.RequestURI(), r.Proto)
|
||||
for key, values := range r.Header {
|
||||
for _, value := range values {
|
||||
lines += fmt.Sprintf("%s: %s\n", key, value)
|
||||
}
|
||||
}
|
||||
lines += "\n"
|
||||
body, err := util.Peek(r.Body, peekLimit)
|
||||
if err != nil {
|
||||
lines = fmt.Sprintf("(could not read body: %s)\n", err.Error())
|
||||
} else if utf8.Valid(body.PeekedBytes) {
|
||||
lines += string(body.PeekedBytes)
|
||||
if body.LimitReached {
|
||||
lines += fmt.Sprintf(" ... (peeked %d bytes)", peekLimit)
|
||||
}
|
||||
lines += "\n"
|
||||
} else {
|
||||
if body.LimitReached {
|
||||
lines += fmt.Sprintf("(peeked bytes not UTF-8, peek limit of %d bytes reached, hex: %x ...)\n", peekLimit, body.PeekedBytes)
|
||||
} else {
|
||||
lines += fmt.Sprintf("(peeked bytes not UTF-8, %d bytes, hex: %x)\n", len(body.PeekedBytes), body.PeekedBytes)
|
||||
}
|
||||
}
|
||||
r.Body = body // Important: Reset body, so it can be re-read
|
||||
return strings.TrimSpace(lines)
|
||||
}
|
||||
|
||||
func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
|
||||
remoteAddr := r.RemoteAddr
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||
|
@ -159,6 +159,10 @@ func (v *visitor) Context() map[string]any {
|
||||
if v.user != nil {
|
||||
fields["user_id"] = v.user.ID
|
||||
fields["user_name"] = v.user.Name
|
||||
if v.user.Tier != nil {
|
||||
fields["tier_id"] = v.user.Tier.ID
|
||||
fields["tier_name"] = v.user.Tier.Name
|
||||
}
|
||||
if v.user.Billing.StripeCustomerID != "" {
|
||||
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
|
||||
}
|
||||
@ -178,6 +182,12 @@ func (v *visitor) RequestAllowed() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *visitor) RequestLimiter() *rate.Limiter {
|
||||
v.mu.Lock() // limiters could be replaced!
|
||||
defer v.mu.Unlock()
|
||||
return v.requestLimiter
|
||||
}
|
||||
|
||||
func (v *visitor) FirebaseAllowed() error {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
Loading…
Reference in New Issue
Block a user