Bug 1950466 - Add Clustering Perftests for Suggest Similar Tabs r=rrando

Differential Revision: https://phabricator.services.mozilla.com/D239634
This commit is contained in:
Vasish Baungally
2025-03-02 00:00:51 +00:00
parent 49dde450e6
commit 6f2179152b
7 changed files with 393 additions and 144 deletions

View File

@@ -8,14 +8,6 @@ const HOST_PREFIX =
const CLUSTERING_TEST_IDS = ["pgh_trip", "gen_set_2", "animal"];
const {
getBestAnchorClusterInfo,
SmartTabGroupingManager,
CLUSTER_METHODS,
DIM_REDUCTION_METHODS,
ANCHOR_METHODS,
} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs");
async function getGroupScore(
clusterMethod,
umapMethod,
@@ -42,56 +34,6 @@ async function getGroupScore(
return total_score / iterations;
}
async function testAugmentGroup(
clusterMethod,
umapMethod,
tabs,
embeddings,
iterations = 1,
preGroupedTabIndices,
anchorMethod = ANCHOR_METHODS.FIXED,
silBoost = undefined
) {
const groupManager = new SmartTabGroupingManager();
groupManager.setAnchorMethod(anchorMethod);
if (silBoost !== undefined) {
groupManager.setSilBoost(silBoost);
}
const randFunc = simpleNumberSequence();
groupManager.setDataTitleKey("title");
groupManager.setClusteringMethod(clusterMethod);
groupManager.setDimensionReductionMethod(umapMethod);
const allScores = [];
for (let i = 0; i < iterations; i++) {
const groupingResult = await groupManager.generateClusters(
tabs,
embeddings,
0,
randFunc,
preGroupedTabIndices
);
const titleKey = "title";
const centralClusterTitles = new Set(
groupingResult.getAnchorCluster().tabs.map(a => a[titleKey])
);
groupingResult.getAnchorCluster().print();
const anchorTitleSet = new Set(
preGroupedTabIndices.map(a => tabs[a][titleKey])
);
Assert.equal(
centralClusterTitles.intersection(anchorTitleSet).size,
anchorTitleSet.size,
`All anchor indices in target cluster`
);
const scoreInfo = groupingResult.getAccuracyStatsForCluster(
"smart_group_label",
groupingResult.getAnchorCluster().tabs[0].smart_group_label
);
allScores.push(scoreInfo);
}
return averageStatsValues(allScores);
}
async function runClusteringTest(data, precomputedEmbeddings = null) {
if (!precomputedEmbeddings) {
shuffleArray(data, simpleNumberSequence(0));
@@ -118,36 +60,6 @@ async function runClusteringTest(data, precomputedEmbeddings = null) {
return 1;
}
async function runAnchorTabTest(
data,
precomputedEmbeddings = null,
anchorGroupIndices,
anchorMethod = ANCHOR_METHODS.FIXED,
silBoost = undefined
) {
const testParams = [[CLUSTER_METHODS.KMEANS]];
let scoreInfo;
for (let testP of testParams) {
scoreInfo = await testAugmentGroup(
testP[0],
testP[1],
data,
precomputedEmbeddings,
1,
anchorGroupIndices,
anchorMethod,
silBoost
);
}
if (testParams.length === 1) {
return scoreInfo;
}
console.warn(
"Test checks on score not enabled because we are testing multiple methods"
);
return null;
}
async function setup({ disabled = false, prefs = [] } = {}) {
await SpecialPowers.pushPrefEnv({
set: [
@@ -160,45 +72,6 @@ async function setup({ disabled = false, prefs = [] } = {}) {
});
}
function parseTsvStructured(tsvString) {
const rows = tsvString.trim().split("\n");
const keys = rows[0].split("\t");
const arrayOfDicts = rows.slice(1).map(row => {
const values = row.split("\t");
// Map keys to corresponding values
const dict = {};
keys.forEach((key, index) => {
dict[key] = values[index];
});
return dict;
});
return arrayOfDicts;
}
function parseTsvEmbeddings(tsvString) {
const rows = tsvString.trim().split("\n");
return rows.map(row => {
return row.split("\t").map(value => parseFloat(value));
});
}
function fetchFile(filename) {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
const url = `${HOST_PREFIX}${filename}`;
xhr.open("GET", url, true);
xhr.onload = () => {
if (xhr.status === 200) {
resolve(xhr.responseText);
} else {
reject(new Error(`Failed to fetch data: ${xhr.statusText}`));
}
};
xhr.onerror = () => reject(new Error(`Network error getting ${url}`));
xhr.send();
});
}
add_task(function testGetBestAnchorClusterInfo() {
const { anchorClusterIndex, numAnchorItemsInCluster } =
getBestAnchorClusterInfo(
@@ -224,9 +97,9 @@ add_task(async function testClustering() {
`${test_id}_labels.tsv`,
]);
for (const test of testSets) {
const rawEmbeddings = await fetchFile(test[0]);
const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]);
const embeddings = parseTsvEmbeddings(rawEmbeddings);
const rawLabels = await fetchFile(test[1]);
const rawLabels = await fetchFile(HOST_PREFIX, test[1]);
const labels = parseTsvStructured(rawLabels);
const score = await runClusteringTest(labels, embeddings);
Assert.greater(score, 0.5, `Clustering ok for dataset ${test[0]}`);
@@ -289,9 +162,9 @@ add_task(async function testAnchorClustering() {
const scoreInfo = [];
for (const test of testSets) {
const rawEmbeddings = await fetchFile(test[0]);
const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]);
const embeddings = parseTsvEmbeddings(rawEmbeddings);
const rawLabels = await fetchFile(test[1]);
const rawLabels = await fetchFile(HOST_PREFIX, test[1]);
const labels = parseTsvStructured(rawLabels);
const labelClusterList = labels.map(a => a[LABEL_DICT_KEY]);
const uniqueLabels = Array.from(new Set(labelClusterList));

View File

@@ -2,9 +2,6 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
const { SmartTabGroupingManager } = ChromeUtils.importESModule(
"resource:///modules/SmartTabGrouping.sys.mjs"
);
const { sinon } = ChromeUtils.importESModule(
"resource://testing-common/Sinon.sys.mjs"
);
@@ -218,7 +215,7 @@ add_task(async function test_pref_off_should_not_create_events() {
Assert.equal(
Glean.browserMlInteraction.smartTabTopic.testGetValue() ?? "none",
"none",
"No event if feature is off"
"No event if the feature is off"
);
cleanup();
});

View File

@@ -85,7 +85,7 @@ mozilla-smart-tab-topic:
fetch:
type: git
repo: https://huggingface.co/Mozilla/smart-tab-topic
revision: a537e5109ccc57a3a1e7acd71963078fbe03c21b
revision: 4d6d7848b9ee62ce4a6db08c69f6ab698cb671d1
path-prefix: "onnx-models/Mozilla/smart-tab-topic/main/"
artifact-name: smart-tab-topic.tar.zst

View File

@@ -176,3 +176,29 @@ ml-perf-autofill:
--output $MOZ_FETCHES_DIR/../artifacts
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js
ml-perf-smart-tab:
fetches:
fetch:
- ort.wasm
- ort.jsep.wasm
- mozilla-smart-tab-topic
- mozilla-smart-tab-emb
description: Run ML Smart Tab Model
treeherder:
symbol: perftest(win-ml-perf-smart-tab)
tier: 2
attributes:
batch: false
cron: false
run-on-projects: [autoland, mozilla-central]
run:
command: >-
mkdir -p $MOZ_FETCHES_DIR/../artifacts &&
cd $MOZ_FETCHES_DIR &&
python3 python/mozperftest/mozperftest/runner.py
--mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe
--flavor mochitest
--output $MOZ_FETCHES_DIR/../artifacts
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js

View File

@@ -176,3 +176,29 @@ ml-perf-autofill:
--output $MOZ_FETCHES_DIR/../artifacts
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js
ml-perf-smart-tab:
fetches:
fetch:
- ort.wasm
- ort.jsep.wasm
- mozilla-smart-tab-topic
- mozilla-smart-tab-emb
description: Run ML Smart Tab Model
treeherder:
symbol: perftest(win-ml-perf-smart-tab)
tier: 2
attributes:
batch: false
cron: false
run-on-projects: [autoland, mozilla-central]
run:
command: >-
mkdir -p $MOZ_FETCHES_DIR/../artifacts &&
cd $MOZ_FETCHES_DIR &&
python3 python/mozperftest/mozperftest/runner.py
--mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe
--flavor mochitest
--output $MOZ_FETCHES_DIR/../artifacts
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js

View File

@@ -41,11 +41,13 @@ const perfMetadata = {
},
};
requestLongerTimeout(120);
requestLongerTimeout(200);
/**
* Tests local Autofill model
*/
const { sinon } = ChromeUtils.importESModule(
"resource://testing-common/Sinon.sys.mjs"
);
// Topic model tests
add_task(async function test_ml_smart_tab_topic() {
const options = new PipelineOptions({
taskName: "text2text-generation",
@@ -73,11 +75,12 @@ add_task(async function test_ml_smart_tab_topic() {
name: "smart-tab-topic",
options,
request,
iterations: ITERATIONS,
iterations: 2,
addColdStart: true,
});
});
// Embedding model tests
async function testEmbedding(trackPeakMemory = false) {
const options = new PipelineOptions({
taskName: "feature-extraction",
@@ -127,7 +130,7 @@ async function testEmbedding(trackPeakMemory = false) {
name: "smart-tab-embedding",
options,
request,
iterations: ITERATIONS,
iterations: 2,
addColdStart: true,
trackPeakMemory,
peakMemoryInterval: 10,
@@ -138,6 +141,165 @@ add_task(async function test_ml_smart_tab_embedding() {
await testEmbedding(false);
});
add_task(async function test_ml_smart_tab_embedding_peak_mem() {
await testEmbedding(true);
// Clustering / Nearest Neighbor tests
const ROOT_URL =
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data/tab_grouping/";
/*
* Generate n random samples by loading existing labels and embeddings
*/
function generateSamples(labels, embeddings, n) {
let generatedLabels = [];
let generatedEmbeddings = [];
for (let i = 0; i < n; i++) {
const randomIndex = Math.floor(Math.random() * labels.length);
generatedLabels.push(labels[randomIndex]);
if (embeddings) {
generatedEmbeddings.push(embeddings[randomIndex]);
}
}
return {
labels: generatedLabels,
embeddings: generatedEmbeddings,
};
}
async function generateEmbeddings(textList) {
const options = new PipelineOptions({
taskName: "feature-extraction",
modelId: "Mozilla/smart-tab-embedding",
modelHubUrlTemplate: "{model}/{revision}",
modelRevision: "main",
dtype: "q8",
timeoutMS: 2 * 60 * 1000,
});
const requestInfo = {
inputArgs: textList,
runOptions: {
pooling: "mean",
normalize: true,
},
};
const request = {
args: [requestInfo.inputArgs],
options: requestInfo.runOptions,
};
const mlEngineParent = await EngineProcess.getMLEngineParent();
const engine = await mlEngineParent.getEngine(options);
const output = await engine.run(request);
return output;
}
const singleTabMetrics = {};
singleTabMetrics["SINGLE-TAB-LATENCY"] = [];
add_task(async function test_clustering() {
const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB");
const { cleanup } = await perfSetup({
prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]],
});
const stgManager = new SmartTabGroupingManager();
let generateEmbeddingsStub = sinon.stub(
SmartTabGroupingManager.prototype,
"_generateEmbeddings"
);
generateEmbeddingsStub.callsFake(async textList => {
return await generateEmbeddings(textList);
});
const labelsPath = `gen_set_2_labels.tsv`;
const rawLabels = await fetchFile(ROOT_URL, labelsPath);
let labels = parseTsvStructured(rawLabels);
labels = labels.map(l => ({ ...l, label: l.smart_group_label }));
const startTime = performance.now();
const similarTabs = await stgManager.findNearestNeighbors(
labels,
[1],
[],
0.3
);
const endTime = performance.now();
singleTabMetrics["SINGLE-TAB-LATENCY"].push(endTime - startTime);
const titles = similarTabs.map(s => s.label);
Assert.equal(
titles[0],
"Impact of Tourism on Local Communities - Google Scholar"
);
Assert.equal(
titles[1],
"Tourist Behavior and Decision Making: A Research Overview"
);
Assert.equal(titles[2], "Global Health Outlook - Reuters");
Assert.equal(titles[3], "Hotel Deals: Save Big on Hotels with Expedia");
reportMetrics(singleTabMetrics);
generateEmbeddingsStub.restore();
await EngineProcess.destroyMLEngine();
await cleanup();
});
const N_TABS = [10, 25, 50];
const methods = ["KMEANS_ANCHOR", "NEAREST_NEIGHBORS_ANCHOR"];
const nTabMetrics = {};
for (let method of methods) {
for (let n of N_TABS) {
if (method === "KMEANS_ANCHOR" && n > 25) {
break;
}
nTabMetrics[`${method}-${n}-TABS-latency`] = [];
}
}
add_task(async function test_n_clustering() {
const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB");
const { cleanup } = await perfSetup({
prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]],
});
const stgManager = new SmartTabGroupingManager();
let generateEmbeddingsStub = sinon.stub(
SmartTabGroupingManager.prototype,
"_generateEmbeddings"
);
generateEmbeddingsStub.callsFake(async textList => {
return await generateEmbeddings(textList);
});
const labelsPath = `gen_set_2_labels.tsv`;
const rawLabels = await fetchFile(ROOT_URL, labelsPath);
const labels = parseTsvStructured(rawLabels);
for (let n of N_TABS) {
for (let method of methods) {
for (let i = 0; i < 1; i++) {
const samples = generateSamples(labels, null, n);
let startTime = performance.now();
if (method === "KMEANS_ANCHOR" && n <= 50) {
await stgManager.generateClusters(
samples.labels,
null,
0,
null,
[0],
[]
);
} else if (method === "NEAREST_NEIGHBORS_ANCHOR") {
await stgManager.findNearestNeighbors(samples.labels, [0], []);
}
let endTime = performance.now();
const key = `${method}-${n}-TABS-latency`;
if (key in nTabMetrics) {
nTabMetrics[key].push(endTime - startTime);
}
await EngineProcess.destroyMLEngine();
}
}
}
reportMetrics(nTabMetrics);
generateEmbeddingsStub.restore();
await cleanup();
});

View File

@@ -2,6 +2,13 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
const {
SmartTabGroupingManager,
CLUSTER_METHODS,
ANCHOR_METHODS,
getBestAnchorClusterInfo,
} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs");
/**
* Checks if numbers are close up to decimalPoints decimal points
*
@@ -80,3 +87,161 @@ function averageStatsValues(itemArray) {
}
return result;
}
/**
* Read tsv file from string
*
* @param {string} tsvString string to read from
* @returns {object} Object with parsed tsv string
*/
function parseTsvStructured(tsvString) {
const rows = tsvString.trim().split("\n");
const keys = rows[0].split("\t");
const arrayOfDicts = rows.slice(1).map(row => {
const values = row.split("\t");
// Map keys to corresponding values
const dict = {};
keys.forEach((key, index) => {
dict[key] = values[index];
});
return dict;
});
return arrayOfDicts;
}
/**
* Read tsv string with embeddings
*
* @param {string} tsvString string with embeddings present
* @returns {object} Object containing the embeddings
*/
function parseTsvEmbeddings(tsvString) {
const rows = tsvString.trim().split("\n");
return rows.map(row => {
return row.split("\t").map(value => parseFloat(value));
});
}
/**
*
* @param {string} clusterMethod kmeans or kmeans with anchor
* @param {string} umapMethod umap or dbscan
* @param {object[]} tabs tabs to cluster
* @param {object[]} embeddings precomputed embeddings for the tabs
* @param {number} iterations number of iterations before stopping clustering
* @param {number[]} preGroupedTabIndices indices of tabs that are present in the group
* @param {string} anchorMethod fixed or drift anchor methods
* @param {number} silBoost what value to multiply silhouette score
* @returns {Promise<{object}>} average of metric results
*/
async function testAugmentGroup(
clusterMethod,
umapMethod,
tabs,
embeddings,
iterations = 1,
preGroupedTabIndices,
anchorMethod = ANCHOR_METHODS.FIXED,
silBoost = undefined
) {
const groupManager = new SmartTabGroupingManager();
groupManager.setAnchorMethod(anchorMethod);
if (silBoost !== undefined) {
groupManager.setSilBoost(silBoost);
}
const randFunc = simpleNumberSequence();
groupManager.setDataTitleKey("title");
groupManager.setClusteringMethod(clusterMethod);
groupManager.setDimensionReductionMethod(umapMethod);
const allScores = [];
for (let i = 0; i < iterations; i++) {
const groupingResult = await groupManager.generateClusters(
tabs,
embeddings,
0,
randFunc,
preGroupedTabIndices
);
const titleKey = "title";
const centralClusterTitles = new Set(
groupingResult.getAnchorCluster().tabs.map(a => a[titleKey])
);
groupingResult.getAnchorCluster().print();
const anchorTitleSet = new Set(
preGroupedTabIndices.map(a => tabs[a][titleKey])
);
Assert.equal(
centralClusterTitles.intersection(anchorTitleSet).size,
anchorTitleSet.size,
`All anchor indices in target cluster`
);
const scoreInfo = groupingResult.getAccuracyStatsForCluster(
"smart_group_label",
groupingResult.getAnchorCluster().tabs[0].smart_group_label
);
allScores.push(scoreInfo);
}
return averageStatsValues(allScores);
}
/**
* Runs clustering test with multiple anchor tabs
*
* @param {object[]} data tabs to run test on
* @param {object []} precomputedEmbeddings embeddings for the tabs
* @param {number[]} anchorGroupIndices indices of tabs already present in the group
* @param {string} anchorMethod fixed or drift anchor method
* @param {number} silBoost value with which to boost silhouette score
* @returns {Promise<{}|null>} metric stats from running the clustering test
*/
async function runAnchorTabTest(
data,
precomputedEmbeddings = null,
anchorGroupIndices,
anchorMethod = ANCHOR_METHODS.FIXED,
silBoost = undefined
) {
const testParams = [[CLUSTER_METHODS.KMEANS]];
let scoreInfo;
for (let testP of testParams) {
scoreInfo = await testAugmentGroup(
testP[0],
testP[1],
data,
precomputedEmbeddings,
1,
anchorGroupIndices,
anchorMethod,
silBoost
);
}
if (testParams.length === 1) {
return scoreInfo;
}
return null;
}
/**
* Fetches a local file from prefix and filename
*
* @param {string} host_prefix root data folder path
* @param {string} filename name of file
* @returns {Promise}
*/
function fetchFile(host_prefix, filename) {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
// const url = `${HOST_PREFIX}${filename}`;
const url = `${host_prefix}${filename}`;
xhr.open("GET", url, true);
xhr.onload = () => {
if (xhr.status === 200) {
resolve(xhr.responseText);
} else {
reject(new Error(`Failed to fetch data: ${xhr.statusText}`));
}
};
xhr.onerror = () => reject(new Error(`Network error getting ${url}`));
xhr.send();
});
}