diff --git a/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping.js b/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping.js index 4df5c4d0320e..46afcceae39d 100644 --- a/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping.js +++ b/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping.js @@ -8,14 +8,6 @@ const HOST_PREFIX = const CLUSTERING_TEST_IDS = ["pgh_trip", "gen_set_2", "animal"]; -const { - getBestAnchorClusterInfo, - SmartTabGroupingManager, - CLUSTER_METHODS, - DIM_REDUCTION_METHODS, - ANCHOR_METHODS, -} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs"); - async function getGroupScore( clusterMethod, umapMethod, @@ -42,56 +34,6 @@ async function getGroupScore( return total_score / iterations; } -async function testAugmentGroup( - clusterMethod, - umapMethod, - tabs, - embeddings, - iterations = 1, - preGroupedTabIndices, - anchorMethod = ANCHOR_METHODS.FIXED, - silBoost = undefined -) { - const groupManager = new SmartTabGroupingManager(); - groupManager.setAnchorMethod(anchorMethod); - if (silBoost !== undefined) { - groupManager.setSilBoost(silBoost); - } - const randFunc = simpleNumberSequence(); - groupManager.setDataTitleKey("title"); - groupManager.setClusteringMethod(clusterMethod); - groupManager.setDimensionReductionMethod(umapMethod); - const allScores = []; - for (let i = 0; i < iterations; i++) { - const groupingResult = await groupManager.generateClusters( - tabs, - embeddings, - 0, - randFunc, - preGroupedTabIndices - ); - const titleKey = "title"; - const centralClusterTitles = new Set( - groupingResult.getAnchorCluster().tabs.map(a => a[titleKey]) - ); - groupingResult.getAnchorCluster().print(); - const anchorTitleSet = new Set( - preGroupedTabIndices.map(a => tabs[a][titleKey]) - ); - Assert.equal( - centralClusterTitles.intersection(anchorTitleSet).size, - anchorTitleSet.size, - `All anchor indices in target cluster` - ); - const scoreInfo = groupingResult.getAccuracyStatsForCluster( - "smart_group_label", - groupingResult.getAnchorCluster().tabs[0].smart_group_label - ); - allScores.push(scoreInfo); - } - return averageStatsValues(allScores); -} - async function runClusteringTest(data, precomputedEmbeddings = null) { if (!precomputedEmbeddings) { shuffleArray(data, simpleNumberSequence(0)); @@ -118,36 +60,6 @@ async function runClusteringTest(data, precomputedEmbeddings = null) { return 1; } -async function runAnchorTabTest( - data, - precomputedEmbeddings = null, - anchorGroupIndices, - anchorMethod = ANCHOR_METHODS.FIXED, - silBoost = undefined -) { - const testParams = [[CLUSTER_METHODS.KMEANS]]; - let scoreInfo; - for (let testP of testParams) { - scoreInfo = await testAugmentGroup( - testP[0], - testP[1], - data, - precomputedEmbeddings, - 1, - anchorGroupIndices, - anchorMethod, - silBoost - ); - } - if (testParams.length === 1) { - return scoreInfo; - } - console.warn( - "Test checks on score not enabled because we are testing multiple methods" - ); - return null; -} - async function setup({ disabled = false, prefs = [] } = {}) { await SpecialPowers.pushPrefEnv({ set: [ @@ -160,45 +72,6 @@ async function setup({ disabled = false, prefs = [] } = {}) { }); } -function parseTsvStructured(tsvString) { - const rows = tsvString.trim().split("\n"); - const keys = rows[0].split("\t"); - const arrayOfDicts = rows.slice(1).map(row => { - const values = row.split("\t"); - // Map keys to corresponding values - const dict = {}; - keys.forEach((key, index) => { - dict[key] = values[index]; - }); - return dict; - }); - return arrayOfDicts; -} - -function parseTsvEmbeddings(tsvString) { - const rows = tsvString.trim().split("\n"); - return rows.map(row => { - return row.split("\t").map(value => parseFloat(value)); - }); -} - -function fetchFile(filename) { - return new Promise((resolve, reject) => { - const xhr = new XMLHttpRequest(); - const url = `${HOST_PREFIX}${filename}`; - xhr.open("GET", url, true); - xhr.onload = () => { - if (xhr.status === 200) { - resolve(xhr.responseText); - } else { - reject(new Error(`Failed to fetch data: ${xhr.statusText}`)); - } - }; - xhr.onerror = () => reject(new Error(`Network error getting ${url}`)); - xhr.send(); - }); -} - add_task(function testGetBestAnchorClusterInfo() { const { anchorClusterIndex, numAnchorItemsInCluster } = getBestAnchorClusterInfo( @@ -224,9 +97,9 @@ add_task(async function testClustering() { `${test_id}_labels.tsv`, ]); for (const test of testSets) { - const rawEmbeddings = await fetchFile(test[0]); + const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]); const embeddings = parseTsvEmbeddings(rawEmbeddings); - const rawLabels = await fetchFile(test[1]); + const rawLabels = await fetchFile(HOST_PREFIX, test[1]); const labels = parseTsvStructured(rawLabels); const score = await runClusteringTest(labels, embeddings); Assert.greater(score, 0.5, `Clustering ok for dataset ${test[0]}`); @@ -289,9 +162,9 @@ add_task(async function testAnchorClustering() { const scoreInfo = []; for (const test of testSets) { - const rawEmbeddings = await fetchFile(test[0]); + const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]); const embeddings = parseTsvEmbeddings(rawEmbeddings); - const rawLabels = await fetchFile(test[1]); + const rawLabels = await fetchFile(HOST_PREFIX, test[1]); const labels = parseTsvStructured(rawLabels); const labelClusterList = labels.map(a => a[LABEL_DICT_KEY]); const uniqueLabels = Array.from(new Set(labelClusterList)); diff --git a/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping_telemetry.js b/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping_telemetry.js index 48c80e947987..46fede507e46 100644 --- a/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping_telemetry.js +++ b/browser/components/tabbrowser/test/browser/smarttabgrouping/browser_tab_grouping_telemetry.js @@ -2,9 +2,6 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -const { SmartTabGroupingManager } = ChromeUtils.importESModule( - "resource:///modules/SmartTabGrouping.sys.mjs" -); const { sinon } = ChromeUtils.importESModule( "resource://testing-common/Sinon.sys.mjs" ); @@ -218,7 +215,7 @@ add_task(async function test_pref_off_should_not_create_events() { Assert.equal( Glean.browserMlInteraction.smartTabTopic.testGetValue() ?? "none", "none", - "No event if feature is off" + "No event if the feature is off" ); cleanup(); }); diff --git a/taskcluster/kinds/fetch/onnxruntime-web-fetch.yml b/taskcluster/kinds/fetch/onnxruntime-web-fetch.yml index eea9c5103edf..669f32a1337d 100644 --- a/taskcluster/kinds/fetch/onnxruntime-web-fetch.yml +++ b/taskcluster/kinds/fetch/onnxruntime-web-fetch.yml @@ -85,7 +85,7 @@ mozilla-smart-tab-topic: fetch: type: git repo: https://huggingface.co/Mozilla/smart-tab-topic - revision: a537e5109ccc57a3a1e7acd71963078fbe03c21b + revision: 4d6d7848b9ee62ce4a6db08c69f6ab698cb671d1 path-prefix: "onnx-models/Mozilla/smart-tab-topic/main/" artifact-name: smart-tab-topic.tar.zst diff --git a/taskcluster/kinds/perftest/windows11-24h2-ref.yml b/taskcluster/kinds/perftest/windows11-24h2-ref.yml index 6b60e3da6664..715e70ff8819 100644 --- a/taskcluster/kinds/perftest/windows11-24h2-ref.yml +++ b/taskcluster/kinds/perftest/windows11-24h2-ref.yml @@ -176,3 +176,29 @@ ml-perf-autofill: --output $MOZ_FETCHES_DIR/../artifacts --hooks toolkit/components/ml/tests/tools/hooks_local_hub.py toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js + +ml-perf-smart-tab: + fetches: + fetch: + - ort.wasm + - ort.jsep.wasm + - mozilla-smart-tab-topic + - mozilla-smart-tab-emb + description: Run ML Smart Tab Model + treeherder: + symbol: perftest(win-ml-perf-smart-tab) + tier: 2 + attributes: + batch: false + cron: false + run-on-projects: [autoland, mozilla-central] + run: + command: >- + mkdir -p $MOZ_FETCHES_DIR/../artifacts && + cd $MOZ_FETCHES_DIR && + python3 python/mozperftest/mozperftest/runner.py + --mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe + --flavor mochitest + --output $MOZ_FETCHES_DIR/../artifacts + --hooks toolkit/components/ml/tests/tools/hooks_local_hub.py + toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js diff --git a/taskcluster/kinds/perftest/windows11-ref.yml b/taskcluster/kinds/perftest/windows11-ref.yml index 76582cba8ddb..7657b2ebbb06 100644 --- a/taskcluster/kinds/perftest/windows11-ref.yml +++ b/taskcluster/kinds/perftest/windows11-ref.yml @@ -176,3 +176,29 @@ ml-perf-autofill: --output $MOZ_FETCHES_DIR/../artifacts --hooks toolkit/components/ml/tests/tools/hooks_local_hub.py toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js + +ml-perf-smart-tab: + fetches: + fetch: + - ort.wasm + - ort.jsep.wasm + - mozilla-smart-tab-topic + - mozilla-smart-tab-emb + description: Run ML Smart Tab Model + treeherder: + symbol: perftest(win-ml-perf-smart-tab) + tier: 2 + attributes: + batch: false + cron: false + run-on-projects: [autoland, mozilla-central] + run: + command: >- + mkdir -p $MOZ_FETCHES_DIR/../artifacts && + cd $MOZ_FETCHES_DIR && + python3 python/mozperftest/mozperftest/runner.py + --mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe + --flavor mochitest + --output $MOZ_FETCHES_DIR/../artifacts + --hooks toolkit/components/ml/tests/tools/hooks_local_hub.py + toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js diff --git a/toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js b/toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js index 3871929fdfdd..484a2fe5dc57 100644 --- a/toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js +++ b/toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js @@ -41,11 +41,13 @@ const perfMetadata = { }, }; -requestLongerTimeout(120); +requestLongerTimeout(200); -/** - * Tests local Autofill model - */ +const { sinon } = ChromeUtils.importESModule( + "resource://testing-common/Sinon.sys.mjs" +); + +// Topic model tests add_task(async function test_ml_smart_tab_topic() { const options = new PipelineOptions({ taskName: "text2text-generation", @@ -73,11 +75,12 @@ add_task(async function test_ml_smart_tab_topic() { name: "smart-tab-topic", options, request, - iterations: ITERATIONS, + iterations: 2, addColdStart: true, }); }); +// Embedding model tests async function testEmbedding(trackPeakMemory = false) { const options = new PipelineOptions({ taskName: "feature-extraction", @@ -127,7 +130,7 @@ async function testEmbedding(trackPeakMemory = false) { name: "smart-tab-embedding", options, request, - iterations: ITERATIONS, + iterations: 2, addColdStart: true, trackPeakMemory, peakMemoryInterval: 10, @@ -138,6 +141,165 @@ add_task(async function test_ml_smart_tab_embedding() { await testEmbedding(false); }); -add_task(async function test_ml_smart_tab_embedding_peak_mem() { - await testEmbedding(true); +// Clustering / Nearest Neighbor tests +const ROOT_URL = + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data/tab_grouping/"; + +/* + * Generate n random samples by loading existing labels and embeddings + */ +function generateSamples(labels, embeddings, n) { + let generatedLabels = []; + let generatedEmbeddings = []; + for (let i = 0; i < n; i++) { + const randomIndex = Math.floor(Math.random() * labels.length); + generatedLabels.push(labels[randomIndex]); + if (embeddings) { + generatedEmbeddings.push(embeddings[randomIndex]); + } + } + return { + labels: generatedLabels, + embeddings: generatedEmbeddings, + }; +} + +async function generateEmbeddings(textList) { + const options = new PipelineOptions({ + taskName: "feature-extraction", + modelId: "Mozilla/smart-tab-embedding", + modelHubUrlTemplate: "{model}/{revision}", + modelRevision: "main", + dtype: "q8", + timeoutMS: 2 * 60 * 1000, + }); + const requestInfo = { + inputArgs: textList, + runOptions: { + pooling: "mean", + normalize: true, + }, + }; + + const request = { + args: [requestInfo.inputArgs], + options: requestInfo.runOptions, + }; + const mlEngineParent = await EngineProcess.getMLEngineParent(); + const engine = await mlEngineParent.getEngine(options); + const output = await engine.run(request); + return output; +} + +const singleTabMetrics = {}; +singleTabMetrics["SINGLE-TAB-LATENCY"] = []; + +add_task(async function test_clustering() { + const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB"); + const { cleanup } = await perfSetup({ + prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]], + }); + + const stgManager = new SmartTabGroupingManager(); + + let generateEmbeddingsStub = sinon.stub( + SmartTabGroupingManager.prototype, + "_generateEmbeddings" + ); + generateEmbeddingsStub.callsFake(async textList => { + return await generateEmbeddings(textList); + }); + + const labelsPath = `gen_set_2_labels.tsv`; + const rawLabels = await fetchFile(ROOT_URL, labelsPath); + let labels = parseTsvStructured(rawLabels); + labels = labels.map(l => ({ ...l, label: l.smart_group_label })); + const startTime = performance.now(); + const similarTabs = await stgManager.findNearestNeighbors( + labels, + [1], + [], + 0.3 + ); + const endTime = performance.now(); + singleTabMetrics["SINGLE-TAB-LATENCY"].push(endTime - startTime); + const titles = similarTabs.map(s => s.label); + Assert.equal( + titles[0], + "Impact of Tourism on Local Communities - Google Scholar" + ); + Assert.equal( + titles[1], + "Tourist Behavior and Decision Making: A Research Overview" + ); + Assert.equal(titles[2], "Global Health Outlook - Reuters"); + Assert.equal(titles[3], "Hotel Deals: Save Big on Hotels with Expedia"); + reportMetrics(singleTabMetrics); + generateEmbeddingsStub.restore(); + await EngineProcess.destroyMLEngine(); + await cleanup(); +}); + +const N_TABS = [10, 25, 50]; +const methods = ["KMEANS_ANCHOR", "NEAREST_NEIGHBORS_ANCHOR"]; +const nTabMetrics = {}; + +for (let method of methods) { + for (let n of N_TABS) { + if (method === "KMEANS_ANCHOR" && n > 25) { + break; + } + nTabMetrics[`${method}-${n}-TABS-latency`] = []; + } +} + +add_task(async function test_n_clustering() { + const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB"); + const { cleanup } = await perfSetup({ + prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]], + }); + + const stgManager = new SmartTabGroupingManager(); + + let generateEmbeddingsStub = sinon.stub( + SmartTabGroupingManager.prototype, + "_generateEmbeddings" + ); + generateEmbeddingsStub.callsFake(async textList => { + return await generateEmbeddings(textList); + }); + + const labelsPath = `gen_set_2_labels.tsv`; + const rawLabels = await fetchFile(ROOT_URL, labelsPath); + const labels = parseTsvStructured(rawLabels); + + for (let n of N_TABS) { + for (let method of methods) { + for (let i = 0; i < 1; i++) { + const samples = generateSamples(labels, null, n); + let startTime = performance.now(); + if (method === "KMEANS_ANCHOR" && n <= 50) { + await stgManager.generateClusters( + samples.labels, + null, + 0, + null, + [0], + [] + ); + } else if (method === "NEAREST_NEIGHBORS_ANCHOR") { + await stgManager.findNearestNeighbors(samples.labels, [0], []); + } + let endTime = performance.now(); + const key = `${method}-${n}-TABS-latency`; + if (key in nTabMetrics) { + nTabMetrics[key].push(endTime - startTime); + } + await EngineProcess.destroyMLEngine(); + } + } + } + reportMetrics(nTabMetrics); + generateEmbeddingsStub.restore(); + await cleanup(); }); diff --git a/toolkit/components/ml/tests/browser/shared-head.js b/toolkit/components/ml/tests/browser/shared-head.js index 7b80034760d9..f3e19aec7d3a 100644 --- a/toolkit/components/ml/tests/browser/shared-head.js +++ b/toolkit/components/ml/tests/browser/shared-head.js @@ -2,6 +2,13 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +const { + SmartTabGroupingManager, + CLUSTER_METHODS, + ANCHOR_METHODS, + getBestAnchorClusterInfo, +} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs"); + /** * Checks if numbers are close up to decimalPoints decimal points * @@ -80,3 +87,161 @@ function averageStatsValues(itemArray) { } return result; } + +/** + * Read tsv file from string + * + * @param {string} tsvString string to read from + * @returns {object} Object with parsed tsv string + */ +function parseTsvStructured(tsvString) { + const rows = tsvString.trim().split("\n"); + const keys = rows[0].split("\t"); + const arrayOfDicts = rows.slice(1).map(row => { + const values = row.split("\t"); + // Map keys to corresponding values + const dict = {}; + keys.forEach((key, index) => { + dict[key] = values[index]; + }); + return dict; + }); + return arrayOfDicts; +} + +/** + * Read tsv string with embeddings + * + * @param {string} tsvString string with embeddings present + * @returns {object} Object containing the embeddings + */ +function parseTsvEmbeddings(tsvString) { + const rows = tsvString.trim().split("\n"); + return rows.map(row => { + return row.split("\t").map(value => parseFloat(value)); + }); +} + +/** + * + * @param {string} clusterMethod kmeans or kmeans with anchor + * @param {string} umapMethod umap or dbscan + * @param {object[]} tabs tabs to cluster + * @param {object[]} embeddings precomputed embeddings for the tabs + * @param {number} iterations number of iterations before stopping clustering + * @param {number[]} preGroupedTabIndices indices of tabs that are present in the group + * @param {string} anchorMethod fixed or drift anchor methods + * @param {number} silBoost what value to multiply silhouette score + * @returns {Promise<{object}>} average of metric results + */ +async function testAugmentGroup( + clusterMethod, + umapMethod, + tabs, + embeddings, + iterations = 1, + preGroupedTabIndices, + anchorMethod = ANCHOR_METHODS.FIXED, + silBoost = undefined +) { + const groupManager = new SmartTabGroupingManager(); + groupManager.setAnchorMethod(anchorMethod); + if (silBoost !== undefined) { + groupManager.setSilBoost(silBoost); + } + const randFunc = simpleNumberSequence(); + groupManager.setDataTitleKey("title"); + groupManager.setClusteringMethod(clusterMethod); + groupManager.setDimensionReductionMethod(umapMethod); + const allScores = []; + for (let i = 0; i < iterations; i++) { + const groupingResult = await groupManager.generateClusters( + tabs, + embeddings, + 0, + randFunc, + preGroupedTabIndices + ); + const titleKey = "title"; + const centralClusterTitles = new Set( + groupingResult.getAnchorCluster().tabs.map(a => a[titleKey]) + ); + groupingResult.getAnchorCluster().print(); + const anchorTitleSet = new Set( + preGroupedTabIndices.map(a => tabs[a][titleKey]) + ); + Assert.equal( + centralClusterTitles.intersection(anchorTitleSet).size, + anchorTitleSet.size, + `All anchor indices in target cluster` + ); + const scoreInfo = groupingResult.getAccuracyStatsForCluster( + "smart_group_label", + groupingResult.getAnchorCluster().tabs[0].smart_group_label + ); + allScores.push(scoreInfo); + } + return averageStatsValues(allScores); +} + +/** + * Runs clustering test with multiple anchor tabs + * + * @param {object[]} data tabs to run test on + * @param {object []} precomputedEmbeddings embeddings for the tabs + * @param {number[]} anchorGroupIndices indices of tabs already present in the group + * @param {string} anchorMethod fixed or drift anchor method + * @param {number} silBoost value with which to boost silhouette score + * @returns {Promise<{}|null>} metric stats from running the clustering test + */ +async function runAnchorTabTest( + data, + precomputedEmbeddings = null, + anchorGroupIndices, + anchorMethod = ANCHOR_METHODS.FIXED, + silBoost = undefined +) { + const testParams = [[CLUSTER_METHODS.KMEANS]]; + let scoreInfo; + for (let testP of testParams) { + scoreInfo = await testAugmentGroup( + testP[0], + testP[1], + data, + precomputedEmbeddings, + 1, + anchorGroupIndices, + anchorMethod, + silBoost + ); + } + if (testParams.length === 1) { + return scoreInfo; + } + return null; +} + +/** + * Fetches a local file from prefix and filename + * + * @param {string} host_prefix root data folder path + * @param {string} filename name of file + * @returns {Promise} + */ +function fetchFile(host_prefix, filename) { + return new Promise((resolve, reject) => { + const xhr = new XMLHttpRequest(); + // const url = `${HOST_PREFIX}${filename}`; + const url = `${host_prefix}${filename}`; + xhr.open("GET", url, true); + xhr.onload = () => { + if (xhr.status === 200) { + resolve(xhr.responseText); + } else { + reject(new Error(`Failed to fetch data: ${xhr.statusText}`)); + } + }; + xhr.onerror = () => reject(new Error(`Network error getting ${url}`)); + xhr.send(); + }); +}