My current research and project mainly lies in the following two aspects: Theoretical foundation of deep/machine learning and Efficient learning algorithm.

Theoretical foundation of deep/machine learning

  • Deep learning explanation, convergence, and generalization analysis

  • Graph neural network learning theory

  • Fairness in deep/machine learning

  • Deep reinforcement learning theory

  • Matrix completion

Efficient learning algorithm

  • Neural network compression

  • Data compression

  • Efficient neural network architectures

  • Few-shot learning

  • Distributed machine learning

  • Foundation model

Research highlights

Neural network pruning [slides]

Discription: Neural network pruning is a technique used to reduce the size of a neural network by removing unnecessary connections and neurons while maintaining or improving its performance. This technique can help to make the network more sparse, leading to several advantages such as decreased computational cost, memory usage, energy consumption, and carbon footprint. Furthermore, recent numerical findings have indicated that a well-pruned neural network can exhibit improved test accuracy and faster convergence rates. This type of pruned network is commonly referred to as "winning tickets" within the context of the lottery ticket hypothesis (LTH), and numerical evidence suggests that magnitude-based pruning is effective in finding such "winning tickets". However, LTH and its relevant paper cannot explain the benefits of training the "winning tickets" and why using magnitude-based pruning approach can find the "winning tickets".

  • First, we provide the theoretical guarantees for using magnitude-based pruning in finding the "winning ticket". Specifically, we proves that the neuron weights that learns class-irrevelant features, e.g., background/noise features, tend to have a small magnitude. Therefore, removing the neuron weights with small magnitude will not change the expressive power of the neural network in learning good features (class-relevant features).
  • Second, we provide the theoretical explanations for a good pruning network in achieving improved test accuracy and accelerated convergence rate. Specifically, our analysis demonstrates that training the "winning ticket" corresponds to a wider and steeper convergence zone, which significantly diminishes the generalization gap, resulting in improved test accuracy. Consequently, training on a "winning ticket" necessitates fewer samples to achieve a sound initialization and ensure convergence, resulting in quicker convergence rates and reduced generalization errors for the algorithm.
  • Numerical verification for the improved performance by using magnitude-based pruning.
  • Selected Publications:

    1. Shuai Zhang, Meng Wang, Pin-Yu Chen, Sijia Liu, Songtao Lu, Miao Liu. “Joint Edge-Model Sparse Learning is Provably Efficient for Graph Neural Networks.” In International Conference on Learning Representations (ICLR), 2023. [pdf]

    2. Shuai Zhang, Meng Wang, Sijia Liu, Pin-Yu Chen, and Jinjun Xiong. “Why Lottery Ticket Wins? A Theoretical Perspective of Sample Complexity on Sparse Neural Networks.” In Proc. of the Thirty-fifth Conference on Neural Information Processing Systems (NeurIPS), 2021. [pdf]

    Self-training via unlabeled data (semi-supervised learning) [slides]

    Discription: Self-training is a type of semi-supervised learning approaches that combines labeled and unlabeled data to improve the accuracy of a model. It is useful in situations where obtaining labeled data is expensive or time-consuming, but there is an abundance of unlabeled data available. In many real-world scenarios, labeled data can be scarce or difficult to obtain, e.g., medical images, and labeling large datasets can be expensive or time-consuming, e.g., labeling ImageNet took almost 4 years with 49,000 workers from 167 countries. Most importantly, by incorporating unlabeled data, self-training can improve the accuracy of a model beyond what is possible with purely labeled data. Despite the use of self-training and deep learning in various studies, there is currently a lack of theoretical understanding regarding their integration and performance. In addition, certain numerical experiments have indicated that the non-linear characteristics of neural networks could result in a decline in performance when utilizing self-training. To address the disparity between numerical results and theoretical comprehension of self-training, we present a convergence analysis of the self-training algorithm, along with theoretical guidelines for selecting hyperparameters to ensure improved generalization using unlabeled data. Our specific contributions include

  • First, we provide quantitative justification of generalization improvement by using unlabeled data.. Specifically, we prove that the improved generalization is a linear function of 1/\sqrt{M}, where M is the number of unlabeled data.
  • Second, we provide analytical justification for the hyperparameter selection in guarantee improved performance when using unlabeled data.. Our analysis focuses on the impact of the weighted loss parameter, denoted as \lambda. Lowering the value of \lambda results in a decrease in performance, while increasing it too much can cause the algorithm to diverge. The optimal value of \lambda is approximately \sqrt{N}/(\sqrt{M}+\sqrt{N}), where M and N represent the number of unlabeled and labeled data, respectively. Alternatively, one can progressively raise \lambda until the point where the algorithm diverges.
  • Third, we provide the convergence and sample complexity analysis for learning a proper model. We quantify the impact of labeled and unlabeled data on the generalization of the learned model, and we show that the iterations converges to the desired model with a characterizable bound with sufficient large number of unlabeled data.
  • Selected Publications:

    1. Shuai Zhang, Meng Wang, Sijia Liu, Pin-Yu Chen, and Jinjun Xiong. “How unlabeled data improve generalization in self-training? A one-hidden-layer theoretical analysis.” In Proc. of The Tenth International Conference on Learning Representations (ICLR), 2022. [pdf]

    Graph Neural Networks [slides]

    Discription: Graph neural networks (GNNs) are a class of deep learning models that are designed for data learning with graph structured data. Examples of data that can be represented as graphs include social networks, biological networks, recommendation systems, communication networks, Internet of Things (IoTs) networks, and transporation networks.

    Compared with traditional neural networks, GNNs propagate information across the graph through a series of neural network layers. At each layer, each node aggregates information from its neighboring nodes and updates its own representation accordingly. The aggregation and update rules are learned by the model during training, allowing it to capture complex patterns and dependencies in the graph. GNNs have been shown to be effective in a wide range of applications, including node classification, link prediction, and graph classification. However, Training and inference of graph neural networks (GNNs) suffer from high computational costs, which prevents GNNs from being scaled up to large-scale real-world graph applications. On the one hand, the computational complexity for processing the entire graph grows as an exponential function as the size of the dataset. As a comparison, the computational complexity of a GNN is much higher than some popular convolutional neural networks (CNNs) with model sizes 100x larger than the GNN. On the other hand, directly adopting GPUs on large graphs remains challenging because GPU memory capacity is limited and can be insufficient. As an example, modern graph benchmark datasets, e.g., OBGN-protein, can take up to 350 gigabytes (GBs) of memory, which requires multiple GPUs at an estimated cost of $400K.

  • First, my research focuses on the generalizability and convergence analysis of implementing graph neural networks in graph structured data learning.
  • Second, my research focuses on developing graph sparsification methods to reduce the computational cost and sample complexity in training graph neural networks with theoretical guarantees. For example, our proposed joint edge-model sparsification algorithm achieves a similar performance as training on the original data and model with a significant reducation at the computation cost (number near the data point denotes the required MACs).
  • Selected Publications:

    1. Shuai Zhang, Meng Wang, Pin-Yu Chen, Sijia Liu, Songtao Lu, Miao Liu. “Joint Edge-Model Sparse Learning is Provably Efficient for Graph Neural Networks.” In International Conference on Learning Representations (ICLR), 2023. [pdf]

    2. Shuai Zhang, Meng Wang, Sijia Liu, Pin-Yu Chen, and Jinjun Xiong. “Fast Learning of Graph Neural Networks with Guaranteed Generalizability: One hidden-layer Case.” In Proc. of 2020 International Conference on Machine Learning (ICML), pp. 11268-11277. PMLR, 2020. [pdf]

    Low-rank Hankel Matrix Completion [slides]

    Discription: Given given partially observed data, we need to recover original data by filling in missing entries and removing outliers. This problem is common in various fields such as recommendation systems, computer vision, and signal processing. To illustrate, imagine a data matrix representing the renewable energy output of multiple solar arrays over time. The time-series data for each array may be affected by a few shared factors, like temperature, wind speed, and UV index. As there are more data points than factors, it suggests that the data matrix is low-rank. Being able to fill in missing data or predict future outputs is crucial for energy systems planning and control. However, existing algorithms require at least one observation in each column, making it unfeasible to apply them in practical scenarios where all data is lost or inaccessible, such as prediction tasks. To solve this issue, we utilize the Hankel matrix to capture the temporal correlation across the data. We propose non-convex algorithms by exploring the structured Hankel matrix, and the algorithms achieve reduced sample complexity and less computational time with theoretical guarantees. Our specific contributions include

  • First, we propose a non-convex approach with convergence analysis for data satisfying low-rank Hankel property, e.g., video processing, image super-resolution, medical image reconstruction, direction-of-arrival estimation, and linear dynamical systems. For a data matrix satisfying low-rank Hankel property with rank r, our algorithm converges to the ground truth with exponential decay, while existing convex approximation approaches only enjoy a sub-linear converge rate.
  • Second, Our algorithm is capable of withstanding the loss of a constant fraction of columns, whereas traditional methods are unable to cope with even a single lost column. For a data matrix satisfying low-rank Hankel property with rank r, our algorithm can tolerate up to 1/r fraction of column-wise lost or corruptions.
  • Third, Our algorithm requires less number of samples and less computational time to recover the ground truth. For a time-series matrix with dimension of n times n satisfying low-rank Hankel property, we can save the sample complexity by a fraction of 1/n and the computational complexity by a fraction of 1/n per iteration compared with traditional appraoches.
  • Selected Publications:

    1. Shuai Zhang, Meng Wang, Sijia Liu, Pin-Yu Chen, and Jinjun Xiong. Shuai Zhang, and Meng Wang. “Correction of corrupted columns through fast robust Hankel matrix completion.” IEEE Transactions on Signal Processing (TSP), no. 10: 2580-2594. IEEE, 2019. [pdf]

    2. Shuai Zhang, Yingshuai Hao, Meng Wang, and Joe H. Chow. “Multichannel Hankel matrix completion through nonconvex optimization.” IEEE Journal of Selected Topics in Signal Processing (JSTSP), no. 4: 617-632. IEEE, 2018. [pdf]