๐ ์ค๋ ํ ์ผ
- ๊ธฐ๋ณธ ๊ณผ์ 4, 5
- ํด์ฆ
๐ฅ ํผ์ด์ธ์ ์์ฝ
์ต์ข ํ๋ก์ ํธ ์ฃผ์ ์ ๋ํด์ ์ด์ผ๊ธฐ๋ฅผ ๋๋์๋ค. ๋ฐ๋๋ผ์ธ์ ์ ํด๋๊ณ ๋ ธ์ ์์ ๋ธ๋ ์ธ์คํ ๋ฐ์ ํด ๋ณผ ์์ ์ด๋ค!
๐ ๊ณผ์ ๋ด์ฉ ์ ๋ฆฌ
๊ธฐ๋ณธ ๊ณผ์ 4 - cGAN (Conditional GAN)
- ์ค์ผ๋ ํค ์ฝ๋ ๋ถ์ (
batch_size=64
๊ฐ์ )real_imgs
($x$) = [64, 1, 28, 28]- MNIST ๋ฐ์ดํฐ์ ์ด๋ฏ๋ก 1ร28ร28์ด๋ค.
labels
/gen_labels
($y$) = [64, 10]- ์ค์ผ๋ ํค ์ฝ๋์์ ์ด๋ฏธ one-hot vector๋ค์ tensor๋ก ๋ง๋ค์ด๋์๋ค.
- ๋ค๋ฅธ ์ฝ๋๋ค์์๋ [64]๋ก๋ง ๋ค์ด๊ฐ์, ๋ชจ๋ ๋ด๋ถ์์ embedding ๋๋ค.
noise
($z$) = [64, 100]
- ๊ฒช์๋ ๋ฌธ์ ์ํฉ
- ์๋ฒ ๋ฉ (embedding)
- ๋ฌด์์ธ์ง ๋ชจ๋ฅด๊ฒ ์ด์ ์ฌ๋ฌ ์ฝ๋๋ฅผ ์ฐพ์๋ณด์์ผ๋, ์ธํฐ๋ท์ ์ฌ๋ผ์ ์๋ cGAN ๊ตฌํ ์ฝ๋์ ๊ณผ์ ์ ์ค์ผ๋ ํค ์ฝ๋๊ฐ ๋ฌ๋ผ ํ์ฐธ์ ํค๋งธ๋ค.
- ๊ตฌํ๋์ด ์๋ ์ฝ๋ ๋๋ถ๋ถ์
nn.Embedding
์ ์ด์ฉํด์ 1์ฐจ์ tensorlabel
์ ์๋ฒ ๋ฉํ๋๋ฐ, ์ค์ผ๋ ํค ์ฝ๋์์๋ ์ด๋ฏธlabels
,gen_labels
๊ฐ 2์ฐจ์ tensor(one-hot vector)๋ก ์ฃผ์ด์ง๋ค. ๋ฐ๋ผ์ ๋nn.Embedding
์ ์ ์ฉํ๋ฉด ์๋๋ค. - ๊ฒฐ๊ตญ, ๊ณผ์ ์ ์ค๋ช
์ ๋ฐ๋ฅด๋ ๊ฒ์ด ์ ๋ต์ด์๋ค. ๊ณผ์ ์ ์ค๋ช
์ ๋ฐ๋ฅด๋ฉด ์๋ฒ ๋ฉ == nn.Linear์ ๋ฃ๊ธฐ์ด๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋๋ก concat ํ์ฌ ์ฌ์ฉํ๋ฉด ๋๋ ๊ฒ์ด์๋ค. ๋ฐ๋ผ์
nn.Linear
์ ๋ ๋ฒ์งธ ์ฐจ์์ 1/2์ฉ ์ค์ฌ์ ๋ฃ์ด์ผ ํ๋ค.
- ์๋ง์ ๋ฐํ์ ์๋ฌ ๐ฑ
- ๋๋ฌด๋๋ ํ์ฌํ ์ค์๋ค ๋๋ฌธ์ ๋ฒ์ด์ง ์ผ์ด๋ผ ์ค๋ช ์ ์๋ตํ๋คโฆ
- activation function
- ๊ณผ์ ์ค๋ช ์ ์ถฉ์คํ์ฌ ์ฝ๋๋ฅผ ์งฐ์ผ๋ generator loss๊ฐ ํญ๋ฐํ๊ณ discriminator loss๊ฐ ์๋ ดํ๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค. ๊ฒ์ํด๋ณด๋ ์ด ์ ๋๋ ์ํคํ ์ฒ๊ฐ ์๋ชป๋ ๊ฒ์ด๋ผ๊ณ ํด์, ๋ค๋ฅธ ๊ตฌํ๋ ์ฝ๋๋ค์ ์ฐธ๊ณ ํด๋ดค๋ค.
์ฐ์ ,
nn.ReLU()
๋ฅผnn.LeakyReLU(0.2, inplace=True)
๋ก ๋ฐ๊ฟ๋ณด์๋ค. ๊ทธ๋๋ ๋ฑํ ๋์์ง์ง ์์๋ค.generator์ ๋ง์ง๋ง activation function์ธ
nn.Sigmoid()
๋ฅผnn.Tanh()
๋ก ๋ฐ๊ฟ๋ณด์๋ค. ๊ทธ๋ฌ๋๋ ๋ฐ๋ก ์ ์ ํ generator loss ๊ฐ๊ณผ discriminator loss ๊ฐ์ด ๊ด์ฐฐ๋์๋ค.generator loss์ discriminator loss๊ฐ ์๋ ดํ์ง ์๋ ๊ฒ์ฒ๋ผ ๋ณด์ด๋ ํ์์ ์์ฐ์ค๋ฌ์ด ํ์์ด๋ค.
- ์๋ฒ ๋ฉ (embedding)
๊ฒฐ๊ณผ (epoch 50) - ๊น๋ํ๊ฒ ์ ๋์จ๋ค!!!
๊ธฐ๋ณธ ๊ณผ์ 5 - CLIP ๋ชจ๋ธ์ ํตํ Multi-modal ๋ชจ๋ธ์ ๋ค์ํ ํ์ฉ
- ๊ธฐ์ตํ ๊ฒ
torch.topk
- ์์ธก ๊ฐ์์ argmax๊ฐ ์๋ top-k ๊ฐ์ ๊ฒฐ๊ณผ ๊ฐ์ ์ป์ ๋ ์ฌ์ฉํ๋ค.1 2 3
# softmax ํจ์๋ก ๊ฐ์ฅ ๋์ K๊ฐ์ ํ๋ฅ ๊ฐ ๊ตฌํ๊ธฐ K = 10 values, indices = torch.topk(similarity.softmax(dim=-1), k=K, dim=-1)
๊ณผ์ ๋ด์ฉ ์ผ๋ถ
๋๋ฌด ์ ๊ธฐํด์ ์ฌ์ง์ ๊ผญ ๋ฃ๊ณ ์ถ์๋ค
๐พ ์ผ์ผ ํ๊ณ
์ค๋์ ํ๋ฃจ ์ข ์ผ ๊ณผ์ ๋ง ํ๋๋ผ ํ๋ ๋จ์ ๊ฐ์๋ฅผ ๋ง์ ๋ฃ์ง๋ ๋ชปํ๊ณ ๋ค๋ฅธ ํ ์ผ๋ ๋ชปํ๋คโฆ ํ์ง๋ง ์๋ง์ ์ฝ์ง๊ณผ ๋ฐํ์ ์๋ฌ ๋์ ๊ณผ์ ๋ฅผ ์์ฑํด๋๋ค๋ ์ ๋ง์ผ๋ก๋ ๋ฟ๋ฏํ๊ฒ ์๊ฐํ๋ ค ํ๋ค. ํนํ CLIP์ ์ข๋ค๊ณ ๋ค์ด๋ณด๊ธฐ๋ง ํ์๋๋ฐ, ์ค์ ๋ก ์ฌ์ฉํด๋ณด๋๊น ๋๋ฌด ๋๋ฌด ์ฌ๋ฏธ์์ด์ ๊ณผ์ ์ค์ ํ๋ง ํ์์ ๊ฐ์ง ๋ฏํ ๋๋์ด์๋ค. ์ฌํ ๊ณผ์ ๋ ์์ง ๋ค ํ์ง ๋ชปํ๋๋ฐ ๋ด์ผ ๋ต์ง๊ฐ ์ฌ๋ผ์ค๊ธฐ ์ ๊น์ง๋ ๋๊น์ง ๋ด์ผ๊ฒ ๋ค.
๐ ๋ด์ผ ํ ์ผ
- ๊ฐ์ ์๊ฐ ๋ฐ ์ ๋ฆฌ
- [CV] 10๊ฐ
- ์๊ณ ๋ฆฌ์ฆ
- ์ถ์ ์ง๋ Notion์ ์ฎ๊ธฐ๊ธฐ
- ์ฌํ ๊ณผ์ 2, 3