Machine Learning

อธิบาย 10 Metrics พื้นฐานสำหรับวัดผลโมเดล Machine Learning

ML Engineer มี metrics ที่ใช้วัดความถูกต้องของโมเดลหลายตัว แบ่งเป็นสองกลุ่มใหญ่สำหรับปัญหา regression และ classification บทความนี้แอดจะอธิบาย 10 metrics พื้นฐานที่ทุกคนควรรู้จัก

10 metrics มีอะไรบ้าง?
Regression – MAE, MSE, RMSE, R2
Classification – Accuracy, Precision, Recall, F1-Score, F-Beta Score, AUC

ในบทความนี้เราใช้ terminology (คำศัพท์) ดังนี้

  • [y] prediction – ผลการทำนายของโมเดล
  • [y] actual – ค่าจริง
  • metrics คือค่าที่ได้จากการคำนวณ/ เปรียบเทียบผล prediction vs. actual ว่าโมเดลของเราทำนายได้ถูกต้องแค่ไหน

Let’s focus at Regression

มาเริ่มกันที่ metrics สำหรับปัญหา Regression เราสร้างโมเดล Regression เพื่อทำนายตัวแปร y แบบ continuous (ตัวเลข) สามารถเขียนสมการคำนวณค่า error ได้แบบนี้ error = prediction – actual

ML เรียก error function นี้ว่า “Loss function” และเป้าหมายของการ train โมเดลคือการ minimize หรือลดค่า loss ให้มีค่าต่ำที่สุด รูปด้านล่างเราลองวาด Linear Regression ง่ายๆเพื่อทำนายค่า y ด้วยตัวแปร x

จุดสีดำบนรูปคือ actual data ส่วนเส้นตรงสีน้ำเงินคือ prediction ที่ได้จากการ train Linear Regression ระยะห่างของ prediction vs. actual คือค่า loss ที่เราต้องการ minimize

MAE

Metric ตัวแรก (หรือ loss function) ที่เราใช้เทรน Linear Regression คือ MAE ย่อมาจาก “Mean Absolute Error” คำนวณง่ายๆแค่หาผลรวมของค่า absolute(error) แล้วคูณกับ 1/n เพื่อเปลี่ยนเป็นค่าเฉลี่ย

MSE

MSE ย่อมาจาก “Mean Squared Error” จะคล้ายกับ MAE แค่เปลี่ยนจากการทำ absolute เป็น squared (ยกกำลังสองค่า error) ก่อนหาค่าเฉลี่ย

Technical Note – ตอนแอดเรียนบน Udacity อาจารย์สอนว่า MSE จะดีกว่า MAE เวลาใช้พวก gradient descent algorithm เพราะว่าตอนหา derivative ดิฟสมการ loss ด้วยแคลคูลัสจะทำได้ง่ายกว่า

RMSE

RMSE ย่อมาจาก “Root Mean Square Error” ทำ square root ค่า MSE เพื่อให้ได้ค่า loss ที่มีหน่วยเดียวกับตัวแปร y เหตุผลที่ต้องทำ SQRT เพราะว่าเรายกกำลังสองค่า error ก่อนหาค่าเฉลี่ยทำให้หน่วยมันเปลี่ยนไปจากเดิม (คนคิด RMSE เลยบอกว่าขอรูทกลับได้ไหม)

จากที่อธิบายมาสามตัว RMSE จะแปลผลง่ายสุดเลย Linear Regression ที่มี RMSE เท่ากับ 2.56 แปลว่าโดยเฉลี่ยโมเดลทำนาย y ผิดไป +/- 2.56 point

 MAE, MSE, RMSE มีค่ายิ่งต่ำยิ่งดี ถ้าเท่ากับ 0 แปลว่าโมเดลทำนายค่า y ได้ถูกต้องเป๊ะ 100% ในทางปฏิบัติโอกาสที่จะเทรนโมเดลได้ loss = 0 เป็นไปได้ยากมาก เพราะอาจนำไปสู่ปัญหา Overfitting ได้ 

R2

R2 หรือ R-Squared เป็น metric ที่ใช้กันเยอะมากเวลารัน Linear Regression แบบนักสถิติ ในสูตรด้านล่าง y_hat คือ prediction และ y_bar คือค่าเฉลี่ยของ y อีกชื่อที่นักสถิติใช้เรียก R2 คือ Explained Variance

Variance คืออะไร?
ภาษาไทยเรียกว่าความแปรปรวน ใน context ของ linear regression นักสถิติมอง variance เป็น amount of information (y) ที่โมเดลของเราสามารถทำนายได้

อธิบายภาษาคนง่ายๆ R2 คือ variance ที่โมเดลของเราอธิบายได้เป็นสัดส่วนจาก total variance ทั้งหมดของข้อมูลชุดนั้น (นักสถิติใช้สูตรคำนวณ total variance = explained variance + error) โดย R2 จะมีค่าอยู่ระหว่าง 0-1 ยิ่งเข้าใกล้ 1 แปลว่าโมเดลเราทำนายผลได้ดีมาก


Let’s focus at Classification

มาลองดู metrics ของปัญหา classification กันบ้าง ตัวอย่างวันนี้เราสร้างโมเดลทำนาย spam email ชื่อทางการของปัญหานี้คือ binary classification ที่ตัวแปร y มีได้แค่สองค่า {0, 1}

สิ่งแรกที่ทุกคนต้องรู้จักคือ Confusion Matrix ตาราง cross-tabs 2×2 ระหว่าง prediction และ actual label การคำนวณ metrics ต่างๆจะใช้ตัวเลขในตารางนี้เป็นหลัก

ตารางด้านบนสรุปผลการทำนายอีเมล์ทั้งหมด N=100 ผลลัพธ์จะถูกแบ่งเป็นสี่ช่อง แต่ละช่องมีชื่อเรียกทางการว่า True Positive, False Positive, False Negative, True Negative (อ่านจากซ้ายไปขวา บนลงล่าง)

ทำไมถึงเรียกว่า True Positive, False Positive etc.
เราใช้คำว่า positive แทนการทำนายผล y=1 (spam) และ negative สำหรับ y=0 (ham) ถ้าเกิดโมเดลทำทายถูกต้องว่าอีเมล์เป็น spam เราจะเรียกว่า True Positive แต่ถ้าทำนาย ham เป็น spam จะเรียกว่า False Positive

Accuracy

Accuracy คือ metric ที่ใช้งานง่ายที่สุด บอกว่าโมเดลเราทำนายถูกทั้งหมดกี่ % จากตารางด้านบน accuracy = (20 + 50) / 100 = 70% คำนวณจากเส้นทแยงมุมของ confusion matrix ได้เลย

ถ้าแบบทางการหน่อย เราจะเขียนสูตรว่า accuracy = (TP + TN)/ N โดยค่า accuracy จะมีค่าอยู่ระหว่าง 0-1 ยิ่งเข้าใกล้ 1 แปลว่าโมเดลเราทำนายผลได้ดีมาก

Precision

นิยามของ Precision คือความน่าจะเป็นที่โมเดลทำนาย spam ถูกต้องจากการทำนาย spam ทั้งหมด 32 ครั้ง (20 + 12 ผลรวมแถวบนของ confusion matrix)

แทนค่าในสมการ precision = TP / (TP + FP) = 20 / (20 + 12) = 62.5%

Recall

นิยามของ Recall คือความน่าจะเป็นที่โมเดลสามารถตรวจจับ spam จากจำนวน spam email ทั้งหมดในข้อมูลของเรา 38 ฉบับ (20 + 18 ผลรวมคอลั่มแรกของ confusion matrix)

แทนค่าในสมการ recall = TP / (TP + FN) = 20 / (20 + 18) = 52.6%

F1-Score

F1-Score คือค่าเฉลี่ยแบบ harmonic mean ระหว่าง precision และ recall นักวิจัยสร้าง F1 ขึ้นมาเพื่อเป็น single metric ที่วัดความสามารถของโมเดล (ไม่ต้องเลือกระหว่าง precision, recall เพราะเฉลี่ยให้แล้ว)

แทนค่าในสมการ F1 = 2 * ( (0.625 * 0.526) / (0.625 + 0.526) ) = 57.1%

Accuracy ไม่ใช่ metric เดียวที่เราต้องดู
ในทางปฏิบัติเราจะดูค่า precision, recall, F1 ร่วมกับ accuracy เสมอ โดยเฉพาะอย่างยิ่งเวลาเจอกับปัญหา imbalanced classification i.e. y {0,1} มีสัดส่วนไม่เท่ากับ 50:50

F-Beta Score

 F Score มีสูตรทั่วไปที่เราสามารถกำหนดค่า Beta ได้เอง (เช่น F1 คือการกำหนด Beta = 1) ถ้าเราอยากให้น้ำหนักไปทาง precision ให้กำหนดค่า Beta < 1 แต่ถ้าอยากเน้นที่ recall ให้กำหนดค่า Beta > 1 .

ถ้าเรากำหนด Beta=0 จะได้ค่า precision เพียวๆเลย (แอด prove สูตรให้ดูด้านล่าง) การใช้ F-Beta Score ช่วยให้เราปรับ metric ได้เหมาะสมตามสถานการณ์ เช่น งานที่เน้น recall ก็ให้ใช้ Beta = 2, 3, 5 เป็นต้น

# example beta=0
F0 = (1 + 0) * (precision * recall) / (0 + recall)
F0 = precision * recall / recall
F0 = precision

AUC

AUC ย่อมาจาก “Area Under <ROC> Curve” เป็นอีกหนึ่ง metric ยอดนิยมที่ใช้กันแทบทุกงานเลย AUC มีค่าอยู่ระหว่าง 0-1 ยิ่งเข้าใกล้ 1 แปลว่าโมเดลในภาพรวมสามารถทำนาย y ได้ดีมาก

  • AUC = 0.50 ไม่ต่างอะไรกับการเดาสุ่มเลย
  • AUC > 0.70 คือเกณฑ์มาตรฐานสำหรับโมเดลส่วนใหญ่
  • AUC > 0.80 โมเดลทำงานได้ดี
  • AUC > 0.90 โมเดลทำงานได้ดีมาก

Summary

เวลาทำงาน regression แอดจะใช้ RMSE กับ R2 เป็นหลัก ส่วน classification จะดูค่า accuracy, precision, recall, F1, AUC พร้อมกันเลย ถ้าเขียนโปรแกรมใน R/ Python จะมีฟังชั่นสำเร็จรูปที่ใช้คำนวณ metrics พวกนี้อยู่แล้ว import library เขียนโค้ดไลน์เดียวจบเลย

2 comments

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.