스트라센 알고리즘은 분할 정복을 이용한 알고리즘으로,
$O(n^3)$ 으로 알려진 행렬 곱셈을 약 $O(n^{2.807})$ 으로 더 빠르게 구하는 알고리즘입니다.
스트라센 알고리즘에 대한 이해를 돕기 위해, 일반적인 행렬 곱셈 방법부터 알아봅시다.
1.일반적인 행렬 곱셈
$AB = C$ 를 만족하는 $n*n$ 크기의 행렬 $A, B, C$ 에서, 일반적인 행렬 곱셈 방법은 다음과 같습니다. \(AB = C =\begin{bmatrix} \sum_{k=1}^nA_{1k}B_{k1} & \sum_{k=1}^nA_{1k}B_{k2} & \cdots & \sum_{k=1}^nA_{1k}B_{kn} \\\sum_{k=1}^nA_{2k}B_{k1} & \sum_{k=1}^nA_{2k}B_{k2} & \cdots & \sum_{k=1}^nA_{2k}B_{kn} \\\vdots & \vdots & \ddots & \vdots \\\sum_{k=1}^nA_{nk}B_{k1} & \sum_{k=1}^nA_{nk}B_{k2} & \cdots & \sum_{k=1}^nA_{nk}B_{kn} \end{bmatrix}\) 이를 코드로 나타내면 아래와 같습니다.
1 |
|
시간복잡도
n번의 반복이 3중으로 일어나므로, 시간복잡도는 $O(n^3)$ 입니다.
2. 분할정복을 이용한 행렬곱셈
$A*B = C$ 를 만족하는 $n = 2^k$ 크기의 행렬을 $2^{k-1}$ 크기의 4개의 부분행렬(submatrix)로 쪼개어, C의 값을 다음과 같이 구할 수 있습니다. \(C = \begin{bmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{bmatrix}, A = \begin{bmatrix}A_{11}&A_{12}\\A_{21}&A_{22}\end{bmatrix}, B = \begin{bmatrix}B_{11}&B_{12}\\B_{21}&B_{22}\end{bmatrix}\)
\[C_{11}=A_{11}B_{11}+A_{12}B_{21} \\C_{12}=A_{11}B_{12}+A_{12}B_{22} \\C_{21}=A_{21}B_{11}+A_{22}B_{21} \\C_{22}=A_{21}B_{12}+A_{22}B_{22}\]시간복잡도
매 과정마다 부분행렬끼리의 8번의 곱셈과 4번의 덧셈이, 재귀적으로 위의 과정이 $k (=log_2n)$ 번 반복됩니다. 이를 통해 시간복잡도를 증명해봅시다. (여기서 c는 덧셈, 대입 등의 상수복잡도입니다.) \(T(n)=8T(\frac{n}{2})+c \\=8[8T(\frac{n}{4})+c]+c \\=8[8^2T(\frac{n}{8})+8c+c]+c \\=8[8^3T(\frac{n}{16})+8^2c+8c+c]+c \\\vdots \\=8^{log_2n}T(1)+\frac{8^{log_2n}-1}{8-1}c \\\approx 8^{log_2n}(T(1)+c) \\=n^{log_28}(T(1)+c)=n^3(T(1)+c) \\\implies O(n^3)\) 이 방법으로도, 시간복잡도는 그대로 $O(n^3)$ 임을 알 수 있습니다.
3. 스트라센 알고리즘
위의 방법에서, 부분행렬끼리의 곱셈 횟수를 7번으로 줄인 알고리즘이 바로 스트라센 알고리즘입니다.
그 방법은 다음과 같습니다. \(M_1=(A_{11}+A_{22})(B_{11}+B_{22}) \\M_2=(A_{21}+A_{22})B_{11} \\M_3=A_{11}(B_{12}−B_{22}) \\M_4=A_{22}(B_{21}−B_{11}) \\M_5=(A_{11}+A_{12})B_{22} \\M_6=(A_{21}-A_{11})(B_{11}+B_{12}) \\M_7=(A_{12}-A_{22})(B_{21}+B_{22})\) \(C_{11}=M_1+M_4-M_5+M_7 \\C_{12}=M_3+M_5 \\C_{21}=M_2+M_4 \\C_{22}=M_1-M_2+M_3+M_6\)
시간복잡도
매 과정마다 부분행렬끼리의 7번의 곱셈과 6+6번의 덧셈, 4+2번의 뺄셈이, 재귀적으로 위의 과정이 $k(=log_2n)$ 번 반복됩니다. 마찬가지로 시간복잡도를 증명해봅시다. \(T(n)=7T(\frac{n}{2^1})+c \\=7[7T(\frac{n}{2^2})+c]+c \\=7[7^2T(\frac{n}{2^3})+7c+c]+c \\=7[7^3T(\frac{n}{2^4})+7^2c+7c+c]+c \\\vdots \\=7^{log_2n}T(1)+\frac{7^{log_2n}-1}{7-1}c \\\approx 7^{log_2n}(T(1)+c) \\=n^{log_27}(T(1)+c) \approx n^{2.807}(T(1)+c) \\\implies O(n^{2.807})\) 이를 구현한 코드입니다. ($A[0] = A_{11} , A[1] = A_{12} , A[2] = A_{21} , A[3] = A_{22} 이고, M[0] = M_1, …, M[6] = M_7$ 입니다.)
1 |
|
행렬곱의 실행 시간을 구하는 방법은 아래의 방법으로 구현했습니다. (위의 코드에도 포함되어 있습니다.)
1 |
|
4. 실행 결과
- 행렬곱 결과
1 |
|
- 실행시간 비교 (strassen VS normal mult)
1 |
|
clock함수의 리턴값이 정수인가봅니다. 소수점은 표현되지 않았습니다.
visual studio 2019의 debug모드에서 실행한 결과인데, release모드에선 조금 더 빠르게 동작합니다.
코드에서는 임계값을 5로 설정했는데, 실제 실행시간을 보니 임계값이 더 높아야겠습니다.
(코드가 최적화1되지 않아서인지도 모르겠습니다. 더 자세한 이유를 아시는 분은 댓글 남겨주세요!)
(실험중에 임계값을 높이면 실험의 의미가 없으므로 코드는 그대로 놔두겠습니다.)
아무튼 중요한 사실은, n의 값이 2배가 될 때마다 strassen의 실행시간 증가폭은 x7정도로 유지되는 반면, normal mult 방식의 실행시간 증가폭은 x8(혹은 그 이상)이라는 것입니다. 위에서 증명한 시간복잡도와 비슷한 결과입니다.
이것으로 스트라센 알고리즘의 소개를 마치겠습니다. 감사합니다!
-
예를 들어, M_6, M_7은 subC를 구할 때 중복 사용되지 않으므로 따로 할당하지 않고 구하는 방법을 생각해 볼 수 있습니다. ↩