# Understand Recursion & Memoization

**Recursion** is a technique to solve problems by

1. dividing the original problem into one or more similar sub-problems,
    
2. solving the sub-problems by **calling the same method again**, and then
    
3. using the solutions of sub-problems to get the solution to the actual problem.
    

Sometimes, recursion can be slow as we might end up solving the same problem multiple times. To avoid this, we use **memoization**. The concept is to solve the problem once and remember the solution to the problem. And in the future, if we encounter the same problem, we just return the solution, instead of solving it again.

## When to use recursion?

When you think the problem can be broken down into smaller sub-problems and their solution can be used to solve the actual problem, you can use recursion. And if using recursion is slow (in many cases), you can use memoization with recursion.

Although many problems which can be solved by recursion can be solved without recursion as well, using recursion is more readable and expressive.

## How to use recursion?

In programming, recursion basically means a function or method calling itself. In general, while calling itself the input to the function is changed, or it might end up calling itself infinite times and you will get the stack-overflow error.

Let's take some examples to understand it better.

## Factorial (n!)

[Factorial](https://en.wikipedia.org/wiki/Factorial) of a non-negative integer `n`, is the product of all positive integers less than or equal to `n`. That means,

`fact(n) = 1 X 2 X ... X (n - 1) X n`

Let's solve this problem using 2 techniques - iteration and recursion.

### Factorial using iteration

We can simply iterate from `1` to `n` and multiply each number to get the value of `fact(n)`.

```java
public static int fact(int n) {
    int answer = 1;
    for (int i = 1; i <= n; i++) {
        answer = answer * i;
    }
    return answer;
}
```

This works fine and is readable as well. Still, we can use recursion to calculate `fact(n)`. That will be more readable.

### Factorial using recursion

To calculate the value of `fact(n)` using recursion, we need to break this problem into smaller sub-problem. If we pay close attention, we see that:

```plaintext
fact(n) = n x (n - 1) x ... x 2 x 1
fact(n) = n x fact(n - 1)
```

Here, we break down the problem of calculating the factorial of n into calculating the factorial of n - 1, which is a subproblem of the original problem. So, we can use recursion to solve this problem. Let's look at the code:

```java
public static int fact(int n) {
    return n * fact(n - 1);
}
```

Here, we call the `fact` method again, but with the argument `n - 1` . We assume that it will give us the correct value `(n - 1)!` and then we multiply this value with `n` to get `n!`. Let's call this method from the `main` method with `n = 5`.

```java
class Main {
    public static int fact(int n) {
        return n * fact(n - 1);
    }
    public static void main(String[] args) {
        System.out.println(fact(5));
    }
}
```

We encountered an error that says:

```java
Exception in thread "main" java.lang.StackOverflowError
```

This is because we are calling `fact` method infinite times.

```java
fact(5) = 5 * fact(4)
        = 5 * 4 * fact(3)
        = 5 * 4 * 3 * fact(2)
        = 5 * 4 * 3 * 2 * fact(1)
        = 5 * 4 * 3 * 2 * 1 * fact(0)
        = 5 * 4 * 3 * 2 * 1 * 0 * fact(-1)
        = 5 * 4 * 3 * 2 * 1 * 0 * -1 * fact(-2)
... and so on
```

Since we have not specified when to stop, we end up having too many recursive calls of `fact` function to be handled by computer memory. We need to tell the `fact` method when to stop recursion or what is the **base case**. For the factorial function, the base case is `n == 0`, since `0! = 1` . Let's rewrite the `fact` method again so that it returns `1` for `n = 0` .

```java
public static int fact(int n) {
    if (n == 0) return 1;
    return n * fact(n - 1);
}
```

Now, run the `main` method and we see the correct output:

`120`

Here is how it works:

```java
fact(5) = 5 * fact(4)
        = 5 * (4 * fact(3))
        = 5 * (4 * (3 * fact(2)))
        = 5 * (4 * (3 * (2 * fact(1))))
        = 5 * (4 * (3 * (2 * (1 * fact(0)))))
        = 5 * (4 * (3 * (2 * (1 * 1))))
        = 5 * (4 * (3 * (2 * 1)))
        = 5 * (4 * (3 * 2))
        = 5 * (4 * 6)
        = 5 * 24
        = 120
```

Now, let's take another example.

## Fibonacci sequence

The [Fibonacci sequence](https://en.wikipedia.org/wiki/Fibonacci_sequence) is a sequence in which each number is the sum of the two preceding ones, that means:

`fib(n) = fib(n - 1) + fib(n - 2)`, where `fib(0) = 0`, `fib(1) = 1`.

The task here is to calculate the `nth` Fibonacci number (`fib(n)`)We won't look at the iterative way to solve this problem. We directly jump to a recursive solution.

### Fibonacci Number using recursion

Since we know the base case and we know how to divide the actual problem into smaller problems, we can easily write the method to do that recursively:

```java
public static int fib (int n) {
    if (n == 0 || n == 1) return n;
    return fib (n - 1) + fib (n - 2);
}
```

The above solution is clean and accurate, but there is a problem with that. Let's look at the calls being made to `fib` method when we call `fib(6)`.

![Image showing calls made to fib method to calculate the 6th fibonacci number](https://cdn.hashnode.com/res/hashnode/image/upload/v1684569736956/94f9b540-074b-4b5f-89c6-93a5954b64a2.png align="center")

Referring to the above image, let's answer some questions.

1. How many times do we calculate the values of `fib(6)`?
    
    Answer: Once
    
2. How many times do we calculate the value of `fib(5)`?  
    Answer: Once
    
3. How many times do we calculate the value of `fib(4)`?
    
    Answer: 2 times
    
4. How many times do we calculate the value of `fib(3)`?  
    Answer: 3 times
    
5. How many times do we calculate the value of `fib(2)`?  
    Answer: 5 times
    
6. How many times do we calculate the value of `fib(1)`?
    
    Answer: 8 times
    
7. How many times do we calculate the value of `fib(0)`?
    
    Answer: 5 times
    

As we can see we are calculating `fib(x)` multiple times for many values of `x`, which makes the method run very slow. To avoid calculating the same value multiple times, we can use memoization.

### Fibonacci Number using Memoization

Basically, we calculate the value once and remember it for the lifetime of the program. And to remember the value, we can use `Map` or any other similar data structure. While calculating the value of `fib(n)`, we first check in the map if we have previously calculated the value of `fib` for the given `n`. If yes, we return the value. Otherwise, we calculate the value of `fib(n)` using recursion and store the calculated value in the map. So that, next time we get the value of `fib` for `n`. Let's implement it in code:

```java
import java.util.Map;
import java.util.HashMap;

class Main {
    private static final Map<Integer, Integer> map = new HashMap<>();
    public static int fib (int n) {
        if (n == 0 || n == 1) return n;
        if (map.containsKey (n)) return map.get(n);
        int answer = fib (n - 1) + fib (n - 2);
        map.put (n, answer);
        return answer;
    }
    ...
}
```

In the above code snippet, we are checking if the `map` has `n` as key (if it does, that means, we have calculated `fib(n)` and it is stored in the `map`). If yes, then we return the corresponding value. Otherwise, we calculate `fib(n)` using recursive formula and store the `(n, fib(n))` key-value pair in the map for future reference.

That's how recursion and memoization work. Let's take a final example and solve it using iteration, recursion, and memoization.

## Pascal's triangle

It's hard for me to write the definition of [Pascal's Triangle](https://en.wikipedia.org/wiki/Pascal%27s_triangle) formally, so I am attaching 2 images of it from Wikipedia:

![Pascal's Triangle](https://cdn.hashnode.com/res/hashnode/image/upload/v1684576488713/2d237446-f1b4-4437-bdab-1b52a8d31171.png align="center")

![In Pascal's triangle, each number is the sum of the two numbers directly above it.](https://cdn.hashnode.com/res/hashnode/image/upload/v1684571709527/9574dbcc-f2ea-457b-8bf5-097a81649383.gif align="center")

In programming, it is represented as a matrix or 2-d array:

![Pascal's Triangle](https://cdn.hashnode.com/res/hashnode/image/upload/v1684572510970/3c0db4ed-6a9b-435b-a785-50fd02d98c2b.png align="center")

We can see, the value of Pascal's triangle, which can be represented as `P(r, c)` for row `r` and column `c` is (`r >= c`)

* `1` if `c = 0`
    
* `1` if `r = c`
    
* `P(r - 1, c - 1) + P(r - 1, c)`, else
    

Now, the problem is that we are given an integer `n`. And we need to print Pascal's Triangle till the row `n`. (The first row is considered as row `0`).

### Iterative way to Pascal's Triangle

Let's create a 2-d array `P` with `n + 1` rows and `n + 1` columns. For each row `r` the first column (`c = 0`) and the last column (`c = r`) is `1`. The rest can be calculated using the formula above. Let's write code for it.

```java
private static void printPascalTriangle (int n) {
    int[][] P = new int[n + 1][n + 1];
    for (int r = 0; r <= n; r++) {
        P[r][0] = 1;
        P[r][r] = 1;
        for (int c = 1; c < r; c++) {
            P[r][c] = P[r - 1][c - 1] + P[r - 1][c];
        }
    }
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(P[r][c] + " ");
        }
        System.out.println();
    }
}
```

This method prints Pascal's Triangle from row `0` to row `n`. Now, let's solve it using recursion.

### Iterative way to Pascal's Triangle

Let's create a method `pascalNumber(int r, int c)` that gives the number at the row `r` and column `c` in the Pascal's Triangle. We already have a recursive formula to calculate this with base conditions from the definition of Pascal's Triangle. So, it is quite easy to implement this method.

```java
private static int pascalNumber (int r, int c) {
    if (c == 0) return 1;
    if (r == c) return 1;
    return pascalNumber(r - 1, c - 1) + pascalNumber(r - 1, c);
}

private static void printPascalTriangle (int n) {
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(pascalNumber(r, c) + " ");
        }
        System.out.println();
    }    
}
```

In the above code snippet, we are making some recursive calls to solve sub-problems. But, we will end up solving the same problem again and again and the program will become too slow. So, let's use memoization to improve the performance.

### Memoization way to Pascal's Triangle

To store the calculated value for `(r, c)` pair, we can use a 2-d array `P`. We initialize `P` in `printPascalTriangle` method. If the value of `P[r][c]` is zero, that means we have not calculated the value for `(r, c)` yet, and we calculate this value using recursion and store it. Otherwise, if it is not zero, we simply return it.

```java
private static int[][] P;
private static int pascalNumber (int r, int c) {
    if (P[r][c] != 0) return P[r][c];
    if (c == 0) return 1;
    if (r == c) return 1;
    P[r][c] = pascalNumber(r - 1, c - 1) + pascalNumber(r - 1, c);
    return P[r][c];
}

private static void printPascalTriangle (int n) {
    P = new int[n + 1][n + 1];
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(pascalNumber(r, c) + " ");
        }
        System.out.println();
    }
}
```

This way we don't calculate the same value again and again. This improves the time complexity of the method.

## Conclusion

I hope you have learned something new from this article. The source code can be found [here](https://github.com/java-rush/java10daysChallenge/tree/main/recursion-and-memoization). Please, give your suggestions and subscribe to my newsletter. It will be a great help. Thank you for your time. Will meet soon.
