Kolmogorov-Arnold-Networks in Java
Reimplementing Kolmogorov-Arnold-Networks from scratch in pure Java.
Introduction
Kolmogorov-Arnold Networks (KANs) are a new type of artificial neural network that has gained significant attention in the machine learning community. In this post, I will present my implementation of KANs in plain Java, aiming to make this technology more accessible.
I will walk you through my Java implementation, highlighting key differences from the original paper, and demonstrating its ability to learn complex functions. While this implementation is not yet production-ready and lacks some features like backpropagation and visualization, it provides a solid foundation for understanding and exploring KANs.
What are KANs?
KANs are a new form of Artificial Neural Network (ANN). A KAN is an ANN where each node has a learnable activation function, implemented as a B-spline. The edges do not have weights and the nodes do not have biases. The preprint has garnered a significant amount of attention, with 60 citations in the two months since its publication. The authors highlight several advantages of this new network type over Multi-Layer Perceptrons (MLPs), including smaller network size, higher interpretability, and immunity to catastrophic forgetting. For a full description refer to the original paper1, which is helpful to understand the high-level concepts but is sparse on details. If you require a summary of the paper, there is abundance of posts to which you can refer234. For the rest of the post, I will assume that you are familiar with KANs.
Why Re-implement KANs in Java?
The original paper gives a high-level description only, the details have to be gleaned from the source code. However, the source code is hard to read, and I aim to make it more accessible.
It was an important learning experience for me. I often have a much easier time understanding source code than mathematical formulae.
I want to do more research into the properties of KANs and ANNs, for which it is helpful to have a simple codebase with full control.
Description
The source code consists of 180 lines of code and can be found at 5. Functionality is grouped into classes.The code lacks comments. However, the variable, method, and class names are self-explanatory, and the codebase is very small.
The Application
class is the main application. In this class, you can perform various operations on and around the KAN: loading training data, generating randomized networks, training the networks, and using them. KanNetwork.toString()
output produces a string representation that can be used as valid Java code to recreate the original value. The BSplines
class handles all operations related to splines and their calculation. Each network layer consists of a matrix of B-splines. A KanNetwork
the Network
itself (composed of Layer
s) and methods to calculate the network’s output for a given input. The Training
class contains the code related to training. Currently, it uses a very simple genetic algorithm, which will be replaced with back-propagation in the future
There are some important differences between this implementation and the original paper:
No back-propagation: The main reason for this is that I do not understand how it was implemented in the original paper. Back-propagation is essential for the viability of ANNs in general, so my Java implementation cannot be put to sensible use yet.
No visualization: A unique feature of KANs is their interpretability, for which visualizations are important. I decided to skip this part and keep the code simple, because it is not needed immediately.
Initialization is currently randomized. There are more efficient ways to initialize ANNs.
Also note, that the network currently is not performant, usable as a library, or production-ready.
The result
To demonstrate the viability of the network, I trained it on the function
using a [2, 4, 4, 1, 1] architecture with multiple sets of about 10000 training samples.
The system was able to learn the function successfully. This is an amazing feat and shows that KANs as presented in the original paper can be implemented concisely.
With the simple training algorithm it took a long time to train (approximately 10 hours). I cancelled the training when the plot looked satisfactory to prove that the system is working. The result is off by about 0.82 on average, and also it looks "wavy" which indicates over-fitting. Further training would improve the result.
Synopsis
In this post, I have presented my Java implementation of Kolmogorov-Arnold Networks (KANs), a novel type of artificial neural network. The rudimentary implementation has demonstrated the ability to learn complex functions. Importantly, this implementation has led me to a significant insight: KANs are not as revolutionary as initially thought. In fact, they are functionally equivalent to Multi-Layer Perceptrons (MLPs). I will explore this equivalence in detail in my next post and open up new avenues for understanding and potentially enhancing both KANs and traditional neural network architectures.