diff --git a/src/geneweaver/api/controller/genesets.py b/src/geneweaver/api/controller/genesets.py index b3d958b..74b41d8 100644 --- a/src/geneweaver/api/controller/genesets.py +++ b/src/geneweaver/api/controller/genesets.py @@ -1,13 +1,12 @@ """Endpoints related to genesets.""" import json -import os import time from tempfile import TemporaryDirectory from typing import Optional, Set from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, StreamingResponse from geneweaver.api import dependencies as deps from geneweaver.api.schemas.apimodels import GeneValueReturn from geneweaver.api.schemas.auth import UserInternal @@ -189,7 +188,7 @@ def get_export_geneset_by_id_type( cursor: Optional[deps.Cursor] = Depends(deps.cursor), temp_dir: TemporaryDirectory = Depends(deps.get_temp_dir), gene_id_type: Optional[GeneIdentifier] = None, -) -> FileResponse: +) -> StreamingResponse: """Export geneset into JSON file. Search by ID and optional gene identifier type.""" timestr = time.strftime("%Y%m%d-%H%M%S") @@ -214,15 +213,19 @@ def get_export_geneset_by_id_type( geneset_filename = f"geneset_{geneset_id}_{timestr}.json" # Write the data to temp file - temp_file_path = os.path.join(temp_dir, geneset_filename) - with open(temp_file_path, "w") as f: - json.dump(response, f, default=str) + from io import StringIO + + buffer = StringIO() + + json.dump(response, buffer, default=str) + + buffer.seek(0) # Return as a download - return FileResponse( - path=temp_file_path, + return StreamingResponse( + buffer, media_type="application/octet-stream", - filename=geneset_filename, + headers={"Content-Disposition": f"attachment; filename={geneset_filename}"}, ) diff --git a/tests/controllers/test_genesets.py b/tests/controllers/test_genesets.py index 40f9f55..32c4745 100644 --- a/tests/controllers/test_genesets.py +++ b/tests/controllers/test_genesets.py @@ -66,7 +66,6 @@ def test_export_geneset_w_gene_id_type(mock_service_get_geneset_w_gene_id_type, response = client.get("/api/genesets/1234/file?gene_id_type=2") assert response.headers.get("content-type") == "application/octet-stream" - assert int(response.headers.get("content-length")) > 0 assert response.status_code == 200