پیش‌نویس:شبکه مولد تخاصمی (Wasserstein GAN)

از testwiki
پرش به ناوبری پرش به جستجو

شبکه مولد تخاصمی (Wasserstein GAN)

شبکه متخاصم مولد Wasserstein Generative Adversarial Network (WGAN) نوعی از شبکه متخاصم مولد (GAN) است که در سال 2017 پیشنهاد شد و هدف آن بهبود پایداری در یادگیری، خلاص شدن از مشکلاتی مانند فروپاشی حالت (به انگلیسی: Mode Collapse) و ارائه منحنی‌های یادگیری معنادار است که برای اشکال زدایی و جستجو فراپارامترها (به انگلیسی: Hyperparameters) مفید هستند. ".[۱][۲]

در مقایسه با شبکه‌های تخاصمی اولیه، WGAN سیگنال‌های راهنمای بهتری برای یادگیری مولد (generator) فراهم میکند و این باعث میشود زمانی که مولد در حال یادگیری توزیع هایی با ابعاد بالاست یادگیری پایدارتر باشد.

انگیزه

بازی GAN

شبکه های تخاصمی اولیه بر پایه بازی GAN هستند، یک بازی جمع صفر (zero sum) با دو بازیکن: مولد و تفکیک کننده (discriminator). این بازی بر روی یک فضای احتمال به صورت (Ω,,μref) تعریف میشود. مجموعه استراتژی های مولد، مجموعه تمام احتمالات μG بر روی (Ω,) است و مجموعه استراتژی های تفکیک کننده به صورت تابع D:Ω[0,1] است. تابع هدف این بازی به صورت زیر است: L(μG,D):=𝔼xμref[lnD(x)]+𝔼xμG[ln(1D(x))]مولد در تلاش برای کمینه کردن تابع فوق و تفکیک کننده به دنبال بیشینه کردن آن است. در بازی GAN یک نظریه پایه می گوید: الگو:Math theoremاگر بازی GAN را به تعداد زیاد تکرار کنیم که هربار مولد حرکت اول و تفکیک کننده حرکت دوم را انجام دهد، هربار استراتژی مولد μG تغییر می‌کند و تفکیک کننده مجبور است برای دادن پاسخ بهینه به مقدار ایده آل نزدیک شود.

D*(x)=dμrefd(μref+μG).

از آنجایی که ما به μref نیاز داریم، تابع تفکیک کننده D به تنهایی برایمان ارزشی نداشته و فقط نسبت احتمال بین توزیع مولد و توزیع مرجع (Reference) را محاسبه میکند. در حالت تعادل، خروجی تفکیک کننده همواره برابر با 12 است که در واقع انگار تفکیک کننده تسلیم شده است.

در بازی GAN اگر استراتژی مولد μG را ثابت نگه داریم و مرحله به مرحله تفکیک کننده را بهبود ببخشیم، با داشتن μD,t بعنوان حرکت تفکیک کننده در زمان t، آنگاه در حالت ایده آل خواهیم داشت:

L(μG,μD,1)L(μG,μD,2)maxμDL(μG,μD)=2DJS(μrefμG)2ln2,

که یعنی تفکیک کننده به دنبال حد پایین DJS(μrefμG) است.

فاصله واسرستاین (Wasserstein distance)

همانطور که دیدیم، تفکیک کننده نقش یک منتقد را دارد و به مولد اعلام می‌کند "چقدر از حقیقت دور است" که تعریف "دور" همان واگرایی جیسون-شنون است.

طبیعتا، امکان تعریف معیارهای دیگری از دور بودن مطرح می شود. امروزه معیارهای زیادی برای انتخاب وجود دارد مانند خانواده f-divergence، که به ما f-GAN را می دهد.[۳]

به همین صورت WGAN با استفاده از معیار واسرستاین (Wasserstein metric)، که در قضیه نمایش دوگانه صدق می کند، به دست می آید.

الگو:Math theorem

اثبات این نظریه را می توانید در صفحه اصلی Wasserstein metric مشاهده کنید.

تعریف

باتوجه به دوگانگی کانتوروویچ-روبنشتاین، تعریف WGAN به صورت زیر است:

الگو:Blockquoteبرای هر استراتژی مولد μG، جواب بهینه از طرف تفکیک کننده برابر است با D* بطوری که:

LWGAN(μG,D*)=KW1(μG,μref).

در نتیجه، اگر تفکیک کننده خوب عمل کند، مولد همواره به کمینه کردن W1(μG,μref) ترغیب می شود و همانطور که باید، استراتژی بهینه برای آن μG=μref است.

مقایسه با GAN

در WGAN تفکیک کننده گرادیان بهتری نسبت به GAN فراهم می کند.

تفکیک کننده بهینه واسرستاین DWGAN و تفکیک کننده بهینه GAN، D برای توزیع مرجع μref و توزیع مولد μG ثابت.
همان نمودار بالا، اما تفکیک کننده GAN، D با ln(1D) جایگزین شده است.

به طور مثال یک بازی بر روی خط اعداد حقیقی داریم، که μG و μref توزیع نرمال هستند. در نمودار زیر مولد بهینه D و تفکیک کننده بهینه واسرستاین DWGAN نشان داده شده اند.

برای یک تفکیک کننده ثابت، مولد باید توابع هدف زیر را کمینه کند.

  • برای بازی GAN: 𝔼xμG[ln(1D(x))]
  • برای بازی WGAN: 𝔼xμG[DWGAN(x)]

فرض کنید μG از پارامترهای θ تشکیل شده است، این گونه می توانیم با استفاده از تخمین گر نااریب (unbiased estimator) گرادیان، یک گرادیان کاهشی تصادفی اجرا کنیم:

θ𝔼xμG[ln(1D(x))]=𝔼xμG[ln(1D(x))θlnρμG(x)]θ𝔼xμG[DWGAN(x)]=𝔼xμG[DWGAN(x)θlnρμG(x)]

برای به دست آوردن فرمول های بالا از تغییر متغیر (reparameterization trick)

همانطور که مشاهده می شود، در GAN، مولد ترغیب می شود از قله ln(1D(x)) به سمت پایین سقوط کند. مولد WGAN نیز به همین صورت است.

در WGAN، DWGAN تقریبا همواره گرادیان برابر یک دارد، این در حالیست که در GAN، ln(1D) در میانه گرادیان برابر صفر و سایر نقاط گرادیانی بزرگ دارد. این باعث می شود واریانس تخمین گر در GAN معمولا بسیار بیشتر از WGAN باشد.

مشکل DJS در موارد واقعی یادگیری ماشین بسیار بزرگتر است. فرض کنید می خواهیم یک GAN را برای ImageNet ، یک مجموعه از عکس های 256 در 256 ، آموزش دهیم. فضای تمام این عکس ها 2562 است، در حالیکه عکس های داخل ImageNet، μref، بر روی یک فراوانی با ابعاد بسیار پایین تر تمرکز دارد. در نتیجه هر استراتژی μG برای مولد تقریبا بطور کامل از μref مجزاست که باعث می شود DJS(μGμref)=+. بنابراین یک تفکیک کننده خوب تقریبا همواره می تواند μref و حتی هر μG نزدیک به μG را از μG تشخیص دهد. این باعث می شود گرادیان تقریبا صفر باشد μGL(μG,D)0، و هیچ گونه سیگنالی برای بهبود مولد تولید نشود.

جزئیات این نظریه را می توانید در اینجا مشاهده کنید.[۴]

آموزش Wasserstein GAN

آموزش مولد در WGAN و GAN صرفا براساس گرادیان کاهشی است، اما آموزش تفکیک کننده متفاوت است به این دلیل که در WGAN تفکیک کننده یک محدودیت جدید دارد که همان حد Lipschitz norm است. برای انجام این کار روش های مختلفی وجود دارد.

قراردادن حد بالا برای Lipschitz norm

تابع تفکیک کننده D را به صورت یک پرستپترون چندلایه (multilayer perceptron) پیاده سازی میکنیم.

D=DnDn1D1

که در آن Di(x)=h(Wix) و h: یک تابع فعالساز ثابت با supx|h(x)|1 است. برای مثال، تابع تانژانت هایپربولیک h=tanh در شرط گفته شده صدق می کند. برای هر x، قرار می دهیم xi=(DiDi1D1)(x) و با استفاده از قانون زنجیره ای خواهیم داشت:

dD(x)=diag(h(Wnxn1))Wndiag(h(Wn1xn2))Wn1diag(h(W1x))W1dx

اینگونه Lipschitz norm تفکیک کننده حد بالای زیر را دارد:

DLsupxdiag(h(Wnxn1))Wndiag(h(Wn1xn2))Wn1diag(h(W1x))W1F

که s همان operator norm یا spectral radius یا بزرگترین مقدار ویژه ماتریس است (این سه مفهوم در ماتریس ها یک معنی را می دهند اما برای دیگر عملگرهای خطی می توانند متفاوت باشند).

از آنجایی که supx|h(x)|1 داریم diag(h(Wixi1))s=maxj|h(Wixi1,j)|1 و به همین دلیل حد بالا برابر است با:

DLi=1nWis

بنابراین، اگر بتوانیم بر روی Wis تمام ماتریس ها حد بالایی تعریف کنیم در واقع حد بالایی برای Lipschitz norm تفکیک کننده تعیین کرده ایم.

برش وزن (Weight clipping)

برای هر ماتریس W با ابعاد m×l قرار می دهیم c=maxi,j|Wi,j|، آنگاه خواهیم داشت:

Ws2=supx2=1Wx22=supx2=1i(jWi,jxj)2=supx2=1i,j,kWijWikxjxkc2ml2

با محدود کردن تمام درایه های W به بازه [c,c]، در واقع Wis را محدود کرده ایم.

این روش برش وزن را می توانید در مقاله اصلی مطالعه کنید.

جریمه بر روی گرادیان (Gradient penalty)

به جای گذاشتن حد بر روی DL، می توانیم برای گرادیان یک جریمه به فرم زیر تعریف کنیم:

𝔼xμ^[(D(x)2a)2]

ه μ^ یک توزیع ثابت برای تخمین زدن مقداری است که تفکیک کننده از حد Lipschitz norm تجاوز کرده است.

تفکیک کننده برای کمینه کردن تابع هزینه جدید، تلاش میکند D(x) را به a نزدیک کند، این باعث می شود DLa

این روش برش وزن را می توانید در مقاله اصلی مطالعه کنید.[۵]

منابع