ente/server/pkg/repo/embedding/repository.go
Neeraj Gupta 76b2a73f9a clean up
2025-02-24 15:00:33 +05:30

56 lines
1.6 KiB
Go

package embedding
import (
"context"
"database/sql"
"github.com/ente-io/stacktrace"
"github.com/lib/pq"
)
// Repository defines the methods for inserting, updating and retrieving
// ML embedding
type Repository struct {
DB *sql.DB
}
func (r *Repository) Delete(fileID int64) error {
_, err := r.DB.Exec("DELETE FROM embeddings WHERE file_id = $1", fileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}
// GetDatacenters returns unique list of datacenters where derived embeddings are stored
func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1`, fileID)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
uniqueDatacenters := make(map[string]struct{})
for rows.Next() {
var datacenters []string
err = rows.Scan(pq.Array(&datacenters))
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
for _, dc := range datacenters {
uniqueDatacenters[dc] = struct{}{}
}
}
datacenters := make([]string, 0, len(uniqueDatacenters))
for dc := range uniqueDatacenters {
datacenters = append(datacenters, dc)
}
return datacenters, nil
}
// RemoveDatacenter removes the given datacenter from the list of datacenters
func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error {
_, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID)
if err != nil {
return stacktrace.Propagate(err, "")
}
return nil
}