ente/server/pkg/repo/embedding/repository.go
2024-07-22 16:29:20 +05:30

239 lines
8.2 KiB
Go

package embedding
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ente-io/museum/ente"
"github.com/ente-io/stacktrace"
"github.com/lib/pq"
"github.com/sirupsen/logrus"
)
// Repository defines the methods for inserting, updating and retrieving
// ML embedding
type Repository struct {
DB *sql.DB
}
// Create inserts a new embedding
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) {
var updatedAt int64
err := r.DB.QueryRowContext(ctx, `
INSERT INTO embeddings
(file_id, owner_id, model, size, version, datacenters)
VALUES
($1, $2, $3, $4, $5, ARRAY[$6]::s3region[])
ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model
DO UPDATE
SET
updated_at = now_utc_micro_seconds(),
size = $4,
version = $5,
datacenters = CASE
WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters
ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region)
END
RETURNING updated_at`,
entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt)
if err != nil {
// check if error is due to model enum invalid value
if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) {
return ente.Embedding{}, stacktrace.Propagate(ente.ErrBadRequest, "invalid model value")
}
return ente.Embedding{}, stacktrace.Propagate(err, "")
}
return ente.Embedding{
FileID: entry.FileID,
Model: entry.Model,
EncryptedEmbedding: entry.EncryptedEmbedding,
DecryptionHeader: entry.DecryptionHeader,
UpdatedAt: updatedAt,
}, nil
}
// GetDiff returns the embeddings that have been updated since the given time
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
FROM embeddings
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
ORDER BY updated_at ASC
LIMIT $4`, ownerID, model, sinceTime, limit)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return convertRowsToEmbeddings(rows)
}
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
FROM embeddings
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return convertRowsToEmbeddings(rows)
}
func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error {
_, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}
func (r *Repository) Delete(fileID int64) error {
_, err := r.DB.Exec("DELETE FROM embeddings WHERE file_id = $1", fileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}
// GetDatacenters returns unique list of datacenters where derived embeddings are stored
func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1`, fileID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
uniqueDatacenters := make(map[string]struct{})
for rows.Next() {
var datacenters []string
err = rows.Scan(pq.Array(&datacenters))
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
for _, dc := range datacenters {
uniqueDatacenters[dc] = struct{}{}
}
}
datacenters := make([]string, 0, len(uniqueDatacenters))
for dc := range uniqueDatacenters {
datacenters = append(datacenters, dc)
}
return datacenters, nil
}
// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC
func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
uniqueDatacenters := make(map[string]bool)
for rows.Next() {
var datacenters []string
err = rows.Scan(pq.Array(&datacenters))
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
for _, dc := range datacenters {
// add to uniqueDatacenters if it is not the ignoredDC
if dc != ignoredDC {
uniqueDatacenters[dc] = true
}
}
}
datacenters := make([]string, 0, len(uniqueDatacenters))
for dc := range uniqueDatacenters {
datacenters = append(datacenters, dc)
}
return datacenters, nil
}
// RemoveDatacenter removes the given datacenter from the list of datacenters
func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error {
_, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}
// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding
func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error {
res, err := r.DB.ExecContext(ctx, `
UPDATE embeddings
SET size = $1,
datacenters = CASE
WHEN $2::s3region = ANY(datacenters) THEN datacenters
ELSE array_append(datacenters, $2::s3region)
END
WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID)
if err != nil {
return stacktrace.Propagate(err, "")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return stacktrace.Propagate(err, "")
}
if rowsAffected == 0 {
return stacktrace.Propagate(errors.New("no row got updated"), "")
}
return nil
}
func (r *Repository) GetIndexedFiles(ctx context.Context, id int64, model ente.Model, since int64, limit *int64) ([]ente.IndexedFile, error) {
var rows *sql.Rows
var err error
if limit == nil {
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3`, id, model, since)
} else {
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 LIMIT $4`, id, model, since, *limit)
}
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
defer func() {
if err := rows.Close(); err != nil {
logrus.Error(err)
}
}()
result := make([]ente.IndexedFile, 0)
for rows.Next() {
var meta ente.IndexedFile
err := rows.Scan(&meta.FileID, &meta.UpdatedAt)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
result = append(result, meta)
}
return result, nil
}
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
defer func() {
if err := rows.Close(); err != nil {
logrus.Error(err)
}
}()
result := make([]ente.Embedding, 0)
for rows.Next() {
embedding := ente.Embedding{}
var encryptedEmbedding, decryptionHeader sql.NullString
var version sql.NullInt32
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size)
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
embedding.EncryptedEmbedding = encryptedEmbedding.String
}
if decryptionHeader.Valid && len(decryptionHeader.String) > 0 {
embedding.DecryptionHeader = decryptionHeader.String
}
v := 1
if version.Valid {
v = int(version.Int32)
}
embedding.Version = &v
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
result = append(result, embedding)
}
return result, nil
}