动态

详情 返回 返回

從0到1實現:AI版你畫我猜小遊戲 - 动态 详情

作者: vivo 互聯網前端團隊- Wei Xing

全民AI時代,前端er該如何蹭上這波熱度?本文將一步步帶大家瞭解前端應該如何結合端側AI模型,實現一個AI版你畫我猜小遊戲。

1分鐘看圖掌握核心觀點👇

圖片

本文提供配套演示代碼,可下載體驗:

Github | vivo-ai-quickdraw

一、引言

近幾年AI的進化速度堪比科幻片——昨天還在調教ChatGPT寫詩,今天Sora已經能生成電影級畫面了。技術圈彷彿被AI“醃入味”了,説不定連這篇文章都是DeepSeek幫忙寫的(狗頭)。

前端er的野望:當其他行業忙着用AI造火箭時,我們這羣和瀏覽器“鬥智鬥勇”的手藝人,該怎麼蹭上這波熱度?

在深入思考如何蹭上熱度之前,首先,我們需要先簡單瞭解下AI模型的分類。

1.1 雲端模型和端側模型

從模型的部署方式上來看,AI模型可以簡單分為雲端模型(Cloud Model)和端側模型(On-Device Model)兩種。

  • 雲端模型:將模型部署在服務集羣上,提供一些API能力供端側來調用,端側無需處理計算部分,只需調用API獲取計算結果即可。比如OpenAI的官方API就是如此。
  • 端側模型:直接將模型部署在終端設備上,模型的計算、推理完全依賴終端設備,具有更高的實時性、私密性、安全性,但同時對終端設備等硬件要求也較高。

對於前端來説,由於其非常依賴瀏覽器和終端設備性能,所以原本最適合前端的方式其實是直接調用雲端模型API,把計算的負擔轉嫁給服務集羣,頁面只需負責展示結果即可。但通常情況下,搭建集羣、訓練模型、定製API有很高的資源門檻和成本,現實條件往往不允許我們這樣做。

因此,我們可以轉而考慮利用端側模型來賦能。

1.2 大語言模型和特定領域模型

端側模型從概念上來區分,又可以簡單分為大語言模型和特定領域模型。

大語言模型

其中大語言模型(Large Language Model,LLM)就是我們熟知的ChatGPT、DeepSeek、Grok這類模型,它們功能強大,但對設備的性能要求很高,以DeepSeek為例,即使是最小的1.5B版本模型,也至少需要RTX 3060+級別的顯卡才能帶得動,並且模型本身的大小已經達到1.1GB,並不適合部署在前端項目中。

特定領域模型

所以,最終留給我們的選項就是利用一些特定領域模型來賦能前端,它們可以用來處理某些特定領域的問題。例如,利用視覺(CNN、MobileNet)模型實現圖像分類、人臉檢測,或者利用自然語言模型(NLP)實現問答機器人、文本惡意檢測等。

這些模型等特徵是尺寸較小,並且對設備性能要求不高,非常適合直接部署在前端並實現一些AI交互。

所以,接下來我們就來看看,如何從0到1訓練一個圖像分類模型(Doodle Classifier based on CNN),並將模型集成至前端頁面,實現一個經典你畫我猜小遊戲-端側AI版。

二、你畫我猜AI版-玩法簡介

動手實踐之前,先來簡單介紹下你畫我猜AI版的玩法,它和普通版本你畫我猜的區別在於:玩家根據提示詞進行塗鴉,由AI來預測玩家畫的詞是什麼,如果AI順利猜對玩家畫的詞,則玩家得分。

圖片

例如,提示詞是“長城”,則玩家需要通過畫板手繪一個長城,儘量畫的像一些,讓AI猜出正確答案就能得分。

瞭解了基礎玩法之後,接下來正片開始,詳細介紹如何從零到一開始實現它。

我們提供了簡化版的 live demo,你可以訪問鏈接試試看。同時我們也提供了相關的 demo代碼,你可以隨時訪問github倉庫,下載和嘗試運行它。

三、訓練模型

首先,第一步是訓練模型。

根據上面的玩法簡介,我們知道它本質上是一個基於視覺的圖片分類AI模型,而這個模型的功能是:輸入圖片數據後,模型可以計算出圖片的分類置信結果。例如,輸入一張小貓的圖片,模型的分類計算結果可能為:[貓 90%,狗8%,豬2%],表示模型認為這張圖有90%的概率是隻貓,8%的概率是條狗,2%概率是隻豬。

這樣一來,我們通過將用户手繪的canvas中的圖片數據丟給模型,並把模型輸出的置信概率最大的分類當作AI的猜測結果,就可以模擬出AI猜詞的互動了。

而實現這個模型也很簡單,但我們需要了解一些深度學習神經網絡的知識以及tensorflow.js的基礎用法。如果對這兩者不太熟悉,可能需要先自行google一下,做點知識儲備。

那麼假設大家已經有了一些基礎的神經網絡、TensorFlow.js基礎知識,就可以利用TensorFlow.js輕鬆搭建一個基於CNN的圖片分類模型。

3.1 獲取數據集

在進入模型訓練之前,我們需要先獲取數據集。

數據集是訓練模型的基礎,我們可以自己創建數據集(這很困難、費時),或者尋找一些開源數據集。剛好Google Lab提供了一套完整的開源塗鴉數據集(The Quick Draw Dataset),數據集中包含了345個不同類別的塗鴉數據集合,總共有5000萬份塗鴉數據,足夠我們挑選使用。

圖片

我們可以直接訪問開源塗鴉數據集(The Quick Draw Dataset)下載所需的數據。點擊頁面右上角的Get the Data 跳轉github倉庫,可以看到文檔中列出了多種數據類型:

圖片

這裏我們直接選擇下載Numpy bitmap files。

圖片

注意:這裏的數據集有345種類別,如果全部進行訓練的話,訓練時間會很長並且最終的模型大小較大,因此,我們可以視情況挑選其中的部分詞彙,例如選擇80個詞彙進行訓練,對於一款小遊戲來説,詞彙量也足夠了。

3.2 搭建模型和訓練模型

下載完訓練數據之後,接下來我們需要搭建模型結構並進行模型訓練。

如果我們下載了demo代碼,可以看到項目結構如下,主要內容為3個部分:

項目目錄/
├── 📁 src/                    
│   ├── 📄 index.ts            # 程序入口文件
│   ├── 📁 data/               # 數據集
│   │   ├── 📄 Apple.npy
│   │   ├── 📄 The Great Wall.npy  
│   │   └── 📄 ...       
│   └── 📁 model/              # 訓練模型相關
│       ├── 📄 doodle-data.model.ts  # 數據加載
│       └── 📄 classifier.model.ts   # 模型結構
├── 📄 package.json

-data目錄:存放訓練數據集

-model目錄:

  • doodle-data.model.ts:數據加載預處理
  • classifier.model.ts:定義模型結構

-index.ts:訓練程序入口

先來看項目的index.ts入口文件,功能非常簡單,主要邏輯就是四步:

  • 加載訓練數據
  • 創建模型
  • 訓練模型
  • 保存模型參數
import { Classifier } from './model/classifier.model';
import { DoodleData } from './model/doodle-data.model';

async function main(){
  const data = new DoodleData({
    directoryData: 'src/data',
    maxImageClass: 20000
  });

  // 1. 加載訓練數據
  data.loadData();
  // 2. 創建模型
  const model = new Classifier(data);
  // 3. 訓練模型
  await model.train();
  // 4. 保存模型參數
  await model.save();
}

main();

瞭解了核心流程之後,再來詳細看下model目錄下的兩個核心文件:doodle-data.model.ts和classifier.model.ts。

首先是doodle-data.model.ts ,它的核心代碼如下,主要是加載data目錄下的數據,並將數據預處理為tensor張量,後續可於訓練模型。

// 加載data目錄下的數據
loadData() {
  this.classes = fs.readdirSync(this.directoryData)
    .filter((x) => x.endsWith('.npy'))
    .map((x) => x.replace('.npy', ''));
}

// 數據生成器,預處理數據為tensor張量
*dataGenerator() {
  // ...
  for (let j = 0; j < bytes.length; j = j + this.IMAGE_SIZE) {
    const singleImage = bytes.slice(j, j + this.IMAGE_SIZE);
    const image = tf
      .tensor(singleImage)
      .reshape([this.IMAGE_WIDTH, this.IMAGE_HEIGHT, 1])
      .toFloat();
    const xs = image.div(offset);
    const ys = tf.tensor(this.classes.map((x) => (x === label ? 1 : 0)));
    yield { xs, ys };
  }
}

其次是,classifier.model.ts。它的核心代碼如下,代碼的主要功能是:

構建了一個基於CNN的圖像分類模型。通過tf.layers.conv3d()構造了卷積神經網絡結構。

提供了train()方法,用於訓練模型。這裏定義了模型訓練的迭代次數(epochs)、訓練的批次大小(batchSize),這些參數會影響模型訓練的最終結果,就是通常我們所説的“模型調參”,當你覺得模型訓練效果不佳時,可以調整這些參數重新訓練,直到達成不錯的準確率。

提供了save()方法,用於保存模型參數。

import * as tf from "@tensorflow/tfjs-node";
import { DoodleData } from "./doodle-data.model";

exportclassClassifier {
  // ...
  // 定義模型結構
  constructor(data: DoodleData) {
    this.data = data;
    this.model = tf.sequential();
    this.model.add(
      tf.layers.conv2d({
        inputShape: [data.IMAGE_WIDTH, data.IMAGE_HEIGHT, 1],
        kernelSize: 3,
        filters: 16,
        strides: 1,
        activation: "relu",
        kernelInitializer: "varianceScaling",
      })
    );
    this.model.add(
      tf.layers.maxPooling2d({
        poolSize: [2, 2],
        strides: [2, 2],
      })
    );
    this.model.add(
      tf.layers.conv2d({
        kernelSize: 3,
        filters: 32,
        strides: 1,
        activation: "relu",
        kernelInitializer: "varianceScaling",
      })
    );
    this.model.add(
      tf.layers.maxPooling2d({
        poolSize: [2, 2],
        strides: [2, 2],
      })
    );
    this.model.add(tf.layers.flatten());
    this.model.add(
      tf.layers.dense({
        units: this.data.totalClasses,
        kernelInitializer: "varianceScaling",
        activation: "softmax",
      })
    );

    const optimizer = tf.train.adam();
    this.model.compile({
      optimizer,
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });
  }

  // 模型訓練
  async train(){
    const trainingData = tf.data
      .generator(() => this.data.dataGenerator("train"))
      .shuffle(this.data.maxImageClass * this.data.totalClasses)
      .batch(64);

    const testData = tf.data
      .generator(() => this.data.dataGenerator("test"))
      .shuffle(this.data.maxImageClass * this.data.totalClasses)
      .batch(64);

    await this.model.fitDataset(trainingData, {
      epochs: 5,
      validationData: testData,
      callbacks: {
        onEpochEnd: async (epoch, logs) => {
          this.logger.debug(
            `Epoch: ${epoch} - acc: ${logs?.acc.toFixed(
              3
            )} - loss: ${logs?.loss.toFixed(3)}`
          );
        },
        onBatchBegin: async (epoch, logs) => {
          console.log("onBatchBegin" + epoch + JSON.stringify(logs));
        },
      },
    });
  }

  // 保存模型
  async save(){
    fs.mkdirSync("doodle-model", { recursive: true });
    fs.writeFileSync(
      "doodle-model/classes.json",
      JSON.stringify({ classes: this.data.classes })
    );
    await this.model.save("file://./doodle-model");
  }
}

如果我們從github倉庫下載了demo代碼,在根目錄下執行:

npm run start

開啓模型訓練過程,會有一些輸出如下,表示當前的訓練輪次、識別準確率、損失等。

onBatchBegin0{"batch":0,"size":512}
onBatchBegin1{"batch":1,"size":512}
onBatchBegin2{"batch":2,"size":512}
onBatchBegin3{"batch":3,"size":512}
onBatchBegin4{"batch":4,"size":192}
...
[Classifier] Epoch: 0 - acc: 0.078 - loss: 2.632
...

耐心等待日誌打完,模型訓練完成之後,我們的項目目錄下就會產出一個額外的目錄,存放模型的訓練結果。

  • classes.json:圖片的所有分類,根據data目錄中的數據文件名稱生成
  • model.json:模型的描述文件
  • weights.bin:模型的參數文件
項目目錄/
├── 📁 doodle-model/           # 訓練結果(最終模型)
│   │   ├── 📄 classes.json    # 圖片分類     
│   │   ├── 📄 model.json      # 模型描述文件
│   │   └── 📄 weights.bin     # 模型參數

這樣,我們的模型就訓練完成了。

接下來看看如何在頁面中集成模型,實現從繪製canvas圖片到模型分類預測的效果。

四、集成至頁面

在頁面中的集成模型也非常簡單,我們只需要創建一個可以繪圖的canvas,每隔一段時間就將當前canvas的圖像數據傳輸給模型,觸發一次模型預測即可。

先來看下項目的核心目錄結構:

項目目錄/
├── 📁 public/assets/doodle-modle/   # 將訓練生成的模型放置在public目錄下
│   │   ├── 📄 classes.json    # 圖片分類     
│   │   ├── 📄 model.json      # 模型描述文件
│   │   └── 📄 weights.bin     # 模型參數
├── 📁 src   
│   ├── 📁 models/               
│   │   └── 📄 DoodleClassifier.js  # 圖片分類器
│   ├── 📁 views/               
│   │   └── 📄 DoodleView.vue   # 頁面視圖(canvas畫布)

其中,DoodleClassifier.js的核心代碼如下:

  • loadModel:加載模型,包括model.json、classes.json,在model.json中會自動加載weights.bin
  • predictTopN:輸入圖片數據,調用model.predict() 預測最有可能的TopN個分類結果,並按照置信度排序
import * as tf from "@tensorflow/tfjs";
import apiClient from "@/services/http";

// 加載模型
async loadModel(){
  this.model = await tf.loadLayersModel("assets/doodle-model/model.json");
  const response = await apiClient.get("assets/doodle-model/classes.json");
  this.classes = response.data.classes; 
}

// 預測最有可能的TopN個分類,並按照置信度排序
async predictTopN(data, n){
  const predictions = Array.from(await this.model.predict(data).data());

  const indexedPredictions = predictions.map((probability, index) => ({
    probability,
    index,
  }));

  indexedPredictions.sort((a, b) => b.probability - a.probability);

  const topNPredictions = indexedPredictions.slice(0, n);

  return topNPredictions.map((p) => ({
    label: this.classes[p.index],
    accuracy: p.probability,
  }));
}

// 預測分類結果
async predict(data){
  const argMax = await this.model.predict(data).argMax(-1).data();
  returnthis.classes[argMax[0]];
}

DoodleView.vue的核心代碼如下:

  • 調用new DoodleClassifier()構造圖片分類器
  • 調用loadModel()加載模型
  • 預處理canvas的圖片數據
  • 將預處理的數據傳輸給model.predictTopN(),預測圖片分類
// 構造圖片分類器
this.model = new DoodleClassifier()

// 加載模型
this.model.loadModel()

// 預處理canvas圖片數據
const tensor = tf.browser.fromPixels(imgData, 1);
const resized = tf.image
  .resizeBilinear(tensor, [28, 28])
  .reshape([1, 28, 28, 1]) // Reshape to [1, 28, 28, 1] for batch and single channel
  .toFloat();
const normalized = tf.scalar(1.0).sub(resized.div(tf.scalar(255.0)));

// 預測圖片分類
this.model.predictTopN(normalized, 5).then((predictions) => {
  if (predictions) {
    this.predictions = predictions;
  }
});

到這為止,你畫我猜-AI版就已經基本搭建完成了。實現起來並不複雜。

如果一切順利,並且你按照我們提供的demo構建頁面,就可以直接在項目中運行:

npm run serve

一個簡易版本的你畫我猜AI版就運行成功了,試試看吧。

圖片

五、優化措施

通過上面的步驟,我們完成了模型訓練和canvas圖片分類預測的全流程,成功實現了你畫我猜AI版。但實際上可能會遇到兩個比較關鍵的問題。

5.1 數據標準化

當我們去調整canvas畫布大小、畫筆粗細後,可能會出現預測結果不準確的情況,此時從canvas獲取的圖像數據和我們餵給模型的訓練數據產生了差異。

這時候我們需要在獲取到canvas數據後,額外做一些數據預處理,將數據標準化,例如:

  • 將畫布的內容區域裁剪為正方形,並居中顯示
  • 將畫布的線條適當變粗,使模型更容易識別

5.2 利用 webworker 優化性能

模型的計算過程是十分耗時的,將計算過程放在主線程會導致頁面卡頓,因此我們可以將整個模型的預測部分放入webworker中,以此來提升計算性能,不影響頁面渲染。

六、總結

你畫我猜-端側AI版是前端結合AI的一個簡單案例,為我們提供了前端利用AI賦能的大致思路和基本實現邏輯。條件允許的情況下,我們可以利雲端模型來拓展前端業務。但如果缺乏資源,我們則轉而考慮使用端側的特定領域模型來產出一些新玩法、新交互。相比之下,端側AI具有更強的靈活性、安全性和更低的集成成本。大家可以試着在各自的業務中探索和使用端側AI,或許無法產出太大的效益,但也是在全民AI時代下,一些積極的嘗試和沉澱。

七、參考

部分代碼參考自:

  • Gihuthub | RiccardoGai | doddle-classifier-model
  • Gihuthub | RiccardoGai | doddle-classifier-app
  • Gihuthub | yining1023 | doodleNet 
user avatar u_13529088 头像 u_11365552 头像 joe235 头像 ldh-blog 头像 niandb 头像 johanazhu 头像 jinl9s27 头像 lvweifu 头像 fiveyoboy 头像 yayujs 头像 howiecong 头像 iwan_68b8da84d3d8b 头像
点赞 14 用户, 点赞了这篇动态!
点赞

Add a new 评论

Some HTML is okay.