Arbitrating the integrity of stochastic gradient descent with proof-of-learning

by Hengrui Jia, Mohammad Yaghini, Christopher A. Choquette-Choo, Natalie Dullerud, Anvith Thudi, Varun Chandrasekaran, and Nicolas Papernot

Machine learning (ML) is expensive. Training state of the art models such as GPT-3 may cost millions of dollars or require specialized hardware. With such hefty training costs, having a model wrongfully appropriated or stolen, through insider attacks or model extraction, is a serious risk [1,2]. So, how would one prove ownership of the model they have trained? This is the problem we tackle in our proof-of-learning paper, which we presented at IEEE S&P 2021 ‘‘Oakland’’ earlier this year. Here is the recording of the talk we gave at the conference if you’d rather listen to the talk than read this blog post.

To illustrate what we mean by ownership, take the example of purchasing a car. Here, ownership could be defined as having acquired the car (you own your car if you bought it) by spending some resources (the money you spent). Similarly, the crux of the proof-of-learning approach presented in this blog post is to observe that owning a model implies that an entity expended some resources (that is, computational resources) to train the model.

What if multiple entities claim ownership over a particular model? In such cases, a trusted arbitrator needs to verify their claims of ownership. Observe, however, that verifying ML training wouldn’t just be useful for disputes over model ownership. Verification schemes can also be used when the fidelity of training could be questioned. Decentralized (distributed) training is an example where many workers conduct the training in parallel among themselves. With no way to verify that a worker is performing the necessary training steps, there’s no real defence against those that may fail to do so at any time, and with no warning (known as a Byzantine fault) thus disrupting training.

Outside ML, the general problem of verifying ownership isn’t new. The premise of blockchains revolves around ownership of secret information – the inverse of a one-way function. Systems based on a proof of work allow a verifier to check if a party has expended significant computational resources required to obtain said secret information. We wish to design a scheme where a verifier can observe secret information held by a party during the training of a ML model to verify if they’ve expended the significant computational resources required to obtain the model which is eventually released. We also enable a verifier to assess the integrity of these computations required for training.

Verifying Training

What secret information does an honest party who effortfully trained a model have? When you train a model, you pass through many intermediate values for the model’s parameters, i.e., states. The sequence of these states (and the states themselves) can’t be guessed (in advance) because training is very noisy: you can start from the same initialization with the same batch ordering and end up with vastly different final weights when training on different GPUs. Therefore, the intermediate states during training can serve as the secret information only known to the honest worker.

The problem becomes how one can verify if a sequence of intermediate states (i.e., checkpoints of intermediate weights) came from training and aren’t random (or worse, forged by a malicious party). First, we need to know the inputs and (training) settings the prover claims to have used during training so that we can reasonably reproduce the training procedure. This involves obtaining the data batch ordering, learning rate/scheduler, loss function, and any other hyperparameters for training. Next, we need to reliably reproduce the states obtained originally during training despite various sources of randomness. The solution to this is simply performing step-wise verification.

Step-wise verification involves choosing a checkpoint \(w_t\) and computing the training update according to the corresponding data and hyperparameters to obtain the immediately following checkpoint \(w’_{t+1}\). Then, we compare the difference between the obtained \(w’_{t+1}\) and the claimed checkpoint for that step \(w_{t+1}\). If the difference is under some very small threshold \(\delta\) that bounds the stepwise noise, then we say that the specific step passed verification (if it isn’t under this threshold then with high probability the step was forged). In practice we calculate \(\delta\) to be two magnitudes smaller than the average difference between two independent models trained with the same hyperparameters, dataset, and architectures. This essential setup is what defines a proof-of-learning.

However, there is now another issue which is to verify every single step during training, a verifier would need to expend as much compute as when training, not to mention also having to store every single checkpoint during training. In response to this, we empirically found that one can simply store the weights of every \(k^{th}\) step. We also find that against the several spoofing procedures we considered to evaluate how an adversary could attempt to pass verification without a legitimately-obtained proof-of-learning, one only needs to verify the top one or two steps in every epoch (where the largest change in weights occurs) to see if the proof-of-learning is valid or not because valid updates tend to have small magnitude (to avoid overshooting during gradient descent). Together these two methods reduce the storage and computational cost associated with a proof-of-learning.

Conclusion

Overall, the proof-of-learning (PoL) concept serves to lay a foundation for future work on verifying the integrity of training procedures that draw intuition from cryptographic ideas. The specific verification procedure and training checkpointing outlined here present specific procedures for proof-of-learning but similar to methods found in cryptography, such as proof-of-work, proof-of-stake, etc.; the specifics of the verification procedure and proof elements are alterable and allow for further work to explore computationally cheaper ways to verify that a party is the originator of a trained model, or that workers in distributed settings are not corrupt in their returned model updates. As stated, however, future procedures for proof-of-learning could focus on cheaper computation for verification and lower storage costs for the prover.

Want to read more or try it out?

You can find more information in our IEEE S&P 2021 paper. Code for reproducing all our experiments is available in our GitHub repository.

References

[1] Stealing Machine Learning Models via Prediction APIs; Florian Tramèr, Fan Zhang, Ari Juels, Michael K. Reiter, Thomas Ristenpart; USENIX Security 2016

[2] High Accuracy and High Fidelity Extraction of Neural Networks; Matthew Jagielski, Nicholas Carlini, David Berthelot, Alex Kurakin, Nicolas Papernot; USENIX Security 2020