package repo import ( "database/sql" "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/utils/crypto" "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" ) type TwoFactorRepository struct { DB *sql.DB SecretEncryptionKey []byte } // GetTwoFactorSecret gets the user's two factor secret func (repo *TwoFactorRepository) GetTwoFactorSecret(userID int64) (string, error) { var encryptedTwoFASecret, nonce []byte row := repo.DB.QueryRow(`SELECT encrypted_two_factor_secret, two_factor_secret_decryption_nonce FROM two_factor WHERE user_id = $1`, userID) err := row.Scan(&encryptedTwoFASecret, &nonce) if err != nil { return "", stacktrace.Propagate(err, "") } twoFASecret, err := crypto.Decrypt(encryptedTwoFASecret, repo.SecretEncryptionKey, nonce) if err != nil { return "", stacktrace.Propagate(err, "") } return twoFASecret, nil } // UpdateTwoFactorStatus the activates/deactivates user's two factor func (repo *TwoFactorRepository) UpdateTwoFactorStatus(userID int64, status bool) error { _, err := repo.DB.Exec(`UPDATE users SET is_two_factor_enabled = $1 WHERE user_id = $2`, status, userID) return stacktrace.Propagate(err, "") } // AddTwoFactorSession added a new two factor session a user func (repo *TwoFactorRepository) AddTwoFactorSession(userID int64, sessionID string, expirationTime int64) error { _, err := repo.DB.Exec(`INSERT INTO two_factor_sessions(user_id, session_id, creation_time, expiration_time) VALUES($1, $2, $3, $4)`, userID, sessionID, time.Microseconds(), expirationTime) return stacktrace.Propagate(err, "") } // RemoveExpiredTwoFactorSessions removes all two factor sessions that have expired func (repo *TwoFactorRepository) RemoveExpiredTwoFactorSessions() error { _, err := repo.DB.Exec(`DELETE FROM two_factor_sessions WHERE expiration_time <= $1`, time.Microseconds()) return stacktrace.Propagate(err, "") } // GetUserIDWithTwoFactorSession returns the userID associated with a given session func (repo *TwoFactorRepository) GetUserIDWithTwoFactorSession(sessionID string) (int64, error) { row := repo.DB.QueryRow(`SELECT user_id FROM two_factor_sessions WHERE session_id = $1 AND expiration_time > $2`, sessionID, time.Microseconds()) var id int64 err := row.Scan(&id) if err != nil { return -1, stacktrace.Propagate(err, "") } return id, nil } // GetRecoveryKeyEncryptedTwoFactorSecret gets the user two factor encrypted with recovery key func (repo *TwoFactorRepository) GetRecoveryKeyEncryptedTwoFactorSecret(userID int64) (ente.TwoFactorRecoveryResponse, error) { var response ente.TwoFactorRecoveryResponse row := repo.DB.QueryRow(`SELECT recovery_encrypted_two_factor_secret, recovery_two_factor_secret_decryption_nonce FROM two_factor WHERE user_id = $1`, userID) err := row.Scan(&response.EncryptedSecret, &response.SecretDecryptionNonce) if err != nil { return ente.TwoFactorRecoveryResponse{}, stacktrace.Propagate(err, "") } return response, nil } // VerifyTwoFactorSecret verifies the if a two secret factor secret belongs to a user func (repo *TwoFactorRepository) VerifyTwoFactorSecret(userID int64, secretHash string) (bool, error) { var exists bool row := repo.DB.QueryRow(`SELECT EXISTS( SELECT 1 FROM two_factor WHERE user_id = $1 AND two_factor_secret_hash = $2)`, userID, secretHash) err := row.Scan(&exists) if err != nil { return false, stacktrace.Propagate(err, "") } return exists, nil } // SetTempTwoFactorSecret sets the two factor secret for a user when he tries to setup a new two-factor app func (repo *TwoFactorRepository) SetTempTwoFactorSecret(userID int64, secret ente.EncryptionResult, secretHash string, expirationTime int64) error { _, err := repo.DB.Exec(`INSERT INTO temp_two_factor(user_id, encrypted_two_factor_secret, two_factor_secret_decryption_nonce, two_factor_secret_hash, creation_time, expiration_time) VALUES($1, $2, $3, $4, $5, $6)`, userID, secret.Cipher, secret.Nonce, secretHash, time.Microseconds(), expirationTime) return stacktrace.Propagate(err, "") } // GetTempTwoFactorSecret gets the user's two factor secret for validing and enabling a new two-factor configuration func (repo *TwoFactorRepository) GetTempTwoFactorSecret(userID int64) ([]ente.EncryptionResult, []string, error) { rows, err := repo.DB.Query(`SELECT encrypted_two_factor_secret, two_factor_secret_decryption_nonce, two_factor_secret_hash FROM temp_two_factor WHERE user_id = $1 AND expiration_time > $2`, userID, time.Microseconds()) if err != nil { return make([]ente.EncryptionResult, 0), make([]string, 0), stacktrace.Propagate(err, "") } defer rows.Close() encryptedSecrets := make([]ente.EncryptionResult, 0) hashedSecrets := make([]string, 0) for rows.Next() { var encryptedTwoFASecret ente.EncryptionResult var secretHash string err := rows.Scan(&encryptedTwoFASecret.Cipher, &encryptedTwoFASecret.Nonce, &secretHash) if err != nil { return make([]ente.EncryptionResult, 0), make([]string, 0), stacktrace.Propagate(err, "") } encryptedSecrets = append(encryptedSecrets, encryptedTwoFASecret) hashedSecrets = append(hashedSecrets, secretHash) } return encryptedSecrets, hashedSecrets, nil } // RemoveTempTwoFactorSecret removes the specified secret with hash value `secretHash` func (repo *TwoFactorRepository) RemoveTempTwoFactorSecret(secretHash string) error { _, err := repo.DB.Exec(`DELETE FROM temp_two_factor WHERE two_factor_secret_hash = $1`, secretHash) return stacktrace.Propagate(err, "") } // RemoveExpiredTempTwoFactorSecrets removes all two temp factor secrets that have expired func (repo *TwoFactorRepository) RemoveExpiredTempTwoFactorSecrets() error { _, err := repo.DB.Exec(`DELETE FROM temp_two_factor WHERE expiration_time <= $1`, time.Microseconds()) return stacktrace.Propagate(err, "") }