mirror of
https://github.com/ente-io/ente.git
synced 2025-08-07 23:18:10 +00:00
clean up
This commit is contained in:
parent
4bbe1ae0d2
commit
76b2a73f9a
@ -754,7 +754,7 @@ func main() {
|
||||
pushHandler := &api.PushHandler{PushController: pushController}
|
||||
privateAPI.POST("/push/token", pushHandler.AddToken)
|
||||
|
||||
embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName)
|
||||
embeddingController := embeddingCtrl.New(embeddingRepo, objectCleanupController, queueRepo, taskLockingRepo, fileRepo, hostName)
|
||||
|
||||
offerHandler := &api.OfferHandler{Controller: offerController}
|
||||
publicAPI.GET("/offers/black-friday", offerHandler.GetBlackFridayOffers)
|
||||
|
@ -1,57 +1,30 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
gTime "time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"github.com/ente-io/museum/pkg/controller"
|
||||
"github.com/ente-io/museum/pkg/controller/access"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
"github.com/ente-io/museum/pkg/repo/embedding"
|
||||
"github.com/ente-io/museum/pkg/utils/s3config"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxEmbeddingDataSize is the min size of an embedding object in bytes
|
||||
minEmbeddingDataSize = 2048
|
||||
embeddingFetchTimeout = 10 * gTime.Second
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
Repo *embedding.Repository
|
||||
AccessCtrl access.Controller
|
||||
ObjectCleanupController *controller.ObjectCleanupController
|
||||
S3Config *s3config.S3Config
|
||||
QueueRepo *repo.QueueRepository
|
||||
TaskLockingRepo *repo.TaskLockRepository
|
||||
FileRepo *repo.FileRepository
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
HostName string
|
||||
cleanupCronRunning bool
|
||||
derivedStorageDataCenter string
|
||||
downloadManagerCache map[string]*s3manager.Downloader
|
||||
Repo *embedding.Repository
|
||||
ObjectCleanupController *controller.ObjectCleanupController
|
||||
QueueRepo *repo.QueueRepository
|
||||
TaskLockingRepo *repo.TaskLockRepository
|
||||
FileRepo *repo.FileRepository
|
||||
HostName string
|
||||
cleanupCronRunning bool
|
||||
}
|
||||
|
||||
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
||||
embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter()}
|
||||
cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
|
||||
for i := range embeddingDcs {
|
||||
s3Client := s3Config.GetS3Client(embeddingDcs[i])
|
||||
cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
|
||||
}
|
||||
func New(repo *embedding.Repository, objectCleanupController *controller.ObjectCleanupController, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, hostName string) *Controller {
|
||||
return &Controller{
|
||||
Repo: repo,
|
||||
AccessCtrl: accessCtrl,
|
||||
ObjectCleanupController: objectCleanupController,
|
||||
S3Config: s3Config,
|
||||
QueueRepo: queueRepo,
|
||||
TaskLockingRepo: taskLockingRepo,
|
||||
FileRepo: fileRepo,
|
||||
CollectionRepo: collectionRepo,
|
||||
HostName: hostName,
|
||||
derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
|
||||
downloadManagerCache: cache,
|
||||
Repo: repo,
|
||||
ObjectCleanupController: objectCleanupController,
|
||||
QueueRepo: queueRepo,
|
||||
TaskLockingRepo: taskLockingRepo,
|
||||
FileRepo: fileRepo,
|
||||
HostName: hostName,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,12 +3,8 @@ package embedding
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/lib/pq"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Repository defines the methods for inserting, updating and retrieving
|
||||
@ -17,74 +13,6 @@ type Repository struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
// Create inserts a new embedding
|
||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) {
|
||||
var updatedAt int64
|
||||
err := r.DB.QueryRowContext(ctx, `
|
||||
INSERT INTO embeddings
|
||||
(file_id, owner_id, model, size, version, datacenters)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, ARRAY[$6]::s3region[])
|
||||
ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model
|
||||
DO UPDATE
|
||||
SET
|
||||
updated_at = now_utc_micro_seconds(),
|
||||
size = $4,
|
||||
version = $5,
|
||||
datacenters = CASE
|
||||
WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters
|
||||
ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region)
|
||||
END
|
||||
RETURNING updated_at`,
|
||||
entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt)
|
||||
|
||||
if err != nil {
|
||||
// check if error is due to model enum invalid value
|
||||
if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) {
|
||||
return ente.Embedding{}, stacktrace.Propagate(ente.ErrBadRequest, "invalid model value")
|
||||
}
|
||||
return ente.Embedding{}, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return ente.Embedding{
|
||||
FileID: entry.FileID,
|
||||
Model: entry.Model,
|
||||
EncryptedEmbedding: entry.EncryptedEmbedding,
|
||||
DecryptionHeader: entry.DecryptionHeader,
|
||||
UpdatedAt: updatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDiff returns the embeddings that have been updated since the given time
|
||||
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT $4`, ownerID, model, sinceTime, limit)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return convertRowsToEmbeddings(rows)
|
||||
}
|
||||
|
||||
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return convertRowsToEmbeddings(rows)
|
||||
}
|
||||
|
||||
func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error {
|
||||
_, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) Delete(fileID int64) error {
|
||||
_, err := r.DB.Exec("DELETE FROM embeddings WHERE file_id = $1", fileID)
|
||||
if err != nil {
|
||||
@ -117,33 +45,6 @@ func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string
|
||||
return datacenters, nil
|
||||
}
|
||||
|
||||
// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC
|
||||
func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
uniqueDatacenters := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var datacenters []string
|
||||
err = rows.Scan(pq.Array(&datacenters))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
for _, dc := range datacenters {
|
||||
// add to uniqueDatacenters if it is not the ignoredDC
|
||||
if dc != ignoredDC {
|
||||
uniqueDatacenters[dc] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
datacenters := make([]string, 0, len(uniqueDatacenters))
|
||||
for dc := range uniqueDatacenters {
|
||||
datacenters = append(datacenters, dc)
|
||||
}
|
||||
return datacenters, nil
|
||||
}
|
||||
|
||||
// RemoveDatacenter removes the given datacenter from the list of datacenters
|
||||
func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error {
|
||||
_, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID)
|
||||
@ -152,87 +53,3 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding
|
||||
func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error {
|
||||
res, err := r.DB.ExecContext(ctx, `
|
||||
UPDATE embeddings
|
||||
SET size = $1,
|
||||
datacenters = CASE
|
||||
WHEN $2::s3region = ANY(datacenters) THEN datacenters
|
||||
ELSE array_append(datacenters, $2::s3region)
|
||||
END
|
||||
WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return stacktrace.Propagate(errors.New("no row got updated"), "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Repository) GetIndexedFiles(ctx context.Context, id int64, model ente.Model, since int64, limit *int64) ([]ente.IndexedFile, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if limit == nil {
|
||||
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3`, id, model, since)
|
||||
} else {
|
||||
rows, err = r.DB.QueryContext(ctx, `SELECT file_id, updated_at FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 LIMIT $4`, id, model, since, *limit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
logrus.Error(err)
|
||||
}
|
||||
}()
|
||||
result := make([]ente.IndexedFile, 0)
|
||||
for rows.Next() {
|
||||
var meta ente.IndexedFile
|
||||
err := rows.Scan(&meta.FileID, &meta.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
result = append(result, meta)
|
||||
}
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
||||
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
logrus.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
result := make([]ente.Embedding, 0)
|
||||
for rows.Next() {
|
||||
embedding := ente.Embedding{}
|
||||
var encryptedEmbedding, decryptionHeader sql.NullString
|
||||
var version sql.NullInt32
|
||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size)
|
||||
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
|
||||
embedding.EncryptedEmbedding = encryptedEmbedding.String
|
||||
}
|
||||
if decryptionHeader.Valid && len(decryptionHeader.String) > 0 {
|
||||
embedding.DecryptionHeader = decryptionHeader.String
|
||||
}
|
||||
v := 1
|
||||
if version.Valid {
|
||||
v = int(version.Int32)
|
||||
}
|
||||
embedding.Version = &v
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
result = append(result, embedding)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user