Projects
My current research and project mainly lies in the following two aspects: Theoretical foundation of deep/machine learning and Efficient learning algorithm.
Theoretical foundation of deep/machine learning
Deep learning explanation, convergence, and generalization analysis
Graph neural network learning theory
Fairness in deep/machine learning
Deep reinforcement learning theory
Matrix completion
Efficient learning algorithm
Neural network compression
Data compression
Efficient neural network architectures
Fewshot learning
Distributed machine learning
Foundation model
Research highlights
Neural network pruning [slides]
Discription: Neural network pruning is a technique used to reduce the size of a neural network by removing unnecessary connections and neurons while maintaining or improving its performance. This technique can help to make the network more sparse, leading to several advantages such as decreased computational cost, memory usage, energy consumption, and carbon footprint. Furthermore, recent numerical findings have indicated that a wellpruned neural network can exhibit improved test accuracy and faster convergence rates. This type of pruned network is commonly referred to as "winning tickets" within the context of the lottery ticket hypothesis (LTH), and numerical evidence suggests that magnitudebased pruning is effective in finding such "winning tickets". However, LTH and its relevant paper cannot explain the benefits of training the "winning tickets" and why using magnitudebased pruning approach can find the "winning tickets".
First, we provide the theoretical guarantees for using magnitudebased pruning in finding the "winning ticket". Specifically, we proves that the neuron weights that learns classirrevelant features, e.g., background/noise features, tend to have a small magnitude. Therefore, removing the neuron weights with small magnitude will not change the expressive power of the neural network in learning good features (classrelevant features).
Second, we provide the theoretical explanations for a good pruning network in achieving improved test accuracy and accelerated convergence rate. Specifically, our analysis demonstrates that training the "winning ticket" corresponds to a wider and steeper convergence zone, which significantly diminishes the generalization gap, resulting in improved test accuracy. Consequently, training on a "winning ticket" necessitates fewer samples to achieve a sound initialization and ensure convergence, resulting in quicker convergence rates and reduced generalization errors for the algorithm.
Numerical verification for the improved performance by using magnitudebased pruning.
Selected Publications:
Shuai Zhang, Meng Wang, PinYu Chen, Sijia Liu, Songtao Lu, Miao Liu. “Joint EdgeModel Sparse Learning is Provably Efficient for Graph Neural Networks.” In International Conference on Learning Representations (ICLR), 2023. [pdf]
Shuai Zhang, Meng Wang, Sijia Liu, PinYu Chen, and Jinjun Xiong. “Why Lottery Ticket Wins? A Theoretical Perspective of Sample Complexity on Sparse Neural Networks.” In Proc. of the Thirtyfifth Conference on Neural Information Processing Systems (NeurIPS), 2021. [pdf]
Selftraining via unlabeled data (semisupervised learning) [slides]
Discription: Selftraining is a type of semisupervised learning approaches that combines labeled and unlabeled data to improve the accuracy of a model. It is useful in situations where obtaining labeled data is expensive or timeconsuming, but there is an abundance of unlabeled data available. In many realworld scenarios, labeled data can be scarce or difficult to obtain, e.g., medical images, and labeling large datasets can be expensive or timeconsuming, e.g., labeling ImageNet took almost 4 years with 49,000 workers from 167 countries. Most importantly, by incorporating unlabeled data, selftraining can improve the accuracy of a model beyond what is possible with purely labeled data. Despite the use of selftraining and deep learning in various studies, there is currently a lack of theoretical understanding regarding their integration and performance. In addition, certain numerical experiments have indicated that the nonlinear characteristics of neural networks could result in a decline in performance when utilizing selftraining. To address the disparity between numerical results and theoretical comprehension of selftraining, we present a convergence analysis of the selftraining algorithm, along with theoretical guidelines for selecting hyperparameters to ensure improved generalization using unlabeled data. Our specific contributions include
First, we provide quantitative justification of generalization improvement by using unlabeled data.. Specifically, we prove that the improved generalization is a linear function of 1/\sqrt{M}, where M is the number of unlabeled data.
Second, we provide analytical justification for the hyperparameter selection in guarantee improved performance when using unlabeled data.. Our analysis focuses on the impact of the weighted loss parameter, denoted as \lambda. Lowering the value of \lambda results in a decrease in performance, while increasing it too much can cause the algorithm to diverge. The optimal value of \lambda is approximately \sqrt{N}/(\sqrt{M}+\sqrt{N}), where M and N represent the number of unlabeled and labeled data, respectively. Alternatively, one can progressively raise \lambda until the point where the algorithm diverges.
Third, we provide the convergence and sample complexity analysis for learning a proper model. We quantify the impact of labeled and unlabeled data on the generalization of the learned model, and we show that the iterations converges to the desired model with a characterizable bound with sufficient large number of unlabeled data.
Selected Publications:
Shuai Zhang, Meng Wang, Sijia Liu, PinYu Chen, and Jinjun Xiong. “How unlabeled data improve generalization in selftraining? A onehiddenlayer theoretical analysis.” In Proc. of The Tenth International Conference on Learning Representations (ICLR), 2022. [pdf]
Graph Neural Networks [slides]
Discription: Graph neural networks (GNNs) are a class of deep learning models that are designed for data learning with graph structured data. Examples of data that can be represented as graphs include social networks, biological networks, recommendation systems, communication networks, Internet of Things (IoTs) networks, and transporation networks.
Compared with traditional neural networks, GNNs propagate information across the graph through a series of neural network layers. At each layer, each node aggregates information from its neighboring nodes and updates its own representation accordingly. The aggregation and update rules are learned by the model during training, allowing it to capture complex patterns and dependencies in the graph. GNNs have been shown to be effective in a wide range of applications, including node classification, link prediction, and graph classification.
However, Training and inference of graph neural networks (GNNs) suffer from high computational costs, which prevents GNNs from being scaled up to largescale realworld graph applications.
On the one hand, the computational complexity for processing the entire graph grows as an exponential function as the size of the dataset. As a comparison, the computational complexity of a GNN is much higher than some popular convolutional neural networks (CNNs) with model sizes 100x larger than the GNN.
On the other hand, directly adopting GPUs on large graphs remains challenging because GPU memory capacity is limited and can be insufficient. As an example, modern graph benchmark datasets, e.g., OBGNprotein, can take up to 350 gigabytes (GBs) of memory, which requires multiple GPUs at an estimated cost of $400K.
First, my research focuses on the generalizability and convergence analysis of implementing graph neural networks in graph structured data learning.
Second, my research focuses on developing graph sparsification methods to reduce the computational cost and sample complexity in training graph neural networks with theoretical guarantees. For example, our proposed joint edgemodel sparsification algorithm achieves a similar performance as training on the original data and model with a significant reducation at the computation cost (number near the data point denotes the required MACs).
Selected Publications:
Shuai Zhang, Meng Wang, PinYu Chen, Sijia Liu, Songtao Lu, Miao Liu. “Joint EdgeModel Sparse Learning is Provably Efficient for Graph Neural Networks.” In International Conference on Learning Representations (ICLR), 2023. [pdf]
Shuai Zhang, Meng Wang, Sijia Liu, PinYu Chen, and Jinjun Xiong. “Fast Learning of Graph Neural Networks with Guaranteed Generalizability: One hiddenlayer Case.” In Proc. of 2020 International Conference on Machine Learning (ICML), pp. 1126811277. PMLR, 2020. [pdf]
Lowrank Hankel Matrix Completion [slides]
Discription: Given given partially observed data, we need to recover original data by filling in missing entries and removing outliers. This problem is common in various fields such as recommendation systems, computer vision, and signal processing. To illustrate, imagine a data matrix representing the renewable energy output of multiple solar arrays over time. The timeseries data for each array may be affected by a few shared factors, like temperature, wind speed, and UV index. As there are more data points than factors, it suggests that the data matrix is lowrank. Being able to fill in missing data or predict future outputs is crucial for energy systems planning and control. However, existing algorithms require at least one observation in each column, making it unfeasible to apply them in practical scenarios where all data is lost or inaccessible, such as prediction tasks. To solve this issue, we utilize the Hankel matrix to capture the temporal correlation across the data. We propose nonconvex algorithms by exploring the structured Hankel matrix, and the algorithms achieve reduced sample complexity and less computational time with theoretical guarantees. Our specific contributions include
First, we propose a nonconvex approach with convergence analysis for data satisfying lowrank Hankel property, e.g., video processing, image superresolution, medical image reconstruction, directionofarrival estimation, and linear dynamical systems. For a data matrix satisfying lowrank Hankel property with rank r, our algorithm converges to the ground truth with exponential decay, while existing convex approximation approaches only enjoy a sublinear converge rate.
Second, Our algorithm is capable of withstanding the loss of a constant fraction of columns, whereas traditional methods are unable to cope with even a single lost column. For a data matrix satisfying lowrank Hankel property with rank r, our algorithm can tolerate up to 1/r fraction of columnwise lost or corruptions.
Third, Our algorithm requires less number of samples and less computational time to recover the ground truth. For a timeseries matrix with dimension of n times n satisfying lowrank Hankel property, we can save the sample complexity by a fraction of 1/n and the computational complexity by a fraction of 1/n per iteration compared with traditional appraoches.
Selected Publications:
Shuai Zhang, Meng Wang, Sijia Liu, PinYu Chen, and Jinjun Xiong. Shuai Zhang, and Meng Wang. “Correction of corrupted columns through fast robust Hankel matrix completion.” IEEE Transactions on Signal Processing (TSP), no. 10: 25802594. IEEE, 2019. [pdf]
Shuai Zhang, Yingshuai Hao, Meng Wang, and Joe H. Chow. “Multichannel Hankel matrix completion through nonconvex optimization.” IEEE Journal of Selected Topics in Signal Processing (JSTSP), no. 4: 617632. IEEE, 2018. [pdf]
