2025-02-26 14:54:56 +05:30

341 lines
11 KiB
Go

package filedata
import (
"context"
"errors"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/ente-io/museum/ente"
fileData "github.com/ente-io/museum/ente/filedata"
"github.com/ente-io/museum/pkg/controller"
"github.com/ente-io/museum/pkg/controller/access"
"github.com/ente-io/museum/pkg/repo"
fileDataRepo "github.com/ente-io/museum/pkg/repo/filedata"
"github.com/ente-io/museum/pkg/utils/array"
"github.com/ente-io/museum/pkg/utils/auth"
"github.com/ente-io/museum/pkg/utils/network"
"github.com/ente-io/museum/pkg/utils/s3config"
"github.com/ente-io/stacktrace"
"github.com/gin-contrib/requestid"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"sync"
gTime "time"
)
// _fetchConfig is the configuration for the fetching objects from S3
type _fetchConfig struct {
RetryCount int
InitialTimeout gTime.Duration
MaxTimeout gTime.Duration
}
var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second}
var globalFileFetchSemaphore = make(chan struct{}, 400)
type bulkS3MetaFetchResult struct {
s3MetaObject fileData.S3FileMetadata
dbEntry fileData.Row
err error
}
type Controller struct {
Repo *fileDataRepo.Repository
AccessCtrl access.Controller
ObjectCleanupController *controller.ObjectCleanupController
S3Config *s3config.S3Config
FileRepo *repo.FileRepository
CollectionRepo *repo.CollectionRepository
downloadManagerCache map[string]*s3manager.Downloader
// for downloading objects from s3 for replication
workerURL string
tempStorage string
}
func New(repo *fileDataRepo.Repository,
accessCtrl access.Controller,
objectCleanupController *controller.ObjectCleanupController,
s3Config *s3config.S3Config,
fileRepo *repo.FileRepository,
collectionRepo *repo.CollectionRepository,
) *Controller {
embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter(), "b5", "b6"}
cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
for i := range embeddingDcs {
s3Client := s3Config.GetS3Client(embeddingDcs[i])
cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
}
return &Controller{
Repo: repo,
AccessCtrl: accessCtrl,
ObjectCleanupController: objectCleanupController,
S3Config: s3Config,
FileRepo: fileRepo,
CollectionRepo: collectionRepo,
downloadManagerCache: cache,
}
}
func (c *Controller) InsertOrUpdateMetadata(ctx *gin.Context, req *fileData.PutFileDataRequest) error {
if err := req.Validate(); err != nil {
return stacktrace.Propagate(err, "validation failed")
}
userID := auth.GetUserID(ctx.Request.Header)
fileOwnerID, err := c.FileRepo.GetOwnerID(req.FileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if fileOwnerID != userID {
permErr := c._checkMetadataReadOrWritePerm(ctx, userID, []int64{req.FileID})
if permErr != nil {
return stacktrace.Propagate(permErr, "")
}
}
if req.Type != ente.MlData {
return stacktrace.Propagate(ente.NewBadRequestWithMessage("unsupported object type "+string(req.Type)), "")
}
bucketID := c.S3Config.GetBucketID(req.Type)
objectKey := fileData.ObjectMetadataKey(req.FileID, fileOwnerID, req.Type, nil)
obj := fileData.S3FileMetadata{
Version: *req.Version,
EncryptedData: *req.EncryptedData,
DecryptionHeader: *req.DecryptionHeader,
Client: network.GetClientInfo(ctx),
}
// Start a goroutine to handle the upload and insert operations
//go func() {
logger := log.WithField("objectKey", objectKey).WithField("fileID", req.FileID).WithField("type", req.Type)
size, uploadErr := c.uploadObject(obj, objectKey, bucketID)
if uploadErr != nil {
logger.WithError(uploadErr).Error("upload failed")
return uploadErr
}
row := fileData.Row{
FileID: req.FileID,
Type: req.Type,
UserID: fileOwnerID,
Size: size,
LatestBucket: bucketID,
}
dbInsertErr := c.Repo.InsertOrUpdate(context.Background(), row)
if dbInsertErr != nil {
logger.WithError(dbInsertErr).Error("insert or update failed")
return uploadErr
}
//}()
return nil
}
func (c *Controller) GetFileData(ctx *gin.Context, req fileData.GetFileData) (*fileData.Entity, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "validation failed")
}
if err := c._checkMetadataReadOrWritePerm(ctx, userID, []int64{req.FileID}); err != nil {
return nil, stacktrace.Propagate(err, "")
}
doRows, err := c.Repo.GetFilesData(ctx, req.Type, []int64{req.FileID})
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
if len(doRows) == 0 || doRows[0].IsDeleted {
return nil, stacktrace.Propagate(&ente.ErrNotFoundError, "")
}
ctxLogger := log.WithFields(log.Fields{
"objectKey": doRows[0].S3FileMetadataObjectKey(),
"latest_bucket": doRows[0].LatestBucket,
"req_id": requestid.Get(ctx),
"file_id": req.FileID,
})
s3MetaObject, err := c.fetchS3FileMetadata(context.Background(), doRows[0], ctxLogger)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return &fileData.Entity{
FileID: doRows[0].FileID,
Type: doRows[0].Type,
EncryptedData: s3MetaObject.EncryptedData,
DecryptionHeader: s3MetaObject.DecryptionHeader,
}, nil
}
func (c *Controller) GetFilesData(ctx *gin.Context, req fileData.GetFilesData) (*fileData.GetFilesDataResponse, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := req.Validate(); err != nil {
return nil, stacktrace.Propagate(err, "req validation failed")
}
if err := c._checkMetadataReadOrWritePerm(ctx, userID, req.FileIDs); err != nil {
return nil, stacktrace.Propagate(err, "")
}
doRows, err := c.Repo.GetFilesData(ctx, req.Type, req.FileIDs)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
activeRows := make([]fileData.Row, 0)
dbFileIds := make([]int64, 0)
errFileIds := make([]int64, 0)
for i := range doRows {
dbFileIds = append(dbFileIds, doRows[i].FileID)
if !doRows[i].IsDeleted {
activeRows = append(activeRows, doRows[i])
}
}
pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
// Fetch missing doRows in parallel
s3MetaFetchResults, err := c.getS3FileMetadataParallel(ctx, activeRows)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
fetchedEmbeddings := make([]fileData.Entity, 0)
// Populate missing data in doRows from fetched objects
for _, obj := range s3MetaFetchResults {
if obj.err != nil {
errFileIds = append(errFileIds, obj.dbEntry.FileID)
} else {
fetchedEmbeddings = append(fetchedEmbeddings, fileData.Entity{
FileID: obj.dbEntry.FileID,
Type: obj.dbEntry.Type,
EncryptedData: obj.s3MetaObject.EncryptedData,
DecryptionHeader: obj.s3MetaObject.DecryptionHeader,
})
}
}
return &fileData.GetFilesDataResponse{
Data: fetchedEmbeddings,
PendingIndexFileIDs: pendingIndexFileIds,
ErrFileIDs: errFileIds,
}, nil
}
func (c *Controller) getS3FileMetadataParallel(ctx *gin.Context, dbRows []fileData.Row) ([]bulkS3MetaFetchResult, error) {
var wg sync.WaitGroup
embeddingObjects := make([]bulkS3MetaFetchResult, len(dbRows))
for i := range dbRows {
dbRow := dbRows[i]
wg.Add(1)
globalFileFetchSemaphore <- struct{}{} // Acquire from global semaphore
go func(i int, row fileData.Row) {
defer wg.Done()
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
ctxLogger := log.WithFields(log.Fields{
"objectKey": row.S3FileMetadataObjectKey(),
"req_id": requestid.Get(ctx),
"latest_bucket": row.LatestBucket,
"file_id": row.FileID,
})
s3FileMetadata, err := c.fetchS3FileMetadata(context.Background(), row, ctxLogger)
if err != nil {
ctxLogger.
Error("error fetching object: "+row.S3FileMetadataObjectKey(), err)
embeddingObjects[i] = bulkS3MetaFetchResult{
err: err,
dbEntry: row,
}
} else {
embeddingObjects[i] = bulkS3MetaFetchResult{
s3MetaObject: *s3FileMetadata,
dbEntry: dbRow,
}
}
}(i, dbRow)
}
wg.Wait()
return embeddingObjects, nil
}
func (c *Controller) fetchS3FileMetadata(ctx context.Context, row fileData.Row, ctxLogger *log.Entry) (*fileData.S3FileMetadata, error) {
dc := row.LatestBucket
// :todo:neeraj make it configurable to
// specify preferred dc to read from
// and fallback logic to read from different bucket when we fail to read from preferred dc
if dc == "b6" {
if array.StringInList("b5", row.ReplicatedBuckets) {
dc = "b5"
}
}
opt := _defaultFetchConfig
objectKey := row.S3FileMetadataObjectKey()
totalAttempts := opt.RetryCount + 1
timeout := opt.InitialTimeout
for i := 0; i < totalAttempts; i++ {
if i > 0 {
timeout = timeout * 2
if timeout > opt.MaxTimeout {
timeout = opt.MaxTimeout
}
}
fetchCtx, cancel := context.WithTimeout(ctx, timeout)
select {
case <-ctx.Done():
cancel()
return nil, stacktrace.Propagate(ctx.Err(), "")
default:
obj, err := c.downloadObject(fetchCtx, objectKey, dc)
cancel() // Ensure cancel is called to release resources
if err == nil {
if i > 0 {
ctxLogger.WithField("dc", dc).Infof("Fetched object after %d attempts", i)
}
return &obj, nil
}
// Check if the error is due to context timeout or cancellation
if err == nil && fetchCtx.Err() != nil {
ctxLogger.WithField("dc", dc).Error("Fetch timed out or cancelled: ", fetchCtx.Err())
} else {
// check if the error is due to object not found
if s3Err, ok := err.(awserr.RequestFailure); ok {
if s3Err.Code() == s3.ErrCodeNoSuchKey {
return nil, stacktrace.Propagate(errors.New("object not found"), "")
}
}
ctxLogger.WithField("dc", dc).Error("Failed to fetch object: ", err)
}
}
}
return nil, stacktrace.Propagate(errors.New("failed to fetch object"), "")
}
func (c *Controller) _checkMetadataReadOrWritePerm(ctx *gin.Context, userID int64, fileIDs []int64) error {
if err := c.AccessCtrl.CanAccessFile(ctx, &access.CanAccessFileParams{
ActorUserID: userID,
FileIDs: fileIDs,
}); err != nil {
return stacktrace.Propagate(err, "User does not own some file(s)")
}
return nil
}
// _checkPreviewWritePerm is
func (c *Controller) _checkPreviewWritePerm(ctx *gin.Context, fileID int64, actorID int64) error {
err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: actorID,
FileIDs: []int64{fileID},
})
if err != nil {
return stacktrace.Propagate(err, "User does not own file")
}
count, err := c.CollectionRepo.GetCollectionCount(fileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
if count < 1 {
return stacktrace.Propagate(ente.ErrNotFound, "")
}
return nil
}
func (c *Controller) FileDataStatusDiff(ctx *gin.Context, req fileData.FDDiffRequest) ([]fileData.FDStatus, error) {
userID := auth.GetUserID(ctx.Request.Header)
return c.Repo.GetFDForUser(ctx, userID, *req.LastUpdatedAt, 5000)
}