بازنویسی عملکرد تابع fit برای آموزش یک شبکه GAN

در این مطلب می خواهیم نمونه‌ای از نحوه بازنویسی training step که درون تابع fit در کلاس Model در تنسورفلو/کراس وجود دارد را ببینم و سپس با این بازنویسی یک مدل GAN را بسته بندی کنیم و آموزش دهیم.

آشنایی با تابع fit

زمانی که شما در حال انجام یک کار supervised learning هستید، شما می توانید از تابع fit استفاده کنید و همه چیز به سادگی انجام خواهد شد.

زمانی هم که شما نیاز داشته باشید تا حلقه گام های آموزش یا training loop را خودتان از اول بنویسید، آنگاه از GradientTape استفاده خواهید کرد و کنترل همه جزئیات را در دست می گیرید.

اما اگر شما الگوریتم آموزش ویژه و متفاوتی در نظر داشته باشید و هنوز بخواهید از تابع fit و مزایای آن (مانند callback ها) استفاده کنید، چه؟ راه حل چیست؟

یکی از اصول طراحی Keras با عنوان progressive disclosure of complexity شناخته می شود؛ طبق این اصل، شما همیشه باید قادر باشید تا به سطوح کاری زیرین کتابخانه به شیوه تدریجی دسترسی پیدا کنید. در واقع اگر مورد کار شما دقیقاً با تعریف های پایه ای کتابخانه نیست، برای رفع نیاز خود شما نباید از قله انتزاع به دره پیچیدگی های سطح پایین سقوط کنید. شما در هر مرحله باید بتوانید به آن مقداری از جزئیات که برای تغییر نیاز دارید دسترسی داشته باشید.

زمانی که شما می خواهید عملکرد fit را شخصی سازی کنید، شما در عمل باید این تابع از کلاس Model را که مسئول اجرای قدم های آموزش مدل است را بازنویسی کنید. این تابع برای دسته یا batch از داده صدا زده می شود. پس از بازنویسی این تابع از کلاس Model ، این تابع می تواند به طور معمول از کد شما صدا زده شود تا الگوریتم شخصی شما را روی داده اجرا کند تا آموزش شبکه انجام شود.

توجه کنید که این الگو شما را از استفاده مدل های Functional API با Keras منع نمی کند و به طور مثال می توانید در مدل Sequential خود استفاده کنید.

تنظیمات این مطلب

توضیحات و کدی که در این مطلب وجود دارد، TensorFlow نسخه ۲.۲ و بالاتر را نیاز دارد.

یک مثال ساده بازنویسی train_step

با یک مثال ساده بازنویسی تابع fit را شروع کنیم:

  • ما یک کلاس مدل تنسورفلوی جدید می سازیم که از keras.Model ارث بری می کند.
  • ما فقط تابع عضو train_step(self, data) را بازنویسی می کنیم (override).
  • ما یک dictionary را باز می گردانیم که شامل نام متریک و معیارها (شامل loss) و مقادیر مرتبط با آنهاست.

 

آرگومان ورودی data همان چیزی که است به عنوان داده آموزش به تابع fit داده می شود.

در بدنه تابع train_step ما یک آموزش و به روز رسانی وزن ساده را پیاده سازی می کنیم، شبیه به آنچه شما از قبل و در مفاهیم شبکه های عصبی عمیق، با آن آشنا هستید. نکته مهم آن است که ما loss را با self.compiled_loss محاسبه می کنیم (که توابع loss موجود در مدل را که با فراخوانی ()compile معرفی شده اند را پوشش می دهد).

به همین شیوه، ما با فراخوانی self.compiled_metrics.update_state(y, y_pred) مقادیر متریک و معیارهای داده شده در ()compile را به روز رسانی می کنیم و با self.metrics آنها را دریافت می کنیم.

بیایید این کد را استفاده کنیم:

مثالی از بازنویسی گام ارزیابی یا test_step

حالا اگر بخواهید در آموزش مدل با فراخوانی ()model.evaluate عملکرد خاصی انجام شود چه؟ آنگاه شما تابع test_step را بازنویسی خواهید کرد. این کار می تواند به سادگی و به صورت زیر انجام شود:

 

مثال پیاده سازی کلاس GAN با بازنویسی فرآیند fit

بیایید یک GAN را به طور کامل در یک کلاس Model بگنجانیم و سرتاسر فرآیند آموزش آن را در همان کلاس تعریف کنیم.

در این مثال در نظر داشته باشید که این موارد را نیاز داریم:

  • یک شبکه generator قرار است تصاویر ۲۸x28x1 تولید کند.
  • یک شبکه discriminator قرار است تصاویر با ابعاد ۲۸x28x1 را در دو دسته واقعی و غیرواقعی یا “fake” و “real” دسته بندی کند.
  • هر کدام از آن دو به طور جداگانه یک Optimizer دارند.
  • یک تابع خطا یا loss برای آموزش دادن شبکه discriminator .

 

و حالا در اینجا یک کلاس GAN با مشخصات کامل را با بازنویسی ()compile و train_step تعریف می کنیم که در بخش گام آموزش تنها با ۱۷ خط الگوریتم آموزش GAN را پیاده کرده ایم.

 

بیایید تعریف بالا را آزمایش و استفاده کنیم:

 

امیدواریم با به کارگیری این آموزش، نحوه پیاده سازی مدل های ویژه و سفارشی سازی شده برای شما آسان تر شود.

 

 

+ ترجمه و تلخیص از منبع زیر:

Customizing what happens in fit()

نظرتان را برای ما بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *