PyTorch 模型可以保存為兩種格式:.pt:此格式保存整個模型,
包括其架構和學習參數。.pth:此格式僅保存模型的狀態字典,其中包括模型的學習參數和一些元數據。
PyTorch
格式基於 Python 的 pickle 模塊,該模塊用於序列化 Python 對象。
為了理解 pickle
的工作原理,讓我們看以下示例:
import pickle
model_state_dict = { "layer1": "hello", "layer2": "world" }
pickle.dump(model_state_dict, open("model.pkl", "wb"))The pickle.dump() 函數將 model_state_dict 字典序列化並保存到名為
model.pkl. 的文件中。
輸出文件現在包含字典的
二進制表示:
model.pkl hex view要
將序列化的字典加載回