Skip to content

Commit

Permalink
Another attempt at implementing MSC701
Browse files Browse the repository at this point in the history
This is part of #103.

There's a couple concepts introduced in this setup. It was found that storing hashes in the database can't be done because then we'll have no string to use to decrypt the user's access token. 

This also doesn't work on federation at all and instead has a short circuit which may need to be expanded upon. The token isn't sent over federation at all, however it might need to be done in plain text. The problem with an encrypted access token for the user is that we won't know what the content token is for remote media, and therefore can't decrypt the user's access token. We equally don't want to send the user's access token over federation, so we may have to settle for throwing the content token around.

More thought is needed.
  • Loading branch information
turt2live committed Jun 17, 2018
1 parent af57e97 commit c164220
Show file tree
Hide file tree
Showing 21 changed files with 253 additions and 57 deletions.
1 change: 1 addition & 0 deletions migrations/6_add_content_token_down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE media DROP COLUMN content_token;
1 change: 1 addition & 0 deletions migrations/6_add_content_token_up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE media ADD COLUMN content_token TEXT NULL DEFAULT NULL;
10 changes: 9 additions & 1 deletion src/github.com/turt2live/matrix-media-repo/api/r0/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/turt2live/matrix-media-repo/api"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/controllers/download_controller"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)

type DownloadMediaResponse struct {
Expand All @@ -27,6 +29,10 @@ func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf
filename := params["filename"]
allowRemote := r.URL.Query().Get("allow_remote")

encryptedToken := util.GetMediaBearerTokenFromRequest(r)
appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
bearerToken := &types.BearerToken{EncryptedToken: encryptedToken, AppserviceUserId: appserviceUserId}

downloadRemote := true
if allowRemote != "" {
parsedFlag, err := strconv.ParseBool(allowRemote)
Expand All @@ -43,14 +49,16 @@ func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf
"allowRemote": downloadRemote,
})

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, r.Context(), log)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, bearerToken, r.Context(), log)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
} else if err == common.ErrMediaQuarantined {
return api.NotFoundError() // We lie for security
} else if err == common.ErrFailedAuthCheck {
return api.AuthFailed()
}
log.Error("Unexpected error locating media: " + err.Error())
return api.InternalServerError("Unexpected Error")
Expand Down
20 changes: 17 additions & 3 deletions src/github.com/turt2live/matrix-media-repo/api/r0/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"io/ioutil"
"net/http"
"strconv"

"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
Expand All @@ -12,7 +13,8 @@ import (
)

type MediaUploadedResponse struct {
ContentUri string `json:"content_uri"`
ContentUri string `json:"content_uri"`
ContentToken *string `json:"content_token,omitempty"`
}

func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
Expand All @@ -21,8 +23,20 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
filename = "upload.bin"
}

isPublicStr := r.URL.Query().Get("public")
isPublic := true
if isPublicStr != "" {
parsedFlag, err := strconv.ParseBool(isPublicStr)
if err != nil {
return api.InternalServerError("public flag does not appear to be a boolean")
}

isPublic = parsedFlag
}

log = log.WithFields(logrus.Fields{
"filename": filename,
"isPublic": isPublic,
})

contentType := r.Header.Get("Content-Type")
Expand All @@ -36,7 +50,7 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
return api.RequestTooLarge()
}

media, err := upload_controller.UploadMedia(r.Body, contentType, filename, user.UserId, r.Host, r.Context(), log)
media, err := upload_controller.UploadMedia(r.Body, contentType, filename, user.UserId, r.Host, isPublic, r.Context(), log)
if err != nil {
io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
defer r.Body.Close()
Expand All @@ -49,5 +63,5 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
return api.InternalServerError("Unexpected Error")
}

return &MediaUploadedResponse{media.MxcUri()}
return &MediaUploadedResponse{ContentUri: media.MxcUri(), ContentToken: media.ContentToken}
}
10 changes: 9 additions & 1 deletion src/github.com/turt2live/matrix-media-repo/api/unstable/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/turt2live/matrix-media-repo/api"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/controllers/download_controller"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)

type MediaInfoResponse struct {
Expand All @@ -27,6 +29,10 @@ func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{
mediaId := params["mediaId"]
allowRemote := r.URL.Query().Get("allow_remote")

encryptedToken := util.GetMediaBearerTokenFromRequest(r)
appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
bearerToken := &types.BearerToken{EncryptedToken: encryptedToken, AppserviceUserId: appserviceUserId}

downloadRemote := true
if allowRemote != "" {
parsedFlag, err := strconv.ParseBool(allowRemote)
Expand All @@ -42,14 +48,16 @@ func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{
"allowRemote": downloadRemote,
})

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, r.Context(), log)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, bearerToken, r.Context(), log)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
} else if err == common.ErrMediaQuarantined {
return api.NotFoundError() // We lie for security
} else if err == common.ErrFailedAuthCheck {
return api.AuthFailed()
}
log.Error("Unexpected error locating media: " + err.Error())
return api.InternalServerError("Unexpected Error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/controllers/download_controller"
"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)

func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
Expand All @@ -20,6 +22,10 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{
mediaId := params["mediaId"]
allowRemote := r.URL.Query().Get("allow_remote")

encryptedToken := util.GetMediaBearerTokenFromRequest(r)
appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
bearerToken := &types.BearerToken{EncryptedToken: encryptedToken, AppserviceUserId: appserviceUserId}

downloadRemote := true
if allowRemote != "" {
parsedFlag, err := strconv.ParseBool(allowRemote)
Expand All @@ -37,14 +43,16 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{

// TODO: There's a lot of room for improvement here. Instead of re-uploading media, we should just update the DB.

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, r.Context(), log)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, bearerToken, r.Context(), log)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
} else if err == common.ErrMediaQuarantined {
return api.NotFoundError() // We lie for security
} else if err == common.ErrFailedAuthCheck {
return api.AuthFailed()
}
log.Error("Unexpected error locating media: " + err.Error())
return api.InternalServerError("Unexpected Error")
Expand All @@ -56,7 +64,7 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{
return &r0.MediaUploadedResponse{ContentUri: streamedMedia.Media.MxcUri()}
}

newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, user.UserId, r.Host, r.Context(), log)
newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, user.UserId, r.Host, true, r.Context(), log)
if err != nil {
if err == common.ErrMediaNotAllowed {
return api.BadRequest("Media content type not allowed on this server")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ var ErrHostNotFound = errors.New("host not found")
var ErrHostBlacklisted = errors.New("host not allowed")
var ErrMediaNotAllowed = errors.New("media content type not allowed")
var ErrMediaQuarantined = errors.New("media quarantined")
var ErrFailedAuthCheck = errors.New("failed auth checks")
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,30 @@ import (
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/internal_cache"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)

var localCache = cache.New(30*time.Second, 60*time.Second)

func GetMedia(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.StreamedMedia, error) {
func GetMedia(origin string, mediaId string, downloadRemote bool, bearerToken *types.BearerToken, ctx context.Context, log *logrus.Entry) (*types.StreamedMedia, error) {
media, err := FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if err != nil {
return nil, err
}

if media.ContentToken != nil {
log.Info("Media is protected by a content token - verifying request")
userId, err := matrix.GetUserIdFromBearerToken(ctx, bearerToken, *media.ContentToken)
if err != nil {
return nil, err
}

log.Info("Access token belongs to ", userId)
}

if media.Quarantined {
log.Warn("Quarantined media accessed")
return nil, common.ErrMediaQuarantined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
defer downloaded.Contents.Close()

userId := upload_controller.NoApplicableUploadUser
media, err := upload_controller.StoreDirect(downloaded.Contents, downloaded.ContentType, downloaded.DesiredFilename, userId, info.origin, info.mediaId, ctx, log)
media, err := upload_controller.StoreDirect(downloaded.Contents, downloaded.ContentType, downloaded.DesiredFilename, userId, info.origin, info.mediaId, nil, ctx, log)
if err != nil {
return &downloadResponse{err: err}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} {
// Store the thumbnail, if there is one
if preview.Image != nil && !upload_controller.IsRequestTooLarge(preview.Image.ContentLength, preview.Image.ContentLengthHeader) {
// UploadMedia will close the read stream for the thumbnail and dedupe the image
media, err := upload_controller.UploadMedia(preview.Image.Data, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, ctx, log)
media, err := upload_controller.UploadMedia(preview.Image.Data, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, true, ctx, log)
if err != nil {
log.Warn("Non-fatal error storing preview thumbnail: " + err.Error())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/controllers/download_controller"
"github.com/turt2live/matrix-media-repo/internal_cache"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
Expand All @@ -35,12 +36,22 @@ var animatedTypes = []string{"image/gif"}

var localCache = cache.New(30*time.Second, 60*time.Second)

func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.StreamedThumbnail, error) {
func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, bearerToken *types.BearerToken, ctx context.Context, log *logrus.Entry) (*types.StreamedThumbnail, error) {
media, err := download_controller.FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if err != nil {
return nil, err
}

if media.ContentToken != nil {
log.Info("Media is protected by a content token - verifying request")
userId, err := matrix.GetUserIdFromBearerToken(ctx, bearerToken, *media.ContentToken)
if err != nil {
return nil, err
}

log.Info("Access token belongs to ", userId)
}

if !util.ArrayContains(supportedThumbnailTypes, media.ContentType) {
log.Warn("Cannot generate thumbnail for " + media.ContentType + " because it is not supported")
return nil, errors.New("cannot generate thumbnail for this media's content type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func IsRequestTooLarge(contentLength int64, contentLengthHeader string) bool {
return false // We can only assume
}

func UploadMedia(contents io.ReadCloser, contentType string, filename string, userId string, origin string, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
func UploadMedia(contents io.ReadCloser, contentType string, filename string, userId string, origin string, isPublic bool, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
defer contents.Close()

var data io.Reader
Expand All @@ -52,10 +52,25 @@ func UploadMedia(contents io.ReadCloser, contentType string, filename string, us
return nil, err
}

return StoreDirect(data, contentType, filename, userId, origin, mediaId, ctx, log)
var contentToken *string
if !isPublic {
generatedToken, err := util.GenerateRandomString(128)
if err != nil {
return nil, err
}

contentToken = &generatedToken
}

stored, err := StoreDirect(data, contentType, filename, userId, origin, mediaId, contentToken, ctx, log)
if err != nil {
return nil, err
}

return stored, nil
}

func StoreDirect(contents io.Reader, contentType string, filename string, userId string, origin string, mediaId string, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
func StoreDirect(contents io.Reader, contentType string, filename string, userId string, origin string, mediaId string, contentToken *string, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
fileLocation, err := storage.PersistFile(contents, ctx, log)
if err != nil {
return nil, err
Expand All @@ -70,7 +85,7 @@ func StoreDirect(contents io.Reader, contentType string, filename string, userId

for _, allowedType := range config.Get().Uploads.AllowedTypes {
if !glob.Glob(allowedType, fileMime) {
log.Warn("Content type " + fileMime +" (reported as " + contentType+") is not allowed to be uploaded")
log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded")

os.Remove(fileLocation) // delete temp file
return nil, common.ErrMediaNotAllowed
Expand All @@ -93,19 +108,6 @@ func StoreDirect(contents io.Reader, contentType string, filename string, userId
if len(records) > 0 {
log.Info("Duplicate media for hash ", hash)

// If the user is a real user (ie: actually uploaded media), then we'll see if there's
// an exact duplicate that we can return. Otherwise we'll just pick the first record and
// clone that.
if userId != NoApplicableUploadUser {
for _, record := range records {
if record.UserId == userId && record.Origin == origin && record.ContentType == contentType {
log.Info("User has already uploaded this media before - returning unaltered media record")
os.Remove(fileLocation) // delete temp file
return record, nil
}
}
}

// We'll use the location from the first record
media := records[0]
media.Origin = origin
Expand All @@ -114,6 +116,7 @@ func StoreDirect(contents io.Reader, contentType string, filename string, userId
media.UploadName = filename
media.ContentType = contentType
media.CreationTs = util.NowMillis()
media.ContentToken = contentToken

err = db.Insert(media)
if err != nil {
Expand Down Expand Up @@ -145,15 +148,16 @@ func StoreDirect(contents io.Reader, contentType string, filename string, userId
log.Info("Persisting new media record")

media := &types.Media{
Origin: origin,
MediaId: mediaId,
UploadName: filename,
ContentType: contentType,
UserId: userId,
Sha256Hash: hash,
SizeBytes: fileSize,
Location: fileLocation,
CreationTs: util.NowMillis(),
Origin: origin,
MediaId: mediaId,
UploadName: filename,
ContentType: contentType,
UserId: userId,
Sha256Hash: hash,
SizeBytes: fileSize,
Location: fileLocation,
CreationTs: util.NowMillis(),
ContentToken: contentToken,
}

err = db.Insert(media)
Expand All @@ -163,4 +167,4 @@ func StoreDirect(contents io.Reader, contentType string, filename string, userId
}

return media, nil
}
}
Loading

0 comments on commit c164220

Please sign in to comment.