# Barren Plateau Analysis
This notebook focuses on the analysis of the Barren Plateau problem. 
The code is heavily based on [Tensorflow's Quantum tutorial](https://www.tensorflow.org/quantum/tutorials/barren_plateaus).

In [None]:
# add the parent directory to the path so that we can import the module
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
import qandle

qandle.__reimport()

### Install missing libraries

In [None]:
!pip install pandas seaborn matplotlib

## Importing Libraries

In [None]:
# for simulation
import qandle 
import torch
import qw_map

# for plotting
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# helper libraries
import random
import dataclasses


## Setting up the problem
We use a fixed seed for reproducibility.

In [None]:
DEPTH = 50
NUM_REPEATS = 10
MAX_QUBITS = 5

qandle.config.DRAW_SHIFT_LEFT = True
qandle.config.DRAW_SHOW_VALUES = True

random.seed(42)
torch.random.manual_seed(42)


## Generate random circuits
Our circuits will first have a fixed layer, rotating each qubit by $R_y\left(\frac{\pi}{4}\right)$

In [None]:
pi8 = torch.tensor(torch.pi / 8)
# the matrix representation of the RY(pi/4) gate
ry_pi_4 = torch.tensor(
 [[torch.cos(pi8), -torch.sin(pi8)], [torch.sin(pi8), torch.cos(pi8)]]
)


def generate_random_circuit(num_qubits, remapping=None):
 layers = []
 # initialize the circuit with RY(pi/4) gates on all qubits
 for i in range(num_qubits):
 layers.append(
 qandle.CustomGate(
 i, matrix=ry_pi_4, num_qubits=num_qubits, self_description="RY(pi/4)"
 )
 )
 for _ in range(DEPTH):
 for qubit in range(num_qubits):
 # randomly choose a gate to apply
 gate = random.choice([qandle.RX, qandle.RY, qandle.RZ])
 layers.append(
 gate(
 qubit=qubit,
 remapping=remapping,
 )
 )
 # add CNOT gates, with even qubits as controls and odd qubits as targets
 for control in range(0, num_qubits, 2):
 target = control + 1
 if target < num_qubits:
 layers.append(qandle.CNOT(control=control, target=target))
 # add CNOT gates, with odd qubits as controls and even qubits as targets
 for control in range(1, num_qubits, 2):
 target = control + 1
 if target < num_qubits:
 layers.append(qandle.CNOT(control=control, target=target))

 # automatically split the circuit into smaller circuits if it is too large
 return qandle.Circuit(layers, split_max_qubits=6)


### Test a circuit
We generate a random circuit with 5 qubits and print it. We also execute it by calling `example_circuit()` and print the results.

In [None]:
example_circuit = generate_random_circuit(5)
print(example_circuit.draw())

print(example_circuit())

## Measure the gradients
We create `NUM_REPEATS` random circuits, and measure their gradients after a forward and backward pass. The mean and standard deviation of the gradients are saved in a pandas DataFrame.

In [None]:
@dataclasses.dataclass
class GradData:
 """Helper class for neatly storing gradient data before converting to a DataFrame."""

 num_qubits: int
 remapping: str
 grad_std: float
 grad_mean: float


circ_grads = []
combinations = [[r, q] for r in [None, qw_map.tanh] for q in range(2, MAX_QUBITS + 1)]

for remapping, num_qubits in combinations:
 remap_str = remapping.__name__ if remapping is not None else "None"
 for _ in range(NUM_REPEATS):
 c = generate_random_circuit(num_qubits, remapping=remapping)
 # forward pass and backward pass using a dummy loss
 c().sum().abs().backward()
 c_grads = [p.grad.item() for p in c.parameters() if p.grad is not None]
 c_grads = torch.tensor(c_grads)
 # standard deviation of the gradients
 grad_std = c_grads.std().item()
 # we use the absolute value of the mean gradient
 grad_mean = c_grads.abs().mean().item()
 circ_grads.append(GradData(num_qubits, remap_str, grad_std, grad_mean))

df = pd.DataFrame([dataclasses.asdict(gd) for gd in circ_grads])

## Plot the results
Remapping increases both mean absolute gradient and standard deviation.

In [None]:
fig, ax = plt.subplots()
pltargs = dict(data=df, x="num_qubits", hue="remapping")
sns.lineplot(**pltargs, ax=ax, y="grad_std", linestyle="-")
sns.lineplot(**pltargs, ax=ax.twinx(), y="grad_mean", linestyle=":", legend=False)
# set major x ticks to integer values
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))