ente/server/pkg/repo/storagebonus/referral_codes.go
2024-03-01 13:37:01 +05:30

56 lines
2.2 KiB
Go

package storagebonus
import (
"context"
"database/sql"
entity "github.com/ente-io/museum/ente/storagebonus"
"github.com/ente-io/stacktrace"
)
// Add context as first parameter in all methods in this file
// GetCode returns the storagebonus code for the given userID
func (r *Repository) GetCode(ctx context.Context, userID int64) (*string, error) {
var code *string
err := r.DB.QueryRowContext(ctx, "SELECT code FROM referral_codes WHERE user_id = $1 and is_active = TRUE", userID).Scan(&code)
return code, stacktrace.Propagate(err, "failed to get storagebonus code for user %d", userID)
}
// InsertCode for the given userID
func (r *Repository) InsertCode(ctx context.Context, userID int64, code string) error {
_, err := r.DB.ExecContext(ctx, "INSERT INTO referral_codes (user_id, code) VALUES ($1, $2)", userID, code)
if err != nil {
if err.Error() == "pq: duplicate key value violates unique constraint \"referral_codes_pkey\"" {
return stacktrace.Propagate(entity.CodeAlreadyExistsErr, "duplicate storagebonus code for user %d", userID)
}
return stacktrace.Propagate(err, "failed to insert storagebonus code for user %d", userID)
}
return nil
}
// AddNewCode and mark the old one as inactive for a given userID.
// Note: This method is not being used in the initial MVP as we don't allow user to change the storagebonus
// code
func (r *Repository) AddNewCode(ctx context.Context, userID int64, code string) error {
_, err := r.DB.ExecContext(ctx, "UPDATE referral_codes SET is_active = FALSE WHERE user_id = $1", userID)
if err != nil {
return stacktrace.Propagate(err, "failed to update storagebonus code for user %d", userID)
}
return r.InsertCode(ctx, userID, code)
}
// GetUserIDByCode returns the userID for the given storagebonus code. The method will also return the userID
// if the code is inactive.
func (r *Repository) GetUserIDByCode(ctx context.Context, code string) (*int64, error) {
var userID int64
err := r.DB.QueryRowContext(ctx, "SELECT user_id FROM referral_codes WHERE code = $1", code).Scan(&userID)
if err != nil {
if err == sql.ErrNoRows {
return nil, stacktrace.Propagate(entity.InvalidCodeErr, "code %s not found", code)
}
return nil, err
}
return &userID, nil
}