Swap k-th Nodes in Linked List

Last Updated : 28 Aug, 2025

Given a head of a singly linked list, swap kth node from beginning with kth node from end.

Example: 

Input: head = 5 -> 10 -> 8 -> 5 -> 9 -> 3, k = 2

1-

Output: 5 -> 9 -> 8 -> 5 -> 10 -> 3 

2

Explanation: The 2nd node from 1st is 10 and 2nd node from last is 9, so swap them.

Input: head = 1 -> 2 -> 3 -> 4 -> 5, k = 1

3-

Output: 5 -> 2 -> 3 -> 4 -> 1

4

Explanation: Value of k is 1 so that node from start and end is swaped.

Try It Yourself
redirect icon

[Approach] Two Pointers + Pointer Manipulation on Linked List - O(n) Time and O(1) Space

The idea is to first count the total length of the linked list, then locate the k-th node from the beginning and the k-th node from the end using two separate traversals. Once both nodes (and their previous nodes) are found, we adjust the links of their previous nodes and swap their next pointers.

Additional if the k-th node from the beginning is the head, then after the swap the head should be updated to point to the k-th node from the end. Similarly, if the k-th node from the end is the tail, then after the swap the tail should be updated to point to the k-th node from the beginning.

C++
#include <iostream>
using namespace std;

// Node class for singly linked list
class Node {
public:
    int data;
    Node* next;
    Node(int d) {
        data = d;
        next = nullptr;
    }
};

Node* swapKth(Node* head, int k) {
    if (!head) return head;

    // Count length
    int n = 0;
    Node* temp = head;
    while (temp) {
        n++;
        temp = temp->next;
    }

    // if k is more than length, no swap
    if (k > n) return head;

    // if kth from start and end are same, no swap
    if (2 * k - 1 == n) return head;

    // find kth node from start and its prev
    Node* prevX = nullptr;
    Node* x = head;
    for (int i = 1; i < k; i++) {
        prevX = x;
        x = x->next;
    }

    // find kth node from end and its prev
    Node* prevY = nullptr;
    Node* y = head;
    for (int i = 1; i < n - k + 1; i++) {
        prevY = y;
        y = y->next;
    }

    // adjust previous pointers
    if (prevX) prevX->next = y;
    if (prevY) prevY->next = x;

    // swap next pointers
    Node* tempNext = x->next;
    x->next = y->next;
    y->next = tempNext;

    // change head if needed
    if (k == 1) head = y;
    if (k == n) head = x;

    return head;
}

// utility to print list
void printList(Node* head) {
    while (head) {
        cout << head->data;
        if (head->next) cout << " -> ";
        head = head->next;
    }
    cout << endl;
}

int main() {
    
    Node* head = new Node(5);
    head->next = new Node(10);
    head->next->next = new Node(8);
    head->next->next->next = new Node(5);
    head->next->next->next->next = new Node(9);
    head->next->next->next->next->next = new Node(3);

    int k = 2;
    head = swapKth(head, k);

    printList(head);

    return 0;
}
Java
class Node {
    int data;
    Node next;

    Node(int d) {
        data = d;
        next = null;
    }
}

class GfG {

    static Node swapKth(Node head, int k) {
        if (head == null) return head;

        // Count length
        int n = 0;
        Node temp = head;
        while (temp != null) {
            n++;
            temp = temp.next;
        }

        // if k is more than length, no swap
        if (k > n) return head;

        // if kth from start and end are same, no swap
        if (2 * k - 1 == n) return head;

        // find kth node from start and its prev
        Node prevX = null;
        Node x = head;
        for (int i = 1; i < k; i++) {
            prevX = x;
            x = x.next;
        }

        // find kth node from end and its prev
        Node prevY = null;
        Node y = head;
        for (int i = 1; i < n - k + 1; i++) {
            prevY = y;
            y = y.next;
        }

        // adjust previous pointers
        if (prevX != null) prevX.next = y;
        if (prevY != null) prevY.next = x;

        // swap next pointers
        Node tempNext = x.next;
        x.next = y.next;
        y.next = tempNext;

        // change head if needed
        if (k == 1) head = y;
        if (k == n) head = x;

        return head;
    }

    // utility to print list
    static void printList(Node head) {
        while (head != null) {
            System.out.print(head.data);
            if (head.next != null) System.out.print(" -> ");
            head = head.next;
        }
        System.out.println();
    }

    public static void main(String[] args) {
        Node head = new Node(5);
        head.next = new Node(10);
        head.next.next = new Node(8);
        head.next.next.next = new Node(5);
        head.next.next.next.next = new Node(9);
        head.next.next.next.next.next = new Node(3);

        int k = 2;
        head = swapKth(head, k);

        printList(head);
    }
}
Python
# Node class for singly linked list
class Node:
    def __init__(self, d):
        self.data = d
        self.next = None


def swapKth(head, k):
    if not head:
        return head

    # Count length
    n = 0
    temp = head
    while temp:
        n += 1
        temp = temp.next

    # if k is more than length, no swap
    if k > n:
        return head

    # if kth from start and end are same, no swap
    if 2 * k - 1 == n:
        return head

    # find kth node from start and its prev
    prevX = None
    x = head
    for i in range(1, k):
        prevX = x
        x = x.next

    # find kth node from end and its prev
    prevY = None
    y = head
    for i in range(1, n - k + 1):
        prevY = y
        y = y.next

    # adjust previous pointers
    if prevX:
        prevX.next = y
    if prevY:
        prevY.next = x

    # swap next pointers
    tempNext = x.next
    x.next = y.next
    y.next = tempNext

    # change head if needed
    if k == 1:
        head = y
    if k == n:
        head = x

    return head


# utility to print list
def printList(head):
    while head:
        print(head.data, end="")
        if head.next:
            print(" -> ", end="")
        head = head.next
    print()


if __name__ == "__main__":
    head = Node(5)
    head.next = Node(10)
    head.next.next = Node(8)
    head.next.next.next = Node(5)
    head.next.next.next.next = Node(9)
    head.next.next.next.next.next = Node(3)

    k = 2
    head = swapKth(head, k)

    printList(head)
C#
using System;

class Node {
    public int data;
    public Node next;

    public Node(int d) {
        data = d;
        next = null;
    }
}

class GfG {

    static Node SwapKth(Node head, int k) {
        if (head == null) return head;

        // Count length
        int n = 0;
        Node temp = head;
        while (temp != null) {
            n++;
            temp = temp.next;
        }

        // if k is more than length, no swap
        if (k > n) return head;

        // if kth from start and end are same, no swap
        if (2 * k - 1 == n) return head;

        // find kth node from start and its prev
        Node prevX = null;
        Node x = head;
        for (int i = 1; i < k; i++) {
            prevX = x;
            x = x.next;
        }

        // find kth node from end and its prev
        Node prevY = null;
        Node y = head;
        for (int i = 1; i < n - k + 1; i++) {
            prevY = y;
            y = y.next;
        }

        // adjust previous pointers
        if (prevX != null) prevX.next = y;
        if (prevY != null) prevY.next = x;

        // swap next pointers
        Node tempNext = x.next;
        x.next = y.next;
        y.next = tempNext;

        // change head if needed
        if (k == 1) head = y;
        if (k == n) head = x;

        return head;
    }

    // utility to print list
    static void PrintList(Node head) {
        while (head != null) {
            Console.Write(head.data);
            if (head.next != null) {
                Console.Write(" -> ");
            }
            head = head.next;
        }
        Console.WriteLine();
    }

    public static void Main(string[] args) {
        
        Node head = new Node(5);
        head.next = new Node(10);
        head.next.next = new Node(8);
        head.next.next.next = new Node(5);
        head.next.next.next.next = new Node(9);
        head.next.next.next.next.next = new Node(3);

        int k = 2;
        head = SwapKth(head, k);

        PrintList(head);
    }
}
JavaScript
// Node class for singly linked list
class Node {
    constructor(d) {
        this.data = d;
        this.next = null;
    }
}

function swapKth(head, k) {
    if (!head) return head;

    // Count length
    let n = 0;
    let temp = head;
    while (temp) {
        n++;
        temp = temp.next;
    }

    // if k is more than length, no swap
    if (k > n) return head;

    // if kth from start and end are same, no swap
    if (2 * k - 1 === n) return head;

    // find kth node from start and its prev
    let prevX = null;
    let x = head;
    for (let i = 1; i < k; i++) {
        prevX = x;
        x = x.next;
    }

    // find kth node from end and its prev
    let prevY = null;
    let y = head;
    for (let i = 1; i < n - k + 1; i++) {
        prevY = y;
        y = y.next;
    }

    // adjust previous pointers
    if (prevX) prevX.next = y;
    if (prevY) prevY.next = x;

    // swap next pointers
    let tempNext = x.next;
    x.next = y.next;
    y.next = tempNext;

    // change head if needed
    if (k === 1) head = y;
    if (k === n) head = x;

    return head;
}

// utility to print list
function printList(head) {
    let res = [];
    while (head) {
        res.push(head.data);
        head = head.next;
    }
    console.log(res.join(" -> "));
}


// Driver Code
let head = new Node(5);
head.next = new Node(10);
head.next.next = new Node(8);
head.next.next.next = new Node(5);
head.next.next.next.next = new Node(9);
head.next.next.next.next.next = new Node(3);

let k = 2;
head = swapKth(head, k);

printList(head);

Output
5 -> 9 -> 8 -> 5 -> 10 -> 3
Comment