作者: vivo 互聯網前端團隊- Wei Xing
全民AI時代,前端er該如何蹭上這波熱度?本文將一步步帶大家瞭解前端應該如何結合端側AI模型,實現一個AI版你畫我猜小遊戲。
1分鐘看圖掌握核心觀點👇
本文提供配套演示代碼,可下載體驗:
Github | vivo-ai-quickdraw
一、引言
近幾年AI的進化速度堪比科幻片——昨天還在調教ChatGPT寫詩,今天Sora已經能生成電影級畫面了。技術圈彷彿被AI“醃入味”了,説不定連這篇文章都是DeepSeek幫忙寫的(狗頭)。
前端er的野望:當其他行業忙着用AI造火箭時,我們這羣和瀏覽器“鬥智鬥勇”的手藝人,該怎麼蹭上這波熱度?
在深入思考如何蹭上熱度之前,首先,我們需要先簡單瞭解下AI模型的分類。
1.1 雲端模型和端側模型
從模型的部署方式上來看,AI模型可以簡單分為雲端模型(Cloud Model)和端側模型(On-Device Model)兩種。
- 雲端模型:將模型部署在服務集羣上,提供一些API能力供端側來調用,端側無需處理計算部分,只需調用API獲取計算結果即可。比如OpenAI的官方API就是如此。
- 端側模型:直接將模型部署在終端設備上,模型的計算、推理完全依賴終端設備,具有更高的實時性、私密性、安全性,但同時對終端設備等硬件要求也較高。
對於前端來説,由於其非常依賴瀏覽器和終端設備性能,所以原本最適合前端的方式其實是直接調用雲端模型API,把計算的負擔轉嫁給服務集羣,頁面只需負責展示結果即可。但通常情況下,搭建集羣、訓練模型、定製API有很高的資源門檻和成本,現實條件往往不允許我們這樣做。
因此,我們可以轉而考慮利用端側模型來賦能。
1.2 大語言模型和特定領域模型
端側模型從概念上來區分,又可以簡單分為大語言模型和特定領域模型。
大語言模型
其中大語言模型(Large Language Model,LLM)就是我們熟知的ChatGPT、DeepSeek、Grok這類模型,它們功能強大,但對設備的性能要求很高,以DeepSeek為例,即使是最小的1.5B版本模型,也至少需要RTX 3060+級別的顯卡才能帶得動,並且模型本身的大小已經達到1.1GB,並不適合部署在前端項目中。
特定領域模型
所以,最終留給我們的選項就是利用一些特定領域模型來賦能前端,它們可以用來處理某些特定領域的問題。例如,利用視覺(CNN、MobileNet)模型實現圖像分類、人臉檢測,或者利用自然語言模型(NLP)實現問答機器人、文本惡意檢測等。
這些模型等特徵是尺寸較小,並且對設備性能要求不高,非常適合直接部署在前端並實現一些AI交互。
所以,接下來我們就來看看,如何從0到1訓練一個圖像分類模型(Doodle Classifier based on CNN),並將模型集成至前端頁面,實現一個經典你畫我猜小遊戲-端側AI版。
二、你畫我猜AI版-玩法簡介
動手實踐之前,先來簡單介紹下你畫我猜AI版的玩法,它和普通版本你畫我猜的區別在於:玩家根據提示詞進行塗鴉,由AI來預測玩家畫的詞是什麼,如果AI順利猜對玩家畫的詞,則玩家得分。
例如,提示詞是“長城”,則玩家需要通過畫板手繪一個長城,儘量畫的像一些,讓AI猜出正確答案就能得分。
瞭解了基礎玩法之後,接下來正片開始,詳細介紹如何從零到一開始實現它。
我們提供了簡化版的 live demo,你可以訪問鏈接試試看。同時我們也提供了相關的 demo代碼,你可以隨時訪問github倉庫,下載和嘗試運行它。
三、訓練模型
首先,第一步是訓練模型。
根據上面的玩法簡介,我們知道它本質上是一個基於視覺的圖片分類AI模型,而這個模型的功能是:輸入圖片數據後,模型可以計算出圖片的分類置信結果。例如,輸入一張小貓的圖片,模型的分類計算結果可能為:[貓 90%,狗8%,豬2%],表示模型認為這張圖有90%的概率是隻貓,8%的概率是條狗,2%概率是隻豬。
這樣一來,我們通過將用户手繪的canvas中的圖片數據丟給模型,並把模型輸出的置信概率最大的分類當作AI的猜測結果,就可以模擬出AI猜詞的互動了。
而實現這個模型也很簡單,但我們需要了解一些深度學習神經網絡的知識以及tensorflow.js的基礎用法。如果對這兩者不太熟悉,可能需要先自行google一下,做點知識儲備。
那麼假設大家已經有了一些基礎的神經網絡、TensorFlow.js基礎知識,就可以利用TensorFlow.js輕鬆搭建一個基於CNN的圖片分類模型。
3.1 獲取數據集
在進入模型訓練之前,我們需要先獲取數據集。
數據集是訓練模型的基礎,我們可以自己創建數據集(這很困難、費時),或者尋找一些開源數據集。剛好Google Lab提供了一套完整的開源塗鴉數據集(The Quick Draw Dataset),數據集中包含了345個不同類別的塗鴉數據集合,總共有5000萬份塗鴉數據,足夠我們挑選使用。
我們可以直接訪問開源塗鴉數據集(The Quick Draw Dataset)下載所需的數據。點擊頁面右上角的Get the Data 跳轉github倉庫,可以看到文檔中列出了多種數據類型:
這裏我們直接選擇下載Numpy bitmap files。
注意:這裏的數據集有345種類別,如果全部進行訓練的話,訓練時間會很長並且最終的模型大小較大,因此,我們可以視情況挑選其中的部分詞彙,例如選擇80個詞彙進行訓練,對於一款小遊戲來説,詞彙量也足夠了。
3.2 搭建模型和訓練模型
下載完訓練數據之後,接下來我們需要搭建模型結構並進行模型訓練。
如果我們下載了demo代碼,可以看到項目結構如下,主要內容為3個部分:
項目目錄/
├── 📁 src/
│ ├── 📄 index.ts # 程序入口文件
│ ├── 📁 data/ # 數據集
│ │ ├── 📄 Apple.npy
│ │ ├── 📄 The Great Wall.npy
│ │ └── 📄 ...
│ └── 📁 model/ # 訓練模型相關
│ ├── 📄 doodle-data.model.ts # 數據加載
│ └── 📄 classifier.model.ts # 模型結構
├── 📄 package.json
-data目錄:存放訓練數據集
-model目錄:
- doodle-data.model.ts:數據加載預處理
- classifier.model.ts:定義模型結構
-index.ts:訓練程序入口
先來看項目的index.ts入口文件,功能非常簡單,主要邏輯就是四步:
- 加載訓練數據
- 創建模型
- 訓練模型
- 保存模型參數
import { Classifier } from './model/classifier.model';
import { DoodleData } from './model/doodle-data.model';
async function main(){
const data = new DoodleData({
directoryData: 'src/data',
maxImageClass: 20000
});
// 1. 加載訓練數據
data.loadData();
// 2. 創建模型
const model = new Classifier(data);
// 3. 訓練模型
await model.train();
// 4. 保存模型參數
await model.save();
}
main();
瞭解了核心流程之後,再來詳細看下model目錄下的兩個核心文件:doodle-data.model.ts和classifier.model.ts。
首先是doodle-data.model.ts ,它的核心代碼如下,主要是加載data目錄下的數據,並將數據預處理為tensor張量,後續可於訓練模型。
// 加載data目錄下的數據
loadData() {
this.classes = fs.readdirSync(this.directoryData)
.filter((x) => x.endsWith('.npy'))
.map((x) => x.replace('.npy', ''));
}
// 數據生成器,預處理數據為tensor張量
*dataGenerator() {
// ...
for (let j = 0; j < bytes.length; j = j + this.IMAGE_SIZE) {
const singleImage = bytes.slice(j, j + this.IMAGE_SIZE);
const image = tf
.tensor(singleImage)
.reshape([this.IMAGE_WIDTH, this.IMAGE_HEIGHT, 1])
.toFloat();
const xs = image.div(offset);
const ys = tf.tensor(this.classes.map((x) => (x === label ? 1 : 0)));
yield { xs, ys };
}
}
其次是,classifier.model.ts。它的核心代碼如下,代碼的主要功能是:
構建了一個基於CNN的圖像分類模型。通過tf.layers.conv3d()構造了卷積神經網絡結構。
提供了train()方法,用於訓練模型。這裏定義了模型訓練的迭代次數(epochs)、訓練的批次大小(batchSize),這些參數會影響模型訓練的最終結果,就是通常我們所説的“模型調參”,當你覺得模型訓練效果不佳時,可以調整這些參數重新訓練,直到達成不錯的準確率。
提供了save()方法,用於保存模型參數。
import * as tf from "@tensorflow/tfjs-node";
import { DoodleData } from "./doodle-data.model";
exportclassClassifier {
// ...
// 定義模型結構
constructor(data: DoodleData) {
this.data = data;
this.model = tf.sequential();
this.model.add(
tf.layers.conv2d({
inputShape: [data.IMAGE_WIDTH, data.IMAGE_HEIGHT, 1],
kernelSize: 3,
filters: 16,
strides: 1,
activation: "relu",
kernelInitializer: "varianceScaling",
})
);
this.model.add(
tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2],
})
);
this.model.add(
tf.layers.conv2d({
kernelSize: 3,
filters: 32,
strides: 1,
activation: "relu",
kernelInitializer: "varianceScaling",
})
);
this.model.add(
tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2],
})
);
this.model.add(tf.layers.flatten());
this.model.add(
tf.layers.dense({
units: this.data.totalClasses,
kernelInitializer: "varianceScaling",
activation: "softmax",
})
);
const optimizer = tf.train.adam();
this.model.compile({
optimizer,
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
}
// 模型訓練
async train(){
const trainingData = tf.data
.generator(() => this.data.dataGenerator("train"))
.shuffle(this.data.maxImageClass * this.data.totalClasses)
.batch(64);
const testData = tf.data
.generator(() => this.data.dataGenerator("test"))
.shuffle(this.data.maxImageClass * this.data.totalClasses)
.batch(64);
await this.model.fitDataset(trainingData, {
epochs: 5,
validationData: testData,
callbacks: {
onEpochEnd: async (epoch, logs) => {
this.logger.debug(
`Epoch: ${epoch} - acc: ${logs?.acc.toFixed(
3
)} - loss: ${logs?.loss.toFixed(3)}`
);
},
onBatchBegin: async (epoch, logs) => {
console.log("onBatchBegin" + epoch + JSON.stringify(logs));
},
},
});
}
// 保存模型
async save(){
fs.mkdirSync("doodle-model", { recursive: true });
fs.writeFileSync(
"doodle-model/classes.json",
JSON.stringify({ classes: this.data.classes })
);
await this.model.save("file://./doodle-model");
}
}
如果我們從github倉庫下載了demo代碼,在根目錄下執行:
npm run start
開啓模型訓練過程,會有一些輸出如下,表示當前的訓練輪次、識別準確率、損失等。
onBatchBegin0{"batch":0,"size":512}
onBatchBegin1{"batch":1,"size":512}
onBatchBegin2{"batch":2,"size":512}
onBatchBegin3{"batch":3,"size":512}
onBatchBegin4{"batch":4,"size":192}
...
[Classifier] Epoch: 0 - acc: 0.078 - loss: 2.632
...
耐心等待日誌打完,模型訓練完成之後,我們的項目目錄下就會產出一個額外的目錄,存放模型的訓練結果。
- classes.json:圖片的所有分類,根據data目錄中的數據文件名稱生成
- model.json:模型的描述文件
- weights.bin:模型的參數文件
項目目錄/
├── 📁 doodle-model/ # 訓練結果(最終模型)
│ │ ├── 📄 classes.json # 圖片分類
│ │ ├── 📄 model.json # 模型描述文件
│ │ └── 📄 weights.bin # 模型參數
這樣,我們的模型就訓練完成了。
接下來看看如何在頁面中集成模型,實現從繪製canvas圖片到模型分類預測的效果。
四、集成至頁面
在頁面中的集成模型也非常簡單,我們只需要創建一個可以繪圖的canvas,每隔一段時間就將當前canvas的圖像數據傳輸給模型,觸發一次模型預測即可。
先來看下項目的核心目錄結構:
項目目錄/
├── 📁 public/assets/doodle-modle/ # 將訓練生成的模型放置在public目錄下
│ │ ├── 📄 classes.json # 圖片分類
│ │ ├── 📄 model.json # 模型描述文件
│ │ └── 📄 weights.bin # 模型參數
├── 📁 src
│ ├── 📁 models/
│ │ └── 📄 DoodleClassifier.js # 圖片分類器
│ ├── 📁 views/
│ │ └── 📄 DoodleView.vue # 頁面視圖(canvas畫布)
其中,DoodleClassifier.js的核心代碼如下:
- loadModel:加載模型,包括model.json、classes.json,在model.json中會自動加載weights.bin
- predictTopN:輸入圖片數據,調用model.predict() 預測最有可能的TopN個分類結果,並按照置信度排序
import * as tf from "@tensorflow/tfjs";
import apiClient from "@/services/http";
// 加載模型
async loadModel(){
this.model = await tf.loadLayersModel("assets/doodle-model/model.json");
const response = await apiClient.get("assets/doodle-model/classes.json");
this.classes = response.data.classes;
}
// 預測最有可能的TopN個分類,並按照置信度排序
async predictTopN(data, n){
const predictions = Array.from(await this.model.predict(data).data());
const indexedPredictions = predictions.map((probability, index) => ({
probability,
index,
}));
indexedPredictions.sort((a, b) => b.probability - a.probability);
const topNPredictions = indexedPredictions.slice(0, n);
return topNPredictions.map((p) => ({
label: this.classes[p.index],
accuracy: p.probability,
}));
}
// 預測分類結果
async predict(data){
const argMax = await this.model.predict(data).argMax(-1).data();
returnthis.classes[argMax[0]];
}
DoodleView.vue的核心代碼如下:
- 調用new DoodleClassifier()構造圖片分類器
- 調用loadModel()加載模型
- 預處理canvas的圖片數據
- 將預處理的數據傳輸給model.predictTopN(),預測圖片分類
// 構造圖片分類器
this.model = new DoodleClassifier()
// 加載模型
this.model.loadModel()
// 預處理canvas圖片數據
const tensor = tf.browser.fromPixels(imgData, 1);
const resized = tf.image
.resizeBilinear(tensor, [28, 28])
.reshape([1, 28, 28, 1]) // Reshape to [1, 28, 28, 1] for batch and single channel
.toFloat();
const normalized = tf.scalar(1.0).sub(resized.div(tf.scalar(255.0)));
// 預測圖片分類
this.model.predictTopN(normalized, 5).then((predictions) => {
if (predictions) {
this.predictions = predictions;
}
});
到這為止,你畫我猜-AI版就已經基本搭建完成了。實現起來並不複雜。
如果一切順利,並且你按照我們提供的demo構建頁面,就可以直接在項目中運行:
npm run serve
一個簡易版本的你畫我猜AI版就運行成功了,試試看吧。
五、優化措施
通過上面的步驟,我們完成了模型訓練和canvas圖片分類預測的全流程,成功實現了你畫我猜AI版。但實際上可能會遇到兩個比較關鍵的問題。
5.1 數據標準化
當我們去調整canvas畫布大小、畫筆粗細後,可能會出現預測結果不準確的情況,此時從canvas獲取的圖像數據和我們餵給模型的訓練數據產生了差異。
這時候我們需要在獲取到canvas數據後,額外做一些數據預處理,將數據標準化,例如:
- 將畫布的內容區域裁剪為正方形,並居中顯示
- 將畫布的線條適當變粗,使模型更容易識別
5.2 利用 webworker 優化性能
模型的計算過程是十分耗時的,將計算過程放在主線程會導致頁面卡頓,因此我們可以將整個模型的預測部分放入webworker中,以此來提升計算性能,不影響頁面渲染。
六、總結
你畫我猜-端側AI版是前端結合AI的一個簡單案例,為我們提供了前端利用AI賦能的大致思路和基本實現邏輯。條件允許的情況下,我們可以利雲端模型來拓展前端業務。但如果缺乏資源,我們則轉而考慮使用端側的特定領域模型來產出一些新玩法、新交互。相比之下,端側AI具有更強的靈活性、安全性和更低的集成成本。大家可以試着在各自的業務中探索和使用端側AI,或許無法產出太大的效益,但也是在全民AI時代下,一些積極的嘗試和沉澱。
七、參考
部分代碼參考自:
- Gihuthub | RiccardoGai | doddle-classifier-model
- Gihuthub | RiccardoGai | doddle-classifier-app
- Gihuthub | yining1023 | doodleNet