Bug 1950466 - Add Clustering Perftests for Suggest Similar Tabs r=rrando
Differential Revision: https://phabricator.services.mozilla.com/D239634
This commit is contained in:
@@ -8,14 +8,6 @@ const HOST_PREFIX =
|
||||
|
||||
const CLUSTERING_TEST_IDS = ["pgh_trip", "gen_set_2", "animal"];
|
||||
|
||||
const {
|
||||
getBestAnchorClusterInfo,
|
||||
SmartTabGroupingManager,
|
||||
CLUSTER_METHODS,
|
||||
DIM_REDUCTION_METHODS,
|
||||
ANCHOR_METHODS,
|
||||
} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs");
|
||||
|
||||
async function getGroupScore(
|
||||
clusterMethod,
|
||||
umapMethod,
|
||||
@@ -42,56 +34,6 @@ async function getGroupScore(
|
||||
return total_score / iterations;
|
||||
}
|
||||
|
||||
async function testAugmentGroup(
|
||||
clusterMethod,
|
||||
umapMethod,
|
||||
tabs,
|
||||
embeddings,
|
||||
iterations = 1,
|
||||
preGroupedTabIndices,
|
||||
anchorMethod = ANCHOR_METHODS.FIXED,
|
||||
silBoost = undefined
|
||||
) {
|
||||
const groupManager = new SmartTabGroupingManager();
|
||||
groupManager.setAnchorMethod(anchorMethod);
|
||||
if (silBoost !== undefined) {
|
||||
groupManager.setSilBoost(silBoost);
|
||||
}
|
||||
const randFunc = simpleNumberSequence();
|
||||
groupManager.setDataTitleKey("title");
|
||||
groupManager.setClusteringMethod(clusterMethod);
|
||||
groupManager.setDimensionReductionMethod(umapMethod);
|
||||
const allScores = [];
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
const groupingResult = await groupManager.generateClusters(
|
||||
tabs,
|
||||
embeddings,
|
||||
0,
|
||||
randFunc,
|
||||
preGroupedTabIndices
|
||||
);
|
||||
const titleKey = "title";
|
||||
const centralClusterTitles = new Set(
|
||||
groupingResult.getAnchorCluster().tabs.map(a => a[titleKey])
|
||||
);
|
||||
groupingResult.getAnchorCluster().print();
|
||||
const anchorTitleSet = new Set(
|
||||
preGroupedTabIndices.map(a => tabs[a][titleKey])
|
||||
);
|
||||
Assert.equal(
|
||||
centralClusterTitles.intersection(anchorTitleSet).size,
|
||||
anchorTitleSet.size,
|
||||
`All anchor indices in target cluster`
|
||||
);
|
||||
const scoreInfo = groupingResult.getAccuracyStatsForCluster(
|
||||
"smart_group_label",
|
||||
groupingResult.getAnchorCluster().tabs[0].smart_group_label
|
||||
);
|
||||
allScores.push(scoreInfo);
|
||||
}
|
||||
return averageStatsValues(allScores);
|
||||
}
|
||||
|
||||
async function runClusteringTest(data, precomputedEmbeddings = null) {
|
||||
if (!precomputedEmbeddings) {
|
||||
shuffleArray(data, simpleNumberSequence(0));
|
||||
@@ -118,36 +60,6 @@ async function runClusteringTest(data, precomputedEmbeddings = null) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
async function runAnchorTabTest(
|
||||
data,
|
||||
precomputedEmbeddings = null,
|
||||
anchorGroupIndices,
|
||||
anchorMethod = ANCHOR_METHODS.FIXED,
|
||||
silBoost = undefined
|
||||
) {
|
||||
const testParams = [[CLUSTER_METHODS.KMEANS]];
|
||||
let scoreInfo;
|
||||
for (let testP of testParams) {
|
||||
scoreInfo = await testAugmentGroup(
|
||||
testP[0],
|
||||
testP[1],
|
||||
data,
|
||||
precomputedEmbeddings,
|
||||
1,
|
||||
anchorGroupIndices,
|
||||
anchorMethod,
|
||||
silBoost
|
||||
);
|
||||
}
|
||||
if (testParams.length === 1) {
|
||||
return scoreInfo;
|
||||
}
|
||||
console.warn(
|
||||
"Test checks on score not enabled because we are testing multiple methods"
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
async function setup({ disabled = false, prefs = [] } = {}) {
|
||||
await SpecialPowers.pushPrefEnv({
|
||||
set: [
|
||||
@@ -160,45 +72,6 @@ async function setup({ disabled = false, prefs = [] } = {}) {
|
||||
});
|
||||
}
|
||||
|
||||
function parseTsvStructured(tsvString) {
|
||||
const rows = tsvString.trim().split("\n");
|
||||
const keys = rows[0].split("\t");
|
||||
const arrayOfDicts = rows.slice(1).map(row => {
|
||||
const values = row.split("\t");
|
||||
// Map keys to corresponding values
|
||||
const dict = {};
|
||||
keys.forEach((key, index) => {
|
||||
dict[key] = values[index];
|
||||
});
|
||||
return dict;
|
||||
});
|
||||
return arrayOfDicts;
|
||||
}
|
||||
|
||||
function parseTsvEmbeddings(tsvString) {
|
||||
const rows = tsvString.trim().split("\n");
|
||||
return rows.map(row => {
|
||||
return row.split("\t").map(value => parseFloat(value));
|
||||
});
|
||||
}
|
||||
|
||||
function fetchFile(filename) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const xhr = new XMLHttpRequest();
|
||||
const url = `${HOST_PREFIX}${filename}`;
|
||||
xhr.open("GET", url, true);
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
resolve(xhr.responseText);
|
||||
} else {
|
||||
reject(new Error(`Failed to fetch data: ${xhr.statusText}`));
|
||||
}
|
||||
};
|
||||
xhr.onerror = () => reject(new Error(`Network error getting ${url}`));
|
||||
xhr.send();
|
||||
});
|
||||
}
|
||||
|
||||
add_task(function testGetBestAnchorClusterInfo() {
|
||||
const { anchorClusterIndex, numAnchorItemsInCluster } =
|
||||
getBestAnchorClusterInfo(
|
||||
@@ -224,9 +97,9 @@ add_task(async function testClustering() {
|
||||
`${test_id}_labels.tsv`,
|
||||
]);
|
||||
for (const test of testSets) {
|
||||
const rawEmbeddings = await fetchFile(test[0]);
|
||||
const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]);
|
||||
const embeddings = parseTsvEmbeddings(rawEmbeddings);
|
||||
const rawLabels = await fetchFile(test[1]);
|
||||
const rawLabels = await fetchFile(HOST_PREFIX, test[1]);
|
||||
const labels = parseTsvStructured(rawLabels);
|
||||
const score = await runClusteringTest(labels, embeddings);
|
||||
Assert.greater(score, 0.5, `Clustering ok for dataset ${test[0]}`);
|
||||
@@ -289,9 +162,9 @@ add_task(async function testAnchorClustering() {
|
||||
const scoreInfo = [];
|
||||
|
||||
for (const test of testSets) {
|
||||
const rawEmbeddings = await fetchFile(test[0]);
|
||||
const rawEmbeddings = await fetchFile(HOST_PREFIX, test[0]);
|
||||
const embeddings = parseTsvEmbeddings(rawEmbeddings);
|
||||
const rawLabels = await fetchFile(test[1]);
|
||||
const rawLabels = await fetchFile(HOST_PREFIX, test[1]);
|
||||
const labels = parseTsvStructured(rawLabels);
|
||||
const labelClusterList = labels.map(a => a[LABEL_DICT_KEY]);
|
||||
const uniqueLabels = Array.from(new Set(labelClusterList));
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
const { SmartTabGroupingManager } = ChromeUtils.importESModule(
|
||||
"resource:///modules/SmartTabGrouping.sys.mjs"
|
||||
);
|
||||
const { sinon } = ChromeUtils.importESModule(
|
||||
"resource://testing-common/Sinon.sys.mjs"
|
||||
);
|
||||
@@ -218,7 +215,7 @@ add_task(async function test_pref_off_should_not_create_events() {
|
||||
Assert.equal(
|
||||
Glean.browserMlInteraction.smartTabTopic.testGetValue() ?? "none",
|
||||
"none",
|
||||
"No event if feature is off"
|
||||
"No event if the feature is off"
|
||||
);
|
||||
cleanup();
|
||||
});
|
||||
|
||||
@@ -85,7 +85,7 @@ mozilla-smart-tab-topic:
|
||||
fetch:
|
||||
type: git
|
||||
repo: https://huggingface.co/Mozilla/smart-tab-topic
|
||||
revision: a537e5109ccc57a3a1e7acd71963078fbe03c21b
|
||||
revision: 4d6d7848b9ee62ce4a6db08c69f6ab698cb671d1
|
||||
path-prefix: "onnx-models/Mozilla/smart-tab-topic/main/"
|
||||
artifact-name: smart-tab-topic.tar.zst
|
||||
|
||||
|
||||
@@ -176,3 +176,29 @@ ml-perf-autofill:
|
||||
--output $MOZ_FETCHES_DIR/../artifacts
|
||||
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
|
||||
toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js
|
||||
|
||||
ml-perf-smart-tab:
|
||||
fetches:
|
||||
fetch:
|
||||
- ort.wasm
|
||||
- ort.jsep.wasm
|
||||
- mozilla-smart-tab-topic
|
||||
- mozilla-smart-tab-emb
|
||||
description: Run ML Smart Tab Model
|
||||
treeherder:
|
||||
symbol: perftest(win-ml-perf-smart-tab)
|
||||
tier: 2
|
||||
attributes:
|
||||
batch: false
|
||||
cron: false
|
||||
run-on-projects: [autoland, mozilla-central]
|
||||
run:
|
||||
command: >-
|
||||
mkdir -p $MOZ_FETCHES_DIR/../artifacts &&
|
||||
cd $MOZ_FETCHES_DIR &&
|
||||
python3 python/mozperftest/mozperftest/runner.py
|
||||
--mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe
|
||||
--flavor mochitest
|
||||
--output $MOZ_FETCHES_DIR/../artifacts
|
||||
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
|
||||
toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js
|
||||
|
||||
@@ -176,3 +176,29 @@ ml-perf-autofill:
|
||||
--output $MOZ_FETCHES_DIR/../artifacts
|
||||
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
|
||||
toolkit/components/ml/tests/browser/browser_ml_autofill_perf.js
|
||||
|
||||
ml-perf-smart-tab:
|
||||
fetches:
|
||||
fetch:
|
||||
- ort.wasm
|
||||
- ort.jsep.wasm
|
||||
- mozilla-smart-tab-topic
|
||||
- mozilla-smart-tab-emb
|
||||
description: Run ML Smart Tab Model
|
||||
treeherder:
|
||||
symbol: perftest(win-ml-perf-smart-tab)
|
||||
tier: 2
|
||||
attributes:
|
||||
batch: false
|
||||
cron: false
|
||||
run-on-projects: [autoland, mozilla-central]
|
||||
run:
|
||||
command: >-
|
||||
mkdir -p $MOZ_FETCHES_DIR/../artifacts &&
|
||||
cd $MOZ_FETCHES_DIR &&
|
||||
python3 python/mozperftest/mozperftest/runner.py
|
||||
--mochitest-binary ${MOZ_FETCHES_DIR}/firefox/firefox.exe
|
||||
--flavor mochitest
|
||||
--output $MOZ_FETCHES_DIR/../artifacts
|
||||
--hooks toolkit/components/ml/tests/tools/hooks_local_hub.py
|
||||
toolkit/components/ml/tests/browser/browser_ml_smart_tab_perf.js
|
||||
|
||||
@@ -41,11 +41,13 @@ const perfMetadata = {
|
||||
},
|
||||
};
|
||||
|
||||
requestLongerTimeout(120);
|
||||
requestLongerTimeout(200);
|
||||
|
||||
/**
|
||||
* Tests local Autofill model
|
||||
*/
|
||||
const { sinon } = ChromeUtils.importESModule(
|
||||
"resource://testing-common/Sinon.sys.mjs"
|
||||
);
|
||||
|
||||
// Topic model tests
|
||||
add_task(async function test_ml_smart_tab_topic() {
|
||||
const options = new PipelineOptions({
|
||||
taskName: "text2text-generation",
|
||||
@@ -73,11 +75,12 @@ add_task(async function test_ml_smart_tab_topic() {
|
||||
name: "smart-tab-topic",
|
||||
options,
|
||||
request,
|
||||
iterations: ITERATIONS,
|
||||
iterations: 2,
|
||||
addColdStart: true,
|
||||
});
|
||||
});
|
||||
|
||||
// Embedding model tests
|
||||
async function testEmbedding(trackPeakMemory = false) {
|
||||
const options = new PipelineOptions({
|
||||
taskName: "feature-extraction",
|
||||
@@ -127,7 +130,7 @@ async function testEmbedding(trackPeakMemory = false) {
|
||||
name: "smart-tab-embedding",
|
||||
options,
|
||||
request,
|
||||
iterations: ITERATIONS,
|
||||
iterations: 2,
|
||||
addColdStart: true,
|
||||
trackPeakMemory,
|
||||
peakMemoryInterval: 10,
|
||||
@@ -138,6 +141,165 @@ add_task(async function test_ml_smart_tab_embedding() {
|
||||
await testEmbedding(false);
|
||||
});
|
||||
|
||||
add_task(async function test_ml_smart_tab_embedding_peak_mem() {
|
||||
await testEmbedding(true);
|
||||
// Clustering / Nearest Neighbor tests
|
||||
const ROOT_URL =
|
||||
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data/tab_grouping/";
|
||||
|
||||
/*
|
||||
* Generate n random samples by loading existing labels and embeddings
|
||||
*/
|
||||
function generateSamples(labels, embeddings, n) {
|
||||
let generatedLabels = [];
|
||||
let generatedEmbeddings = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
const randomIndex = Math.floor(Math.random() * labels.length);
|
||||
generatedLabels.push(labels[randomIndex]);
|
||||
if (embeddings) {
|
||||
generatedEmbeddings.push(embeddings[randomIndex]);
|
||||
}
|
||||
}
|
||||
return {
|
||||
labels: generatedLabels,
|
||||
embeddings: generatedEmbeddings,
|
||||
};
|
||||
}
|
||||
|
||||
async function generateEmbeddings(textList) {
|
||||
const options = new PipelineOptions({
|
||||
taskName: "feature-extraction",
|
||||
modelId: "Mozilla/smart-tab-embedding",
|
||||
modelHubUrlTemplate: "{model}/{revision}",
|
||||
modelRevision: "main",
|
||||
dtype: "q8",
|
||||
timeoutMS: 2 * 60 * 1000,
|
||||
});
|
||||
const requestInfo = {
|
||||
inputArgs: textList,
|
||||
runOptions: {
|
||||
pooling: "mean",
|
||||
normalize: true,
|
||||
},
|
||||
};
|
||||
|
||||
const request = {
|
||||
args: [requestInfo.inputArgs],
|
||||
options: requestInfo.runOptions,
|
||||
};
|
||||
const mlEngineParent = await EngineProcess.getMLEngineParent();
|
||||
const engine = await mlEngineParent.getEngine(options);
|
||||
const output = await engine.run(request);
|
||||
return output;
|
||||
}
|
||||
|
||||
const singleTabMetrics = {};
|
||||
singleTabMetrics["SINGLE-TAB-LATENCY"] = [];
|
||||
|
||||
add_task(async function test_clustering() {
|
||||
const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB");
|
||||
const { cleanup } = await perfSetup({
|
||||
prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]],
|
||||
});
|
||||
|
||||
const stgManager = new SmartTabGroupingManager();
|
||||
|
||||
let generateEmbeddingsStub = sinon.stub(
|
||||
SmartTabGroupingManager.prototype,
|
||||
"_generateEmbeddings"
|
||||
);
|
||||
generateEmbeddingsStub.callsFake(async textList => {
|
||||
return await generateEmbeddings(textList);
|
||||
});
|
||||
|
||||
const labelsPath = `gen_set_2_labels.tsv`;
|
||||
const rawLabels = await fetchFile(ROOT_URL, labelsPath);
|
||||
let labels = parseTsvStructured(rawLabels);
|
||||
labels = labels.map(l => ({ ...l, label: l.smart_group_label }));
|
||||
const startTime = performance.now();
|
||||
const similarTabs = await stgManager.findNearestNeighbors(
|
||||
labels,
|
||||
[1],
|
||||
[],
|
||||
0.3
|
||||
);
|
||||
const endTime = performance.now();
|
||||
singleTabMetrics["SINGLE-TAB-LATENCY"].push(endTime - startTime);
|
||||
const titles = similarTabs.map(s => s.label);
|
||||
Assert.equal(
|
||||
titles[0],
|
||||
"Impact of Tourism on Local Communities - Google Scholar"
|
||||
);
|
||||
Assert.equal(
|
||||
titles[1],
|
||||
"Tourist Behavior and Decision Making: A Research Overview"
|
||||
);
|
||||
Assert.equal(titles[2], "Global Health Outlook - Reuters");
|
||||
Assert.equal(titles[3], "Hotel Deals: Save Big on Hotels with Expedia");
|
||||
reportMetrics(singleTabMetrics);
|
||||
generateEmbeddingsStub.restore();
|
||||
await EngineProcess.destroyMLEngine();
|
||||
await cleanup();
|
||||
});
|
||||
|
||||
const N_TABS = [10, 25, 50];
|
||||
const methods = ["KMEANS_ANCHOR", "NEAREST_NEIGHBORS_ANCHOR"];
|
||||
const nTabMetrics = {};
|
||||
|
||||
for (let method of methods) {
|
||||
for (let n of N_TABS) {
|
||||
if (method === "KMEANS_ANCHOR" && n > 25) {
|
||||
break;
|
||||
}
|
||||
nTabMetrics[`${method}-${n}-TABS-latency`] = [];
|
||||
}
|
||||
}
|
||||
|
||||
add_task(async function test_n_clustering() {
|
||||
const modelHubRootUrl = Services.env.get("MOZ_MODELS_HUB");
|
||||
const { cleanup } = await perfSetup({
|
||||
prefs: [["browser.ml.modelHubRootUrl", modelHubRootUrl]],
|
||||
});
|
||||
|
||||
const stgManager = new SmartTabGroupingManager();
|
||||
|
||||
let generateEmbeddingsStub = sinon.stub(
|
||||
SmartTabGroupingManager.prototype,
|
||||
"_generateEmbeddings"
|
||||
);
|
||||
generateEmbeddingsStub.callsFake(async textList => {
|
||||
return await generateEmbeddings(textList);
|
||||
});
|
||||
|
||||
const labelsPath = `gen_set_2_labels.tsv`;
|
||||
const rawLabels = await fetchFile(ROOT_URL, labelsPath);
|
||||
const labels = parseTsvStructured(rawLabels);
|
||||
|
||||
for (let n of N_TABS) {
|
||||
for (let method of methods) {
|
||||
for (let i = 0; i < 1; i++) {
|
||||
const samples = generateSamples(labels, null, n);
|
||||
let startTime = performance.now();
|
||||
if (method === "KMEANS_ANCHOR" && n <= 50) {
|
||||
await stgManager.generateClusters(
|
||||
samples.labels,
|
||||
null,
|
||||
0,
|
||||
null,
|
||||
[0],
|
||||
[]
|
||||
);
|
||||
} else if (method === "NEAREST_NEIGHBORS_ANCHOR") {
|
||||
await stgManager.findNearestNeighbors(samples.labels, [0], []);
|
||||
}
|
||||
let endTime = performance.now();
|
||||
const key = `${method}-${n}-TABS-latency`;
|
||||
if (key in nTabMetrics) {
|
||||
nTabMetrics[key].push(endTime - startTime);
|
||||
}
|
||||
await EngineProcess.destroyMLEngine();
|
||||
}
|
||||
}
|
||||
}
|
||||
reportMetrics(nTabMetrics);
|
||||
generateEmbeddingsStub.restore();
|
||||
await cleanup();
|
||||
});
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
const {
|
||||
SmartTabGroupingManager,
|
||||
CLUSTER_METHODS,
|
||||
ANCHOR_METHODS,
|
||||
getBestAnchorClusterInfo,
|
||||
} = ChromeUtils.importESModule("resource:///modules/SmartTabGrouping.sys.mjs");
|
||||
|
||||
/**
|
||||
* Checks if numbers are close up to decimalPoints decimal points
|
||||
*
|
||||
@@ -80,3 +87,161 @@ function averageStatsValues(itemArray) {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read tsv file from string
|
||||
*
|
||||
* @param {string} tsvString string to read from
|
||||
* @returns {object} Object with parsed tsv string
|
||||
*/
|
||||
function parseTsvStructured(tsvString) {
|
||||
const rows = tsvString.trim().split("\n");
|
||||
const keys = rows[0].split("\t");
|
||||
const arrayOfDicts = rows.slice(1).map(row => {
|
||||
const values = row.split("\t");
|
||||
// Map keys to corresponding values
|
||||
const dict = {};
|
||||
keys.forEach((key, index) => {
|
||||
dict[key] = values[index];
|
||||
});
|
||||
return dict;
|
||||
});
|
||||
return arrayOfDicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read tsv string with embeddings
|
||||
*
|
||||
* @param {string} tsvString string with embeddings present
|
||||
* @returns {object} Object containing the embeddings
|
||||
*/
|
||||
function parseTsvEmbeddings(tsvString) {
|
||||
const rows = tsvString.trim().split("\n");
|
||||
return rows.map(row => {
|
||||
return row.split("\t").map(value => parseFloat(value));
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} clusterMethod kmeans or kmeans with anchor
|
||||
* @param {string} umapMethod umap or dbscan
|
||||
* @param {object[]} tabs tabs to cluster
|
||||
* @param {object[]} embeddings precomputed embeddings for the tabs
|
||||
* @param {number} iterations number of iterations before stopping clustering
|
||||
* @param {number[]} preGroupedTabIndices indices of tabs that are present in the group
|
||||
* @param {string} anchorMethod fixed or drift anchor methods
|
||||
* @param {number} silBoost what value to multiply silhouette score
|
||||
* @returns {Promise<{object}>} average of metric results
|
||||
*/
|
||||
async function testAugmentGroup(
|
||||
clusterMethod,
|
||||
umapMethod,
|
||||
tabs,
|
||||
embeddings,
|
||||
iterations = 1,
|
||||
preGroupedTabIndices,
|
||||
anchorMethod = ANCHOR_METHODS.FIXED,
|
||||
silBoost = undefined
|
||||
) {
|
||||
const groupManager = new SmartTabGroupingManager();
|
||||
groupManager.setAnchorMethod(anchorMethod);
|
||||
if (silBoost !== undefined) {
|
||||
groupManager.setSilBoost(silBoost);
|
||||
}
|
||||
const randFunc = simpleNumberSequence();
|
||||
groupManager.setDataTitleKey("title");
|
||||
groupManager.setClusteringMethod(clusterMethod);
|
||||
groupManager.setDimensionReductionMethod(umapMethod);
|
||||
const allScores = [];
|
||||
for (let i = 0; i < iterations; i++) {
|
||||
const groupingResult = await groupManager.generateClusters(
|
||||
tabs,
|
||||
embeddings,
|
||||
0,
|
||||
randFunc,
|
||||
preGroupedTabIndices
|
||||
);
|
||||
const titleKey = "title";
|
||||
const centralClusterTitles = new Set(
|
||||
groupingResult.getAnchorCluster().tabs.map(a => a[titleKey])
|
||||
);
|
||||
groupingResult.getAnchorCluster().print();
|
||||
const anchorTitleSet = new Set(
|
||||
preGroupedTabIndices.map(a => tabs[a][titleKey])
|
||||
);
|
||||
Assert.equal(
|
||||
centralClusterTitles.intersection(anchorTitleSet).size,
|
||||
anchorTitleSet.size,
|
||||
`All anchor indices in target cluster`
|
||||
);
|
||||
const scoreInfo = groupingResult.getAccuracyStatsForCluster(
|
||||
"smart_group_label",
|
||||
groupingResult.getAnchorCluster().tabs[0].smart_group_label
|
||||
);
|
||||
allScores.push(scoreInfo);
|
||||
}
|
||||
return averageStatsValues(allScores);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs clustering test with multiple anchor tabs
|
||||
*
|
||||
* @param {object[]} data tabs to run test on
|
||||
* @param {object []} precomputedEmbeddings embeddings for the tabs
|
||||
* @param {number[]} anchorGroupIndices indices of tabs already present in the group
|
||||
* @param {string} anchorMethod fixed or drift anchor method
|
||||
* @param {number} silBoost value with which to boost silhouette score
|
||||
* @returns {Promise<{}|null>} metric stats from running the clustering test
|
||||
*/
|
||||
async function runAnchorTabTest(
|
||||
data,
|
||||
precomputedEmbeddings = null,
|
||||
anchorGroupIndices,
|
||||
anchorMethod = ANCHOR_METHODS.FIXED,
|
||||
silBoost = undefined
|
||||
) {
|
||||
const testParams = [[CLUSTER_METHODS.KMEANS]];
|
||||
let scoreInfo;
|
||||
for (let testP of testParams) {
|
||||
scoreInfo = await testAugmentGroup(
|
||||
testP[0],
|
||||
testP[1],
|
||||
data,
|
||||
precomputedEmbeddings,
|
||||
1,
|
||||
anchorGroupIndices,
|
||||
anchorMethod,
|
||||
silBoost
|
||||
);
|
||||
}
|
||||
if (testParams.length === 1) {
|
||||
return scoreInfo;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a local file from prefix and filename
|
||||
*
|
||||
* @param {string} host_prefix root data folder path
|
||||
* @param {string} filename name of file
|
||||
* @returns {Promise}
|
||||
*/
|
||||
function fetchFile(host_prefix, filename) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const xhr = new XMLHttpRequest();
|
||||
// const url = `${HOST_PREFIX}${filename}`;
|
||||
const url = `${host_prefix}${filename}`;
|
||||
xhr.open("GET", url, true);
|
||||
xhr.onload = () => {
|
||||
if (xhr.status === 200) {
|
||||
resolve(xhr.responseText);
|
||||
} else {
|
||||
reject(new Error(`Failed to fetch data: ${xhr.statusText}`));
|
||||
}
|
||||
};
|
||||
xhr.onerror = () => reject(new Error(`Network error getting ${url}`));
|
||||
xhr.send();
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user