TPU without Estimator

Varun Patil
4 min readJun 22, 2018

If you’re using Tensorflow to train a deep learning model that takes hours to run on good hardware, chances are that you’ve heard of Google’s latest cloud offering — the Tensor Processing Unit (TPU). According to Google, these ASICs can be 15x to 30x faster than the hardware you are currently using.

However, if you glance at the documentation for using cloud TPUs with TensorFlow, you will find that Google always refers you to use the Estimator API, which is a high level API built on top of the lower level Graph and Session APIs. While it is advisable using this API since it probably will perform better than your code in low level APIs, being written with things such as TPUs in mind, but there are still situations where you may want a performance boost on your old code that uses the old low level APIs. So I’m going to quickly summarise below how I ran my old model with only a few lines of changes on a cloud TPU.

Note: If you’re not familiar with the basics of Google Cloud, I encourage reading my post on Google Cloud Basics and/or other documentation/tutorials first.

DISCLAIMER: The author of this article is a 19 year old kid messing around with cloud infrastructure, i.e. me. I’m not responsible for anything bad that happens due to you doing any of this. Reader discretion is advised :)

I will be using a toy problem pragarized from here to demonstrate how we can make it use TPUs with just a few additionaly lines of code. So here is a simple neural network. Note that I’m using TensorFlow 1.8, the latest version as of writing this.

As a sanity check, you might want to check if this (or your existing) model is running on CPU/GPU first. Once that is working, the first thing you need is a link pointing to the TPU. This can be obtained by adding the following at the top of the file, replacing tpu_name with the name of the TPU you created with ctpu. You probably want to refactor this to use the TPU_NAME environment variable, since I believe Google sets this for you if you do everything right, and it becomes easier to switch between TPUs this way. Note that this requires your compute instance and the TPU to be in the same region.

As of now, this is still doing nothing, so you need to pass this link to tf.Session() as the target argument (which is also the first). Another thing to be done is to initialize the TPU system when the session is created and clean up when you're done.

And that’s it! On running the model now, it should train on the TPU if it exists in the network. To verify it actually worked, you can make three checks:

  • CPU usage when not on TPU should be significantly higher.
  • Your cloud console should show a minor CPU usage for your TPU (0.9% in my case).
  • It might run slower (if you’re running the code above). Since this is just a toy problem, the network latency and other overhead of transferring information between the instance and the TPU probably becomes the bottleneck here.

However, doing the above will probably ensure you are not using all shards or the entire computing power of the TPU. Again, if you want to be foolproof about this, you should go for the Estimator API. Still, one thing that could possibly work (only speculation henceforth) is to use the CrossShardOptimizer wrapper around your optimizer. This should make your optimizer look something like

However, as soon as you do this, you should have a warning like

WARNING:tensorflow:CrossShardOptimizer should be used within a tpu_shard_context, but got unset number_of_shards. Assuming 1.

So we are using only one shard of eight. One really ugly way (which touches the internal APIs of TensorFlow) I found to fix the warning (emphasis on this since I still don’t know if it really does use all shards after this), is to exlicitly set the number of shards as

I don’t recommend this at all, but it might work. Do let me know if you actually do this and it does/doesn’t

Note that since you’re doing everything yourself, you need to make sure of a couple of things. Firstly, you cannot write to local storage from TPU, so you need to either comment out all writes or use a cloud bucket for this. Secondly, do a sanity check first, since code might not behave the same and a couple of changes here and there might be required for everything to work (again, mostly related to files).

So finally, you have your TensorFlow model with the low level APIs running on TPU. Hopefully, we will have some official instruction on how to do this once TPUs are no longer in beta, since after all, these APIs are far from deprecated.

Happy Training!

Originally published at https://pulsejet.github.io on June 22, 2018.

--

--