이번 글에서는 sum연산을 통해 axis와 keepdims에 대해 간단히 알아보도록 하겠습니다.
Index
1. Axis
2.keepdims
1. Axis
x.sum(axis =option)에서 option은 0부터 n까지 설정할 수 있으며 해당 되는 값의 차원의 합을 구하며 동시에 제거합니다.
예시를 통해 3rd order tensor에 대하여 axis별 연산을 설명해 보도록 하겠습니다.
import numpy as np
a= np.arange(2*3*4).reshape((2,3,4))
print(a,a.shape)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]] (2, 3, 4)
#ex1)
sum_0 = a.sum(axis=0)
print(sum_0,sum_0.shape)
[[12 14 16 18]
[20 22 24 26]
[28 30 32 34]] (3, 4)
#ex2)
sum_1 = a.sum(axis=1)
print(sum_1,sum_1.shape)
[[12 15 18 21]
[48 51 54 57]] (2, 4)
#ex3)
sum_2 = a.sum(axis=2)
print(sum_2,sum_2.shape)
[[ 6 22 38]
[54 70 86]] (2, 3)
ex1)
(2,3,4) 중 2가 사라지며 해당되는 차원의 원소값들이 합해지고 차원이 축소되어 (3,4)의 matrix를 만들어 낸 모습입니다.
ex2)
(2,3,4) 중 3이 사라지며 해당되는 차원의 원소값들이 합해지고 차원이 축소되어 (2,4)의 matrix를 만들어 낸 모습입니다.
ex3)
(2,3,4) 중 4가 사라지며 해당되는 차원의 원소값들이 합해지고 차원이 축소되어 (2,3)의 matrix를 만들어 낸 모습입니다.
2.Keepdims
axis 설정을 해주게 되면 차원이 축소 되는데 그렇게 되면 이후에 broadcasting연산시 불편함이 생기게 됩니다.
그래서 차원을 축소하지 않는 방법으로 keepdims를 사용할 수 있습니다.
x.sum(axis=option,keepdims=True) , keepdims 를 True로 설정하게 되면 축소되어야 하는 차원이 1로 남게 되어
broadcasting시 용이하게 됩니다.
1.axis에 사용한 예제에 keepdims=True 로 설정하여 결괏값을 살펴보도록 하겠습니다.
import numpy as np
a= np.arange(2*3*4).reshape((2,3,4))
print(a,a.shape)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]] (2, 3, 4)
#ex1)
sum_0 = a.sum(axis=0,keepdims=True)
print(sum_0,sum_0.shape)
[[[12 14 16 18]
[20 22 24 26]
[28 30 32 34]]] (1, 3, 4)
#ex2)
sum_1 = a.sum(axis=1,keepdims=True)
print(sum_1,sum_1.shape)
[[[12 15 18 21]]
[[48 51 54 57]]] (2, 1, 4)
#ex3)
sum_2 = a.sum(axis=2,keepdims=True)
print(sum_2,sum_2.shape)
[[[ 6]
[22]
[38]]
[[54]
[70]
[86]]] (2, 3, 1)
결괏값을 살펴보면 차원이 축소되지 않고 그자리에 1이 들어간 것을 확인 할 수 있습니다.
이 option을 broadcasting에 용이하게 사용할 수 있을것입니다.
'컴퓨터공학 > python' 카테고리의 다른 글
[Numpy] Sorting에 대하여(sort,argsort) 정렬, 인덱스정렬 (0) | 2021.10.07 |
---|---|
[Numpy] Sum, Prod, Diff 사용법 (0) | 2021.10.06 |
[Numpy] Indexing and Slicing (0) | 2021.10.05 |
[Numpy] General Broadcasting Rules에 대하여 (작성중) (0) | 2021.10.05 |
[Numpy] Data type에 대하여 (0) | 2021.10.01 |