/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ /** * MLSuggest helps with ML based suggestions around intents and location. */ const lazy = {}; ChromeUtils.defineESModuleGetters(lazy, { createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs", UrlbarPrefs: "resource:///modules/UrlbarPrefs.sys.mjs", }); // List of prepositions used in subject cleaning. const PREPOSITIONS = ["in", "at", "on", "for", "to", "near"]; /** * Class for handling ML-based suggestions using intent and NER models. * * @class */ class _MLSuggest { #modelEngines = {}; INTENT_OPTIONS = { taskName: "text-classification", featureId: "suggest-intent-classification", engineId: "ml-suggest-intent", timeoutMS: -1, numThreads: 2, }; NER_OPTIONS = { taskName: "token-classification", featureId: "suggest-NER", engineId: "ml-suggest-ner", timeoutMS: -1, numThreads: 2, }; // Helper to wrap createEngine for testing purpose createEngine(args) { return lazy.createEngine(args); } /** * Initializes the intent and NER models. */ async initialize() { await Promise.all([ this.#initializeModelEngine(this.INTENT_OPTIONS), this.#initializeModelEngine(this.NER_OPTIONS), ]); } /** * Generates ML-based suggestions by finding intent, detecting entities, and * combining locations. * * @param {string} query * The user's input query. * @returns {object | null} * The suggestion result including intent, location, and subject, or null if * an error occurs. * {string} intent * The predicted intent label of the query. * - {object|null} location: The detected location from the query, which is * an object with `city` and `state` fields: * - {string|null} city: The detected city, or `null` if no city is found. * - {string|null} state: The detected state, or `null` if no state is found. * {string} subject * The subject of the query after location is removed. * {object} metrics * The combined metrics from NER model results, representing additional * information about the model's performance. */ async makeSuggestions(query) { let intentRes, nerResult; try { [intentRes, nerResult] = await Promise.all([ this._findIntent(query), this._findNER(query), ]); } catch (error) { return null; } if (!intentRes || !nerResult) { return null; } const locationResVal = await this.#combineLocations( nerResult, lazy.UrlbarPrefs.get("nerThreshold") ); return { intent: intentRes[0].label, location: locationResVal, subject: this.#findSubjectFromQuery(query, locationResVal), metrics: { intent: intentRes.metrics, ner: nerResult.metrics }, }; } /** * Shuts down all initialized engines. */ async shutdown() { for (const [key, engine] of Object.entries(this.#modelEngines)) { try { await engine.terminate?.(); } finally { // Remove each engine after termination delete this.#modelEngines[key]; } } } async #initializeModelEngine(options) { const engineId = options.engineId; // uses cache if engine was used if (this.#modelEngines[engineId]) { return this.#modelEngines[engineId]; } const engine = await this.createEngine({ ...options, engineId }); // Cache the engine this.#modelEngines[engineId] = engine; return engine; } /** * Finds the intent of the query using the intent classification model. * (This has been made public to enable testing) * * @param {string} query * The user's input query. * @param {object} options * The options for the engine pipeline * @returns {object[] | null} * The intent results or null if the model is not initialized. */ async _findIntent(query, options = {}) { const engineIntentClassifier = this.#modelEngines[this.INTENT_OPTIONS.engineId]; if (!engineIntentClassifier) { return null; } let res; try { res = await engineIntentClassifier.run({ args: [query], options, }); } catch (error) { // engine could timeout or fail, so remove that from cache // and reinitialize this.#modelEngines[this.INTENT_OPTIONS.engineId] = null; this.#initializeModelEngine(this.INTENT_OPTIONS); return null; } return res; } /** * Finds named entities in the query using the NER model. * (This has been made public to enable testing) * * @param {string} query * The user's input query. * @param {object} options * The options for the engine pipeline * @returns {object[] | null} * The NER results or null if the model is not initialized. */ async _findNER(query, options = {}) { const engineNER = this.#modelEngines[this.NER_OPTIONS.engineId]; try { return engineNER?.run({ args: [query], options }); } catch (error) { // engine could timeout or fail, so remove that from cache // and reinitialize this.#modelEngines[this.NER_OPTIONS.engineId] = null; this.#initializeModelEngine(this.NER_OPTIONS); return null; } } /** * Combines location tokens detected by NER into separate city and state * components. This method processes city, state, and combined city-state * entities, returning an object with `city` and `state` fields. * * Handles the following entity types: * - B-CITY, I-CITY: Identifies city tokens. * - B-STATE, I-STATE: Identifies state tokens. * - B-CITYSTATE, I-CITYSTATE: Identifies tokens that represent a combined * city and state. * * @param {object[]} nerResult * The NER results containing tokens and their corresponding entity labels. * @param {number} nerThreshold * The confidence threshold for including entities. Tokens with a confidence * score below this threshold will be ignored. * @returns {object} * An object with `city` and `state` fields: * - {string|null} city: The detected city, or `null` if no city is found. * - {string|null} state: The detected state, or `null` if no state is found. */ async #combineLocations(nerResult, nerThreshold) { let cityResult = []; let stateResult = []; let cityStateResult = []; for (let i = 0; i < nerResult.length; i++) { const res = nerResult[i]; // Handle B-CITY, I-CITY if ( (res.entity === "B-CITY" || res.entity === "I-CITY") && res.score > nerThreshold ) { if (res.word.startsWith("##") && cityResult.length) { cityResult[cityResult.length - 1] += res.word.slice(2); } else { cityResult.push(res.word); } } // Handle B-STATE, I-STATE else if ( (res.entity === "B-STATE" || res.entity === "I-STATE") && res.score > nerThreshold ) { if (res.word.startsWith("##") && stateResult.length) { stateResult[stateResult.length - 1] += res.word.slice(2); } else { stateResult.push(res.word); } } // Handle B-CITYSTATE, I-CITYSTATE else if ( (res.entity === "B-CITYSTATE" || res.entity === "I-CITYSTATE") && res.score > nerThreshold ) { if (res.word.startsWith("##") && cityStateResult.length) { cityStateResult[cityStateResult.length - 1] += res.word.slice(2); } else { cityStateResult.push(res.word); } } } // Handle city_state as combined and split into city and state if (cityStateResult.length) { let cityStateSplit = cityStateResult.join(" ").split(","); return { city: cityStateSplit[0]?.trim() || null, state: cityStateSplit[1]?.trim() || null, }; } // Return city and state as separate components if detected return { city: cityResult.join(" ").trim() || null, state: stateResult.join(" ").trim() || null, }; } #findSubjectFromQuery(query, location) { // If location is null or no city/state, return the entire query if (!location || (!location.city && !location.state)) { return query; } // Remove the city and state from the query let subjectWithoutLocation = query; if (location.city) { subjectWithoutLocation = subjectWithoutLocation .replace(location.city, "") .trim(); } if (location.state) { subjectWithoutLocation = subjectWithoutLocation .replace(location.state, "") .trim(); } // Remove leftover commas, trailing whitespace, and unnecessary punctuation subjectWithoutLocation = subjectWithoutLocation .replaceAll(",", "") .replace(/\s+/g, " ") .trim(); return this.#cleanSubject(subjectWithoutLocation); } #cleanSubject(subject) { let end = PREPOSITIONS.find( p => subject === p || subject.endsWith(" " + p) ); if (end) { subject = subject.substring(0, subject.length - end.length).trimEnd(); } return subject; } } // Export the singleton instance export var MLSuggest = new _MLSuggest();