package storagebonus import ( "context" "database/sql" entity "github.com/ente-io/museum/ente/storagebonus" "github.com/ente-io/stacktrace" ) // Add context as first parameter in all methods in this file // GetCode returns the storagebonus code for the given userID func (r *Repository) GetCode(ctx context.Context, userID int64) (*string, error) { var code *string err := r.DB.QueryRowContext(ctx, "SELECT code FROM referral_codes WHERE user_id = $1 and is_active = TRUE", userID).Scan(&code) return code, stacktrace.Propagate(err, "failed to get storagebonus code for user %d", userID) } // InsertCode for the given userID func (r *Repository) InsertCode(ctx context.Context, userID int64, code string) error { _, err := r.DB.ExecContext(ctx, "INSERT INTO referral_codes (user_id, code) VALUES ($1, $2)", userID, code) if err != nil { if err.Error() == "pq: duplicate key value violates unique constraint \"referral_codes_pkey\"" { return stacktrace.Propagate(entity.CodeAlreadyExistsErr, "duplicate storagebonus code for user %d", userID) } return stacktrace.Propagate(err, "failed to insert storagebonus code for user %d", userID) } return nil } // AddNewCode and mark the old one as inactive for a given userID. // Note: This method is not being used in the initial MVP as we don't allow user to change the storagebonus // code func (r *Repository) AddNewCode(ctx context.Context, userID int64, code string) error { _, err := r.DB.ExecContext(ctx, "UPDATE referral_codes SET is_active = FALSE WHERE user_id = $1", userID) if err != nil { return stacktrace.Propagate(err, "failed to update storagebonus code for user %d", userID) } return r.InsertCode(ctx, userID, code) } // GetUserIDByCode returns the userID for the given storagebonus code. The method will also return the userID // if the code is inactive. func (r *Repository) GetUserIDByCode(ctx context.Context, code string) (*int64, error) { var userID int64 err := r.DB.QueryRowContext(ctx, "SELECT user_id FROM referral_codes WHERE code = $1", code).Scan(&userID) if err != nil { if err == sql.ErrNoRows { return nil, stacktrace.Propagate(entity.InvalidCodeErr, "code %s not found", code) } return nil, err } return &userID, nil }