mirror of
https://github.com/ente-io/ente.git
synced 2025-05-01 20:03:07 +00:00
415 lines
13 KiB
Go
415 lines
13 KiB
Go
package user
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/ente-io/museum/pkg/utils/random"
|
|
|
|
"github.com/ente-io/museum/pkg/utils/config"
|
|
"github.com/ente-io/museum/pkg/utils/network"
|
|
"github.com/gin-contrib/requestid"
|
|
"github.com/spf13/viper"
|
|
|
|
"github.com/ente-io/museum/ente"
|
|
"github.com/ente-io/museum/pkg/utils/auth"
|
|
"github.com/ente-io/museum/pkg/utils/crypto"
|
|
emailUtil "github.com/ente-io/museum/pkg/utils/email"
|
|
"github.com/ente-io/museum/pkg/utils/time"
|
|
"github.com/ente-io/stacktrace"
|
|
"github.com/gin-gonic/gin"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type HardCodedOTTEmail struct {
|
|
Email string
|
|
Value string
|
|
}
|
|
|
|
type HardCodedOTT struct {
|
|
Emails []HardCodedOTTEmail
|
|
LocalDomainSuffix string
|
|
LocalDomainValue string
|
|
}
|
|
|
|
func ReadHardCodedOTTFromConfig() HardCodedOTT {
|
|
emails := make([]HardCodedOTTEmail, 0)
|
|
emailsSlice := viper.GetStringSlice("internal.hardcoded-ott.emails")
|
|
for _, entry := range emailsSlice {
|
|
xs := strings.Split(entry, ",")
|
|
if len(xs) == 2 && xs[0] != "" && xs[1] != "" {
|
|
emails = append(emails, HardCodedOTTEmail{
|
|
Email: xs[0],
|
|
Value: xs[1],
|
|
})
|
|
} else {
|
|
log.Errorf("Ignoring malformed internal.hardcoded-ott.emails entry %s", entry)
|
|
}
|
|
}
|
|
|
|
localDomainSuffix := ""
|
|
localDomainValue := ""
|
|
if config.IsLocalEnvironment() {
|
|
localDomainSuffix = viper.GetString("internal.hardcoded-ott.local-domain-suffix")
|
|
localDomainValue = viper.GetString("internal.hardcoded-ott.local-domain-value")
|
|
}
|
|
|
|
return HardCodedOTT{
|
|
Emails: emails,
|
|
LocalDomainSuffix: localDomainSuffix,
|
|
LocalDomainValue: localDomainValue,
|
|
}
|
|
}
|
|
|
|
func hardcodedOTTForEmail(hardCodedOTT HardCodedOTT, email string) string {
|
|
for _, entry := range hardCodedOTT.Emails {
|
|
if email == entry.Email {
|
|
return entry.Value
|
|
}
|
|
}
|
|
|
|
if hardCodedOTT.LocalDomainSuffix != "" && strings.HasSuffix(email, hardCodedOTT.LocalDomainSuffix) {
|
|
return hardCodedOTT.LocalDomainValue
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// SendEmailOTT generates and sends an OTT to the provided email address
|
|
func (c *UserController) SendEmailOTT(context *gin.Context, email string, purpose string) error {
|
|
if purpose == ente.ChangeEmailOTTPurpose {
|
|
_, err := c.UserRepo.GetUserIDWithEmail(email)
|
|
if err == nil {
|
|
// email already owned by a user
|
|
return stacktrace.Propagate(ente.ErrPermissionDenied, "")
|
|
}
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
// unknown error, rethrow
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
}
|
|
ott, err := random.GenerateSixDigitOtp()
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
// for hard-coded ott, adding same OTT in db can throw error
|
|
hasHardcodedOTT := false
|
|
if purpose != ente.ChangeEmailOTTPurpose {
|
|
hardCodedOTT := hardcodedOTTForEmail(c.HardCodedOTT, email)
|
|
if hardCodedOTT != "" {
|
|
log.Warn(fmt.Sprintf("returning hardcoded ott for %s", email))
|
|
hasHardcodedOTT = true
|
|
ott = hardCodedOTT
|
|
}
|
|
}
|
|
emailHash, err := crypto.GetHash(email, c.HashingKey)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
// check if user has already requested for more than 10 codes in last 10mins
|
|
otts, _ := c.UserAuthRepo.GetValidOTTs(emailHash, auth.GetApp(context))
|
|
if len(otts) >= OTTActiveCodeLimit {
|
|
msg := "Too many ott requests in a short duration"
|
|
go c.DiscordController.NotifyPotentialAbuse(msg)
|
|
return stacktrace.Propagate(ente.ErrTooManyBadRequest, msg)
|
|
}
|
|
|
|
err = c.UserAuthRepo.AddOTT(emailHash, auth.GetApp(context), ott, time.Microseconds()+OTTValidityDurationInMicroSeconds)
|
|
if !hasHardcodedOTT {
|
|
// ignore error for AddOTT for hardcode OTT. This is to avoid error when unique OTT check fails at db layer
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
log.Info("Added ott for " + emailHash + ": " + ott)
|
|
err = emailOTT(email, ott, purpose)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
} else {
|
|
log.Info("Added hard coded ott for " + email + " : " + ott)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *UserController) AddAdminOtt(req ente.AdminOttReq) error {
|
|
emailHash, err := crypto.GetHash(req.Email, c.HashingKey)
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to get hash")
|
|
return nil
|
|
}
|
|
err = c.UserAuthRepo.AddOTT(emailHash, req.App, req.Code, req.ExpiryTime)
|
|
if err != nil {
|
|
log.WithError(err).Error("Failed to add ott")
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// verifyEmailOtt should be deprecated in favor of verifyEmailOttWithSession once clients are updated.
|
|
func (c *UserController) verifyEmailOtt(context *gin.Context, email string, ott string) error {
|
|
ott = strings.TrimSpace(ott)
|
|
app := auth.GetApp(context)
|
|
emailHash, err := crypto.GetHash(email, c.HashingKey)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
wrongAttempt, err := c.UserAuthRepo.GetMaxWrongAttempts(emailHash, app)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
if wrongAttempt >= OTTWrongAttemptLimit {
|
|
msg := fmt.Sprintf("Too many wrong ott verification attemp for app %s", app)
|
|
go c.DiscordController.NotifyPotentialAbuse(msg)
|
|
return stacktrace.Propagate(ente.ErrTooManyBadRequest, "User needs to wait before active ott are expired")
|
|
}
|
|
|
|
otts, err := c.UserAuthRepo.GetValidOTTs(emailHash, app)
|
|
log.Infof("Valid ott (app: %s) for %s are %s", app, emailHash, strings.Join(otts, ","))
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
if len(otts) < 1 {
|
|
return stacktrace.Propagate(ente.ErrExpiredOTT, "")
|
|
}
|
|
isValidOTT := false
|
|
for _, validOTT := range otts {
|
|
if ott == validOTT {
|
|
isValidOTT = true
|
|
}
|
|
}
|
|
if !isValidOTT {
|
|
if err = c.UserAuthRepo.RecordWrongAttemptForActiveOtt(emailHash, app); err != nil {
|
|
log.WithError(err).Warn("Failed to track wrong attempt")
|
|
}
|
|
return stacktrace.Propagate(ente.ErrIncorrectOTT, "")
|
|
}
|
|
err = c.UserAuthRepo.RemoveOTT(emailHash, ott, app)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// VerifyEmail validates that the OTT provided in the request is valid for the
|
|
// provided email address and if yes returns the users credentials
|
|
func (c *UserController) VerifyEmail(context *gin.Context, request ente.EmailVerificationRequest) (ente.EmailAuthorizationResponse, error) {
|
|
email := strings.ToLower(request.Email)
|
|
err := c.verifyEmailOtt(context, email, request.OTT)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return c.onVerificationSuccess(context, email, request.Source)
|
|
}
|
|
|
|
// ChangeEmail validates that the OTT provided in the request is valid for the
|
|
// provided email address and if yes updates the user's existing email address
|
|
func (c *UserController) ChangeEmail(ctx *gin.Context, request ente.EmailVerificationRequest) error {
|
|
email := strings.ToLower(request.Email)
|
|
err := c.verifyEmailOtt(ctx, email, request.OTT)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
return c.UpdateEmail(ctx, auth.GetUserID(ctx.Request.Header), email)
|
|
}
|
|
|
|
// UpdateEmail updates the email address of the user with the provided userID
|
|
func (c *UserController) UpdateEmail(ctx *gin.Context, userID int64, email string) error {
|
|
_, err := c.UserRepo.GetUserIDWithEmail(email)
|
|
if err == nil {
|
|
// email already owned by a user
|
|
return stacktrace.Propagate(ente.ErrPermissionDenied, "")
|
|
}
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
// unknown error, rethrow
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
user, err := c.UserRepo.Get(userID)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
oldEmail := user.Email
|
|
encryptedEmail, err := crypto.Encrypt(email, c.SecretEncryptionKey)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
emailHash, err := crypto.GetHash(email, c.HashingKey)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
err = c.UserRepo.UpdateEmail(userID, encryptedEmail, emailHash)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
_ = emailUtil.SendTemplatedEmail([]string{user.Email}, "ente", "team@ente.io",
|
|
ente.EmailChangedSubject, ente.EmailChangedTemplate, map[string]interface{}{
|
|
"NewEmail": email,
|
|
}, nil)
|
|
|
|
err = c.BillingController.UpdateBillingEmail(userID, email)
|
|
if err != nil {
|
|
log.WithError(err).
|
|
WithFields(log.Fields{
|
|
"req_id": requestid.Get(ctx),
|
|
"user_id": userID,
|
|
}).Error("stripe update email failed")
|
|
}
|
|
|
|
// Unsubscribe the old email, subscribe the new one.
|
|
//
|
|
// Note that resubscribing the same email after it has been unsubscribed
|
|
// once works fine.
|
|
//
|
|
// See also: Do not block on mailing list errors
|
|
go func() {
|
|
_ = c.MailingListsController.Unsubscribe(oldEmail)
|
|
_ = c.MailingListsController.Subscribe(email)
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Logout removes the token from the cache and database.
|
|
// known issue: the token may be still cached in other instances till the expiry time (10min), JWTs might remain too
|
|
func (c *UserController) Logout(ctx *gin.Context) error {
|
|
token := auth.GetToken(ctx)
|
|
userID := auth.GetUserID(ctx.Request.Header)
|
|
return c.TerminateSession(userID, token)
|
|
}
|
|
|
|
// GetActiveSessions returns the list of active tokens for userID
|
|
func (c *UserController) GetActiveSessions(context *gin.Context, userID int64) ([]ente.Session, error) {
|
|
tokens, err := c.UserAuthRepo.GetActiveSessions(userID, auth.GetApp(context))
|
|
if err != nil {
|
|
return nil, stacktrace.Propagate(err, "")
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
// TerminateSession removes the token for a user from cache and database
|
|
func (c *UserController) TerminateSession(userID int64, token string) error {
|
|
c.Cache.Delete(fmt.Sprintf("%s:%s", ente.Photos, token))
|
|
c.Cache.Delete(fmt.Sprintf("%s:%s", ente.Auth, token))
|
|
return stacktrace.Propagate(c.UserAuthRepo.RemoveToken(userID, token), "")
|
|
}
|
|
|
|
func emailOTT(to string, ott string, purpose string) error {
|
|
var templateName string
|
|
if purpose == ente.ChangeEmailOTTPurpose {
|
|
templateName = ente.ChangeEmailOTTTemplate
|
|
} else {
|
|
templateName = ente.OTTTemplate
|
|
}
|
|
subject := fmt.Sprintf("Verification code: %s", ott)
|
|
err := emailUtil.SendTemplatedEmail([]string{to}, "Ente", "verify@ente.io",
|
|
subject, templateName, map[string]interface{}{
|
|
"VerificationCode": ott,
|
|
}, nil)
|
|
if err != nil {
|
|
return stacktrace.Propagate(err, "")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// onVerificationSuccess is called when the user has successfully verified their email address.
|
|
// source indicates where the user came from. It can be nil.
|
|
func (c *UserController) onVerificationSuccess(context *gin.Context, email string, source *string) (ente.EmailAuthorizationResponse, error) {
|
|
isTwoFactorEnabled := false
|
|
|
|
userID, err := c.UserRepo.GetUserIDWithEmail(email)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
userID, _, err = c.createUser(email, source)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
} else {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
} else {
|
|
isTwoFactorEnabled, err = c.UserRepo.IsTwoFactorEnabled(userID)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, err
|
|
}
|
|
}
|
|
hasPasskeys, err := c.UserRepo.HasPasskeys(userID)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
|
|
// if the user has passkeys, we will prioritize that over secret TOTP
|
|
if hasPasskeys {
|
|
passKeySessionID, err := auth.GenerateURLSafeRandomString(PassKeySessionIDLength)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
err = c.PasskeyRepo.AddPasskeyTwoFactorSession(userID, passKeySessionID, time.Microseconds()+TwoFactorValidityDurationInMicroSeconds)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return ente.EmailAuthorizationResponse{ID: userID, PasskeySessionID: passKeySessionID}, nil
|
|
} else {
|
|
if isTwoFactorEnabled {
|
|
twoFactorSessionID, err := auth.GenerateURLSafeRandomString(TwoFactorSessionIDLength)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
err = c.TwoFactorRepo.AddTwoFactorSession(userID, twoFactorSessionID, time.Microseconds()+TwoFactorValidityDurationInMicroSeconds)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return ente.EmailAuthorizationResponse{ID: userID, TwoFactorSessionID: twoFactorSessionID}, nil
|
|
}
|
|
|
|
}
|
|
|
|
token, err := auth.GenerateURLSafeRandomString(TokenLength)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
keyAttributes, err := c.UserRepo.GetKeyAttributes(userID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = c.UserAuthRepo.AddToken(userID, auth.GetApp(context), token,
|
|
network.GetClientIP(context), context.Request.UserAgent())
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return ente.EmailAuthorizationResponse{ID: userID, Token: token}, nil
|
|
} else {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
}
|
|
encryptedToken, err := crypto.GetEncryptedToken(token, keyAttributes.PublicKey)
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
err = c.UserAuthRepo.AddToken(userID, auth.GetApp(context), token,
|
|
network.GetClientIP(context), context.Request.UserAgent())
|
|
if err != nil {
|
|
return ente.EmailAuthorizationResponse{}, stacktrace.Propagate(err, "")
|
|
}
|
|
return ente.EmailAuthorizationResponse{
|
|
ID: userID,
|
|
KeyAttributes: &keyAttributes,
|
|
EncryptedToken: encryptedToken,
|
|
}, nil
|
|
|
|
}
|
|
|
|
func convertStringToBytes(s string) []byte {
|
|
b, err := base64.StdEncoding.DecodeString(s)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func convertBytesToString(b []byte) string {
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
}
|