diff --git a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go index 0c1dce6f8..7641109c3 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go @@ -16,7 +16,6 @@ package routing import ( "net/http" - "sync" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/common" @@ -42,7 +41,7 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg *config.MediaAPI })) activeRemoteRequests := &types.ActiveRemoteRequests{ - MXCToCond: map[string]*sync.Cond{}, + MXCToResult: map[string]*types.RemoteRequestResult{}, } r0mux.Handle("/download/{serverName}/{mediaId}", prometheus.InstrumentHandler("download", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { diff --git a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go index ac18f5fe2..82cc1d7c3 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go @@ -59,10 +59,18 @@ type MediaMetadata struct { UserID MatrixUserID } +// RemoteRequestResult is used for broadcasting the result of a request for a remote file to routines waiting on the condition +type RemoteRequestResult struct { + // Condition used for the requester to signal the result to all other routines waiting on this condition + Cond *sync.Cond + // Resulting HTTP status code from the request + Result int +} + // ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers // It is used for ensuring multiple requests for the same file do not clobber each other. type ActiveRemoteRequests struct { sync.Mutex // The string key is an mxc:// URL - MXCToCond map[string]*sync.Cond + MXCToResult map[string]*RemoteRequestResult } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 024f755b2..35eca6b6e 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -18,10 +18,15 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" + "net/url" "os" + "path/filepath" "regexp" "strconv" + "strings" + "sync" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -47,6 +52,10 @@ type downloadRequest struct { // Download implements /download // Files from this server (i.e. origin == cfg.ServerName) are served directly +// Files from remote servers (i.e. origin != cfg.ServerName) are cached locally. +// If they are present in the cache, they are served directly. +// If they are not present in the cache, they are obtained from the remote server and +// simultaneously served back to the client and written into the cache. func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { r := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ @@ -130,11 +139,8 @@ func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), } } - // TODO: If we do not have a record and the origin is remote, we need to fetch it and respond with that file - return &util.JSONResponse{ - Code: 404, - JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), - } + // If we do not have a record and the origin is remote, we need to fetch it and respond with that file + return r.respondFromRemoteFile(w, cfg, db, activeRemoteRequests) } // If we have a record, we can respond from the local file r.MediaMetadata = mediaMetadata @@ -200,3 +206,311 @@ func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePat } return nil } + +// respondFromRemoteFile fetches the remote file, caches it locally and responds from that local file +// A hash map of active remote requests to a struct containing a sync.Cond is used to only download remote files once, +// regardless of how many download requests are received. +// Returns a util.JSONResponse error in case of error +func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { + // Note: getMediaMetadataForRemoteFile uses mutexes and conditions from activeRemoteRequests + mediaMetadata, resErr := r.getMediaMetadataForRemoteFile(db, activeRemoteRequests) + if resErr != nil { + return resErr + } else if mediaMetadata != nil { + // If we have a record, we can respond from the local file + r.MediaMetadata = mediaMetadata + } else { + // If we do not have a record, we need to fetch the remote file first and then respond from the local file + // Note: getRemoteFile uses mutexes and conditions from activeRemoteRequests + if resErr := r.getRemoteFile(cfg.AbsBasePath, cfg.MaxFileSizeBytes, db, activeRemoteRequests); resErr != nil { + return resErr + } + } + return r.respondFromLocalFile(w, cfg.AbsBasePath) +} + +func (r *downloadRequest) getMediaMetadataForRemoteFile(db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) (*types.MediaMetadata, *util.JSONResponse) { + activeRemoteRequests.Lock() + defer activeRemoteRequests.Unlock() + + // check if we have a record of the media in our database + mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + if err != nil { + r.Logger.WithError(err).Error("Error querying the database.") + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + + if mediaMetadata != nil { + // If we have a record, we can respond from the local file + return mediaMetadata, nil + } + + // No record was found + + // Check if there is an active remote request for the file + mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { + r.Logger.Info("Waiting for another goroutine to fetch the remote file.") + + activeRemoteRequestResult.Cond.Wait() + activeRemoteRequests.Unlock() + // NOTE: there is still a deferred Unlock() that will unlock this + activeRemoteRequests.Lock() + + // check if we have a record of the media in our database + mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + if err != nil { + r.Logger.WithError(err).Error("Error querying the database.") + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + + if mediaMetadata != nil { + // If we have a record, we can respond from the local file + return mediaMetadata, nil + } + + // Note: if the result was 200, we shouldn't get here + switch activeRemoteRequestResult.Result { + case 404: + return nil, &util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("File not found."), + } + case 500: + r.Logger.Error("Other goroutine failed to fetch the remote file.") + resErr := jsonerror.InternalServerError() + return nil, &resErr + default: + r.Logger.Error("Other goroutine failed to fetch the remote file.") + return nil, &util.JSONResponse{ + Code: activeRemoteRequestResult.Result, + JSON: jsonerror.Unknown("Failed to fetch file from remote server."), + } + } + } + + // No active remote request so create one + activeRemoteRequests.MXCToResult[mxcURL] = &types.RemoteRequestResult{ + Cond: &sync.Cond{L: activeRemoteRequests}, + } + return nil, nil +} + +// getRemoteFile fetches the file from the remote server and stores its metadata in the database +// Only the owner of the activeRemoteRequestResult for this origin and media ID should call this function. +func (r *downloadRequest) getRemoteFile(absBasePath types.Path, maxFileSizeBytes types.FileSizeBytes, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { + // Wake up other goroutines after this function returns. + isError := true + var result int + defer func() { + if isError { + // If an error happens, the lock MUST NOT have been taken, isError MUST be true and so the lock is taken here. + activeRemoteRequests.Lock() + } + defer activeRemoteRequests.Unlock() + mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { + r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") + if result == 0 { + r.Logger.Error("Invalid result, treating as InternalServerError") + result = 500 + } + activeRemoteRequestResult.Result = result + activeRemoteRequestResult.Cond.Broadcast() + } + delete(activeRemoteRequests.MXCToResult, mxcURL) + }() + + finalPath, duplicate, resErr := r.fetchRemoteFile(absBasePath, maxFileSizeBytes) + if resErr != nil { + result = resErr.Code + return resErr + } + + // NOTE: Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. + // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. + // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. + // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. + // The locking below mitigates this situation. + + // NOTE: The following two lines MUST remain together! + // isError == true causes the lock to be taken in a deferred function! + activeRemoteRequests.Lock() + isError = false + + r.Logger.WithFields(log.Fields{ + "Base64Hash": r.MediaMetadata.Base64Hash, + "UploadName": r.MediaMetadata.UploadName, + "FileSizeBytes": r.MediaMetadata.FileSizeBytes, + "Content-Type": r.MediaMetadata.ContentType, + }).Info("Storing file metadata to media repository database") + + // FIXME: timeout db request + if err := db.StoreMediaMetadata(r.MediaMetadata); err != nil { + // If the file is a duplicate (has the same hash as an existing file) then + // there is valid metadata in the database for that file. As such we only + // remove the file if it is not a duplicate. + if duplicate == false { + finalDir := filepath.Dir(string(finalPath)) + fileutils.RemoveDir(types.Path(finalDir), r.Logger) + } + // NOTE: It should really not be possible to fail the uniqueness test here so + // there is no need to handle that separately + resErr := jsonerror.InternalServerError() + result = resErr.Code + return &resErr + } + + // TODO: generate thumbnails + + r.Logger.WithFields(log.Fields{ + "UploadName": r.MediaMetadata.UploadName, + "Base64Hash": r.MediaMetadata.Base64Hash, + "FileSizeBytes": r.MediaMetadata.FileSizeBytes, + "Content-Type": r.MediaMetadata.ContentType, + }).Infof("Remote file cached") + + result = 200 + return nil +} + +func (r *downloadRequest) fetchRemoteFile(absBasePath types.Path, maxFileSizeBytes types.FileSizeBytes) (types.Path, bool, *util.JSONResponse) { + r.Logger.Info("Fetching remote file") + + // create request for remote file + resp, resErr := r.createRemoteRequest() + if resErr != nil { + return "", false, resErr + } + defer resp.Body.Close() + + // get metadata from request and set metadata on response + contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + r.Logger.WithError(err).Warn("Failed to parse content length") + return "", false, &util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown("Invalid response from remote server"), + } + } + if contentLength > int64(maxFileSizeBytes) { + return "", false, &util.JSONResponse{ + Code: 413, + JSON: jsonerror.Unknown(fmt.Sprintf("Remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes)), + } + } + r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) + r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) + r.MediaMetadata.UploadName = types.Filename(contentDispositionToFilename(resp.Header.Get("Content-Disposition"))) + + r.Logger.Info("Transferring remote file") + + // The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a + // method of deduplicating files to save storage, as well as a way to conduct + // integrity checks on the file data in the repository. + // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(resp.Body, maxFileSizeBytes, absBasePath) + if err != nil { + r.Logger.WithError(err).WithFields(log.Fields{ + "MaxFileSizeBytes": maxFileSizeBytes, + }).Warn("Error while downloading file from remote server") + fileutils.RemoveDir(tmpDir, r.Logger) + return "", false, &util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown("File could not be downloaded from remote server"), + } + } + + r.Logger.Info("Remote file transferred") + + // It's possible the bytesWritten to the temporary file is different to the reported Content-Length from the remote + // request's response. bytesWritten is therefore used as it is what would be sent to clients when reading from the local + // file. + r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(bytesWritten) + r.MediaMetadata.Base64Hash = hash + + // The database is the source of truth so we need to have moved the file first + finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger) + if err != nil { + r.Logger.WithError(err).Error("Failed to move file.") + resErr := jsonerror.InternalServerError() + return "", false, &resErr + } + if duplicate { + r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate") + // Continue on to store the metadata in the database + } + + return types.Path(finalPath), duplicate, nil +} + +func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) { + dnsResult, err := gomatrixserverlib.LookupServer(r.MediaMetadata.Origin) + if err != nil { + if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.Timeout() { + return nil, &util.JSONResponse{ + Code: 504, + JSON: jsonerror.Unknown(fmt.Sprintf("DNS look up for homeserver at %v timed out", r.MediaMetadata.Origin)), + } + } + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + url := "https://" + strings.Trim(dnsResult.SRVRecords[0].Target, ".") + ":" + strconv.Itoa(int(dnsResult.SRVRecords[0].Port)) + + r.Logger.WithField("URL", url).Info("Connecting to remote") + + remoteReqAddr := url + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil) + if err != nil { + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + + remoteReq.Header.Set("Host", string(r.MediaMetadata.Origin)) + + client := http.Client{} + resp, err := client.Do(remoteReq) + if err != nil { + r.Logger.Warn("Failed to execute request for remote file") + return nil, &util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + } + } + + if resp.StatusCode != 200 { + if resp.StatusCode == 404 { + return nil, &util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + } + } + r.Logger.WithFields(log.Fields{ + "StatusCode": resp.StatusCode, + }).Warn("Received error response") + return nil, &util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + } + } + + return resp, nil +} + +var contentDispositionRegex = regexp.MustCompile("filename([*])?=(utf-8'')?([A-Za-z0-9._-]+)") + +func contentDispositionToFilename(contentDisposition string) types.Filename { + filename := "" + if matches := contentDispositionRegex.FindStringSubmatch(contentDisposition); len(matches) == 4 { + // Note: the filename should already be escaped. If not, unescape should be close to a no-op. This way filename is sure to be safe. + unescaped, err := url.PathUnescape(matches[3]) + if err != nil { + unescaped = matches[3] + } + filename = url.PathEscape(unescaped) + } + return types.Filename(filename) +}