diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 9bc1beef63..a4b4425a32 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -754,7 +754,7 @@ func main() { pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) - embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName) + embeddingController := embeddingCtrl.New(embeddingRepo, objectCleanupController, queueRepo, taskLockingRepo, fileRepo, hostName) offerHandler := &api.OfferHandler{Controller: offerController} publicAPI.GET("/offers/black-friday", offerHandler.GetBlackFridayOffers) diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index d8cb4e62c0..e1f23fa50f 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -1,57 +1,30 @@ package embedding import ( - "strconv" - gTime "time" - - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/ente-io/museum/pkg/controller" - "github.com/ente-io/museum/pkg/controller/access" "github.com/ente-io/museum/pkg/repo" "github.com/ente-io/museum/pkg/repo/embedding" - "github.com/ente-io/museum/pkg/utils/s3config" -) - -const ( - // maxEmbeddingDataSize is the min size of an embedding object in bytes - minEmbeddingDataSize = 2048 - embeddingFetchTimeout = 10 * gTime.Second + "strconv" ) type Controller struct { - Repo *embedding.Repository - AccessCtrl access.Controller - ObjectCleanupController *controller.ObjectCleanupController - S3Config *s3config.S3Config - QueueRepo *repo.QueueRepository - TaskLockingRepo *repo.TaskLockRepository - FileRepo *repo.FileRepository - CollectionRepo *repo.CollectionRepository - HostName string - cleanupCronRunning bool - derivedStorageDataCenter string - downloadManagerCache map[string]*s3manager.Downloader + Repo *embedding.Repository + ObjectCleanupController *controller.ObjectCleanupController + QueueRepo *repo.QueueRepository + TaskLockingRepo *repo.TaskLockRepository + FileRepo *repo.FileRepository + HostName string + cleanupCronRunning bool } -func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { - embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter()} - cache := make(map[string]*s3manager.Downloader, len(embeddingDcs)) - for i := range embeddingDcs { - s3Client := s3Config.GetS3Client(embeddingDcs[i]) - cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client) - } +func New(repo *embedding.Repository, objectCleanupController *controller.ObjectCleanupController, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, hostName string) *Controller { return &Controller{ - Repo: repo, - AccessCtrl: accessCtrl, - ObjectCleanupController: objectCleanupController, - S3Config: s3Config, - QueueRepo: queueRepo, - TaskLockingRepo: taskLockingRepo, - FileRepo: fileRepo, - CollectionRepo: collectionRepo, - HostName: hostName, - derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(), - downloadManagerCache: cache, + Repo: repo, + ObjectCleanupController: objectCleanupController, + QueueRepo: queueRepo, + TaskLockingRepo: taskLockingRepo, + FileRepo: fileRepo, + HostName: hostName, } } diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index eb570b4e74..22eb92e8e1 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -3,12 +3,8 @@ package embedding import ( "context" "database/sql" - "errors" - "fmt" - "github.com/ente-io/museum/ente" "github.com/ente-io/stacktrace" "github.com/lib/pq" - "github.com/sirupsen/logrus" ) // Repository defines the methods for inserting, updating and retrieving @@ -17,74 +13,6 @@ type Repository struct { DB *sql.DB } -// Create inserts a new embedding -func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) { - var updatedAt int64 - err := r.DB.QueryRowContext(ctx, ` - INSERT INTO embeddings - (file_id, owner_id, model, size, version, datacenters) - VALUES - ($1, $2, $3, $4, $5, ARRAY[$6]::s3region[]) - ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model - DO UPDATE - SET - updated_at = now_utc_micro_seconds(), - size = $4, - version = $5, - datacenters = CASE - WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters - ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region) - END - RETURNING updated_at`, - entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt) - - if err != nil { - // check if error is due to model enum invalid value - if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) { - return ente.Embedding{}, stacktrace.Propagate(ente.ErrBadRequest, "invalid model value") - } - return ente.Embedding{}, stacktrace.Propagate(err, "") - } - return ente.Embedding{ - FileID: entry.FileID, - Model: entry.Model, - EncryptedEmbedding: entry.EncryptedEmbedding, - DecryptionHeader: entry.DecryptionHeader, - UpdatedAt: updatedAt, - }, nil -} - -// GetDiff returns the embeddings that have been updated since the given time -func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size - FROM embeddings - WHERE owner_id = $1 AND model = $2 AND updated_at > $3 - ORDER BY updated_at ASC - LIMIT $4`, ownerID, model, sinceTime, limit) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return convertRowsToEmbeddings(rows) -} - -func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size - FROM embeddings - WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs)) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - return convertRowsToEmbeddings(rows) -} - -func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error { - _, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID) - if err != nil { - return stacktrace.Propagate(err, "") - } - return nil -} - func (r *Repository) Delete(fileID int64) error { _, err := r.DB.Exec("DELETE FROM embeddings WHERE file_id = $1", fileID) if err != nil { @@ -117,33 +45,6 @@ func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string return datacenters, nil } -// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC -func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - uniqueDatacenters := make(map[string]bool) - for rows.Next() { - var datacenters []string - err = rows.Scan(pq.Array(&datacenters)) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - for _, dc := range datacenters { - // add to uniqueDatacenters if it is not the ignoredDC - if dc != ignoredDC { - uniqueDatacenters[dc] = true - } - } - } - datacenters := make([]string, 0, len(uniqueDatacenters)) - for dc := range uniqueDatacenters { - datacenters = append(datacenters, dc) - } - return datacenters, nil -} - // RemoveDatacenter removes the given datacenter from the list of datacenters func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error { _, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID) @@ -152,87 +53,3 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri } return nil } - -// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding -func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error { - res, err := r.DB.ExecContext(ctx, ` - UPDATE embeddings - SET size = $1, - datacenters = CASE - WHEN $2::s3region = ANY(datacenters) THEN datacenters - ELSE array_append(datacenters, $2::s3region) - END - WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) - if err != nil { - return stacktrace.Propagate(err, "") - } - rowsAffected, err := res.RowsAffected() - if err != nil { - return stacktrace.Propagate(err, "") - } - if rowsAffected == 0 { - return stacktrace.Propagate(errors.New("no row got updated"), "") - } - return nil -} - -func (r *Repository) GetIndexedFiles(ctx context.Context, id int64, model ente.Model, since int64, limit *int64) ([]ente.IndexedFile, error) { - var rows *sql.Rows - var err error - if limit == nil { - rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3`, id, model, since) - } else { - rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 LIMIT $4`, id, model, since, *limit) - } - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - defer func() { - if err := rows.Close(); err != nil { - logrus.Error(err) - } - }() - result := make([]ente.IndexedFile, 0) - for rows.Next() { - var meta ente.IndexedFile - err := rows.Scan(&meta.FileID, &meta.UpdatedAt) - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - result = append(result, meta) - } - return result, nil - -} - -func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { - defer func() { - if err := rows.Close(); err != nil { - logrus.Error(err) - } - }() - - result := make([]ente.Embedding, 0) - for rows.Next() { - embedding := ente.Embedding{} - var encryptedEmbedding, decryptionHeader sql.NullString - var version sql.NullInt32 - err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size) - if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 { - embedding.EncryptedEmbedding = encryptedEmbedding.String - } - if decryptionHeader.Valid && len(decryptionHeader.String) > 0 { - embedding.DecryptionHeader = decryptionHeader.String - } - v := 1 - if version.Valid { - v = int(version.Int32) - } - embedding.Version = &v - if err != nil { - return nil, stacktrace.Propagate(err, "") - } - result = append(result, embedding) - } - return result, nil -}