From d0488ddda97e51a673b7d1091f253e1a0145ac5b Mon Sep 17 00:00:00 2001 From: Koen Vlaswinkel Date: Tue, 14 Nov 2023 16:12:39 +0100 Subject: [PATCH] Allow downloading multiple databases from GitHub This adds the option to download multiple databases from GitHub in the initial GitHub database download prompt. The databases will be downloaded concurrently. Unfortunately it doesn't seem possible to change the "OK" text in the quick pick to "Download", so I've left it as "OK" for now. --- .../src/databases/github-database-prompt.ts | 102 +++++++---- .../databases/github-database-prompt.test.ts | 168 ++++++++++++++---- 2 files changed, 201 insertions(+), 69 deletions(-) diff --git a/extensions/ql-vscode/src/databases/github-database-prompt.ts b/extensions/ql-vscode/src/databases/github-database-prompt.ts index f5cd79450dd..235af1f63bc 100644 --- a/extensions/ql-vscode/src/databases/github-database-prompt.ts +++ b/extensions/ql-vscode/src/databases/github-database-prompt.ts @@ -3,10 +3,7 @@ import { RestEndpointMethodTypes } from "@octokit/plugin-rest-endpoint-methods"; import { Octokit } from "@octokit/rest"; import { showNeverAskAgainDialog } from "../common/vscode/dialog"; import { getLanguageDisplayName } from "../common/query-language"; -import { - downloadGitHubDatabaseFromUrl, - promptForLanguage, -} from "./database-fetcher"; +import { downloadGitHubDatabaseFromUrl } from "./database-fetcher"; import { withProgress } from "../common/vscode/progress"; import { DatabaseManager } from "./local-databases"; import { CodeQLCliServer } from "../codeql-cli/cli"; @@ -77,38 +74,46 @@ export async function promptGitHubDatabaseDownload( return; } - const language = await promptForLanguage(languages, undefined); - if (!language) { + const selectedDatabases = await promptForDatabases(databases); + if (selectedDatabases.length === 0) { return; } - const database = databases.find((database) => database.language === language); - if (!database) { - return; - } - - await withProgress(async (progress) => { - await downloadGitHubDatabaseFromUrl( - database.url, - database.id, - database.created_at, - database.commit_oid ?? null, - owner, - repo, - octokit, - progress, - databaseManager, - storagePath, - cliServer, - true, - false, - ); - - await commandManager.execute("codeQLDatabases.focus"); - void window.showInformationMessage( - `Downloaded ${getLanguageDisplayName(language)} database from GitHub.`, - ); - }); + await Promise.all( + selectedDatabases.map((database) => + withProgress( + async (progress) => { + await downloadGitHubDatabaseFromUrl( + database.url, + database.id, + database.created_at, + database.commit_oid ?? null, + owner, + repo, + octokit, + progress, + databaseManager, + storagePath, + cliServer, + true, + false, + ); + + await commandManager.execute("codeQLDatabases.focus"); + void window.showInformationMessage( + `Downloaded ${getLanguageDisplayName( + database.language, + )} database from GitHub.`, + ); + }, + { + title: `Adding ${getLanguageDisplayName( + database.language, + )} database from GitHub`, + }, + ), + ), + ); } /** @@ -135,3 +140,34 @@ function joinLanguages(languages: string[]): string { return result; } + +async function promptForDatabases( + databases: CodeqlDatabase[], +): Promise { + if (databases.length === 1) { + return databases; + } + + const items = databases + .map((database) => { + const bytesToDisplayMB = `${(database.size / (1024 * 1024)).toFixed( + 1, + )} MB`; + + return { + label: getLanguageDisplayName(database.language), + description: bytesToDisplayMB, + database, + }; + }) + .sort((a, b) => a.label.localeCompare(b.label)); + + const selectedItems = await window.showQuickPick(items, { + title: "Select databases to download", + placeHolder: "Databases found in this repository", + ignoreFocusOut: true, + canPickMany: true, + }); + + return selectedItems?.map((selectedItem) => selectedItem.database) ?? []; +} diff --git a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts index 5790ef803a0..42ca917a125 100644 --- a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts +++ b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-prompt.test.ts @@ -1,6 +1,7 @@ import { faker } from "@faker-js/faker"; import { Octokit } from "@octokit/rest"; -import { mockedObject } from "../../utils/mocking.helpers"; +import { QuickPickItem, window } from "vscode"; +import { mockedObject, mockedQuickPickItem } from "../../utils/mocking.helpers"; import { CodeqlDatabase, promptGitHubDatabaseDownload, @@ -29,6 +30,7 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), @@ -38,9 +40,7 @@ describe("promptGitHubDatabaseDownload", () => { let showNeverAskAgainDialogSpy: jest.SpiedFunction< typeof dialog.showNeverAskAgainDialog >; - let promptForLanguageSpy: jest.SpiedFunction< - typeof databaseFetcher.promptForLanguage - >; + let showQuickPickSpy: jest.SpiedFunction; let downloadGitHubDatabaseFromUrlSpy: jest.SpiedFunction< typeof databaseFetcher.downloadGitHubDatabaseFromUrl >; @@ -56,9 +56,13 @@ describe("promptGitHubDatabaseDownload", () => { showNeverAskAgainDialogSpy = jest .spyOn(dialog, "showNeverAskAgainDialog") .mockResolvedValue("Connect"); - promptForLanguageSpy = jest - .spyOn(databaseFetcher, "promptForLanguage") - .mockResolvedValue(databases[0].language); + showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + ]), + ); downloadGitHubDatabaseFromUrlSpy = jest .spyOn(databaseFetcher, "downloadGitHubDatabaseFromUrl") .mockResolvedValue(undefined); @@ -93,7 +97,7 @@ describe("promptGitHubDatabaseDownload", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith(["swift"], undefined); + expect(showQuickPickSpy).not.toHaveBeenCalled(); expect(config.setDownload).not.toHaveBeenCalled(); }); @@ -180,28 +184,6 @@ describe("promptGitHubDatabaseDownload", () => { }); }); - describe("when not selecting language", () => { - beforeEach(() => { - promptForLanguageSpy.mockResolvedValue(undefined); - }); - - it("does not download the database", async () => { - await promptGitHubDatabaseDownload( - octokit, - owner, - repo, - databases, - config, - databaseManager, - storagePath, - cliServer, - commandManager, - ); - - expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); - }); - }); - describe("when there are multiple languages", () => { beforeEach(() => { databases = [ @@ -210,6 +192,7 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), @@ -219,16 +202,23 @@ describe("promptGitHubDatabaseDownload", () => { created_at: faker.date.past().toISOString(), commit_oid: null, language: "go", + size: 2930572385, url: faker.internet.url({ protocol: "https", }), }), ]; - - promptForLanguageSpy.mockResolvedValue(databases[1].language); }); - it("downloads the correct database", async () => { + it("downloads a single selected language", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[1], + }), + ]), + ); + await promptGitHubDatabaseDownload( octokit, owner, @@ -257,11 +247,117 @@ describe("promptGitHubDatabaseDownload", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith( - ["swift", "go"], - undefined, + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), ); expect(config.setDownload).not.toHaveBeenCalled(); }); + + it("downloads multiple selected languages", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + mockedObject({ + database: databases[1], + }), + ]), + ); + + await promptGitHubDatabaseDownload( + octokit, + owner, + repo, + databases, + config, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(2); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[0].url, + databases[0].id, + databases[0].created_at, + databases[0].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[1].url, + databases[1].id, + databases[1].created_at, + databases[1].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), + ); + expect(config.setDownload).not.toHaveBeenCalled(); + }); + + describe("when not selecting language", () => { + beforeEach(() => { + showQuickPickSpy.mockResolvedValue(undefined); + }); + + it("does not download the database", async () => { + await promptGitHubDatabaseDownload( + octokit, + owner, + repo, + databases, + config, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); + }); + }); }); });