Files
tubestation/browser/components/urlbar/private/MLSuggest.sys.mjs

312 lines
9.2 KiB
JavaScript

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/. */
/**
* MLSuggest helps with ML based suggestions around intents and location.
*/
const lazy = {};
ChromeUtils.defineESModuleGetters(lazy, {
createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs",
UrlbarPrefs: "resource:///modules/UrlbarPrefs.sys.mjs",
});
// List of prepositions used in subject cleaning.
const PREPOSITIONS = ["in", "at", "on", "for", "to", "near"];
/**
* Class for handling ML-based suggestions using intent and NER models.
*
* @class
*/
class _MLSuggest {
#modelEngines = {};
INTENT_OPTIONS = {
taskName: "text-classification",
featureId: "suggest-intent-classification",
engineId: "ml-suggest-intent",
timeoutMS: -1,
numThreads: 2,
};
NER_OPTIONS = {
taskName: "token-classification",
featureId: "suggest-NER",
engineId: "ml-suggest-ner",
timeoutMS: -1,
numThreads: 2,
};
// Helper to wrap createEngine for testing purpose
createEngine(args) {
return lazy.createEngine(args);
}
/**
* Initializes the intent and NER models.
*/
async initialize() {
await Promise.all([
this.#initializeModelEngine(this.INTENT_OPTIONS),
this.#initializeModelEngine(this.NER_OPTIONS),
]);
}
/**
* Generates ML-based suggestions by finding intent, detecting entities, and
* combining locations.
*
* @param {string} query
* The user's input query.
* @returns {object | null}
* The suggestion result including intent, location, and subject, or null if
* an error occurs.
* {string} intent
* The predicted intent label of the query.
* - {object|null} location: The detected location from the query, which is
* an object with `city` and `state` fields:
* - {string|null} city: The detected city, or `null` if no city is found.
* - {string|null} state: The detected state, or `null` if no state is found.
* {string} subject
* The subject of the query after location is removed.
* {object} metrics
* The combined metrics from NER model results, representing additional
* information about the model's performance.
*/
async makeSuggestions(query) {
let intentRes, nerResult;
try {
[intentRes, nerResult] = await Promise.all([
this._findIntent(query),
this._findNER(query),
]);
} catch (error) {
return null;
}
if (!intentRes || !nerResult) {
return null;
}
const locationResVal = await this.#combineLocations(
nerResult,
lazy.UrlbarPrefs.get("nerThreshold")
);
return {
intent: intentRes[0].label,
location: locationResVal,
subject: this.#findSubjectFromQuery(query, locationResVal),
metrics: { intent: intentRes.metrics, ner: nerResult.metrics },
};
}
/**
* Shuts down all initialized engines.
*/
async shutdown() {
for (const [key, engine] of Object.entries(this.#modelEngines)) {
try {
await engine.terminate?.();
} finally {
// Remove each engine after termination
delete this.#modelEngines[key];
}
}
}
async #initializeModelEngine(options) {
const engineId = options.engineId;
// uses cache if engine was used
if (this.#modelEngines[engineId]) {
return this.#modelEngines[engineId];
}
const engine = await this.createEngine({ ...options, engineId });
// Cache the engine
this.#modelEngines[engineId] = engine;
return engine;
}
/**
* Finds the intent of the query using the intent classification model.
* (This has been made public to enable testing)
*
* @param {string} query
* The user's input query.
* @param {object} options
* The options for the engine pipeline
* @returns {object[] | null}
* The intent results or null if the model is not initialized.
*/
async _findIntent(query, options = {}) {
const engineIntentClassifier =
this.#modelEngines[this.INTENT_OPTIONS.engineId];
if (!engineIntentClassifier) {
return null;
}
let res;
try {
res = await engineIntentClassifier.run({
args: [query],
options,
});
} catch (error) {
// engine could timeout or fail, so remove that from cache
// and reinitialize
this.#modelEngines[this.INTENT_OPTIONS.engineId] = null;
this.#initializeModelEngine(this.INTENT_OPTIONS);
return null;
}
return res;
}
/**
* Finds named entities in the query using the NER model.
* (This has been made public to enable testing)
*
* @param {string} query
* The user's input query.
* @param {object} options
* The options for the engine pipeline
* @returns {object[] | null}
* The NER results or null if the model is not initialized.
*/
async _findNER(query, options = {}) {
const engineNER = this.#modelEngines[this.NER_OPTIONS.engineId];
try {
return engineNER?.run({ args: [query], options });
} catch (error) {
// engine could timeout or fail, so remove that from cache
// and reinitialize
this.#modelEngines[this.NER_OPTIONS.engineId] = null;
this.#initializeModelEngine(this.NER_OPTIONS);
return null;
}
}
/**
* Combines location tokens detected by NER into separate city and state
* components. This method processes city, state, and combined city-state
* entities, returning an object with `city` and `state` fields.
*
* Handles the following entity types:
* - B-CITY, I-CITY: Identifies city tokens.
* - B-STATE, I-STATE: Identifies state tokens.
* - B-CITYSTATE, I-CITYSTATE: Identifies tokens that represent a combined
* city and state.
*
* @param {object[]} nerResult
* The NER results containing tokens and their corresponding entity labels.
* @param {number} nerThreshold
* The confidence threshold for including entities. Tokens with a confidence
* score below this threshold will be ignored.
* @returns {object}
* An object with `city` and `state` fields:
* - {string|null} city: The detected city, or `null` if no city is found.
* - {string|null} state: The detected state, or `null` if no state is found.
*/
async #combineLocations(nerResult, nerThreshold) {
let cityResult = [];
let stateResult = [];
let cityStateResult = [];
for (let i = 0; i < nerResult.length; i++) {
const res = nerResult[i];
// Handle B-CITY, I-CITY
if (
(res.entity === "B-CITY" || res.entity === "I-CITY") &&
res.score > nerThreshold
) {
if (res.word.startsWith("##") && cityResult.length) {
cityResult[cityResult.length - 1] += res.word.slice(2);
} else {
cityResult.push(res.word);
}
}
// Handle B-STATE, I-STATE
else if (
(res.entity === "B-STATE" || res.entity === "I-STATE") &&
res.score > nerThreshold
) {
if (res.word.startsWith("##") && stateResult.length) {
stateResult[stateResult.length - 1] += res.word.slice(2);
} else {
stateResult.push(res.word);
}
}
// Handle B-CITYSTATE, I-CITYSTATE
else if (
(res.entity === "B-CITYSTATE" || res.entity === "I-CITYSTATE") &&
res.score > nerThreshold
) {
if (res.word.startsWith("##") && cityStateResult.length) {
cityStateResult[cityStateResult.length - 1] += res.word.slice(2);
} else {
cityStateResult.push(res.word);
}
}
}
// Handle city_state as combined and split into city and state
if (cityStateResult.length) {
let cityStateSplit = cityStateResult.join(" ").split(",");
return {
city: cityStateSplit[0]?.trim() || null,
state: cityStateSplit[1]?.trim() || null,
};
}
// Return city and state as separate components if detected
return {
city: cityResult.join(" ").trim() || null,
state: stateResult.join(" ").trim() || null,
};
}
#findSubjectFromQuery(query, location) {
// If location is null or no city/state, return the entire query
if (!location || (!location.city && !location.state)) {
return query;
}
// Remove the city and state from the query
let subjectWithoutLocation = query;
if (location.city) {
subjectWithoutLocation = subjectWithoutLocation
.replace(location.city, "")
.trim();
}
if (location.state) {
subjectWithoutLocation = subjectWithoutLocation
.replace(location.state, "")
.trim();
}
// Remove leftover commas, trailing whitespace, and unnecessary punctuation
subjectWithoutLocation = subjectWithoutLocation
.replaceAll(",", "")
.replace(/\s+/g, " ")
.trim();
return this.#cleanSubject(subjectWithoutLocation);
}
#cleanSubject(subject) {
let end = PREPOSITIONS.find(
p => subject === p || subject.endsWith(" " + p)
);
if (end) {
subject = subject.substring(0, subject.length - end.length).trimEnd();
}
return subject;
}
}
// Export the singleton instance
export var MLSuggest = new _MLSuggest();