Machine Learning R

อธิบาย Logistic Regression พร้อมโค้ดตัวอย่างใน R

Tutorial วันนี้เรามาอธิบาย concept ของ Logistic Regression เบื้องต้น พร้อมโค้ดตัวอย่างใน R สำหรับสร้างและทดสอบโมเดล – Case Study ทำนายการเกิดมะเร็งเต้านม (Breast Cancer Dataset)

When to use?

เรานิยมใช้ Logistic Regression กับปัญหา Binary Classification i.e. ทำนาย target variable ที่มีสอง classes และใช้ค่า % accuracy สำหรับวัดผลโมเดลเบื้องต้น ด้านล่างเป็นตัวอย่าง use cases ในชีวิตจริง

  • Churn prediction – ทำนายว่าลูกค้าจะเลิกใช้บริการหรือเปล่า (yes/ no)
  • Repeated purchase prediction – ทำนายว่าลูกค้าจะกลับมาซื้อสินค้าหรือเปล่า (yes/ no)
  • Disease detection – ทำนายว่าจะเป็นโรคหรือเปล่า (yes/ no)
  • Spam classification – ทำนายว่าอีเมล์เป็น spam หรือเปล่า (yes/ no)
  • (Marketing) Conversion prediction – ทำนายว่า user จะ take action หรือเปล่า (yes/ no)

Key Concept

โมเดลนี้ง่ายกว่าที่คิด !! เพราะ Logistic Regression จริงๆคือ Extended Version ของ Linear Regression รูปด้านล่างแอดลอง plot กราฟขึ้นมาจะเห็นว่าแกนตั้ง y มีได้สองค่าคือ {0, 1} ส่วนแกนนอนคือตัวแปร X1

สำหรับข้อมูลประเภทนี้ ถ้าใช้ Linear Regression ทั่วไป ถามว่าเทรนได้ไหม ก็ทำได้แต่ผลจะออกมาไม่ดี เหตุผลที่เราไม่ใช้ linear regression กับปัญหา {0, 1} แบบนี้คือ  [1] linear regression เหมาะกับตัวแปร y แบบ continuous และ [2] ผลทำนายของสมการ linear regression เป็นแบบเส้นตรงมีโอกาสที่จะต่ำกว่าศูนย์หรือสูงกว่าหนึ่ง ซึ่งไม่ตอบโจทย์ binary classification ที่ค่า y ต้องอยู่ในช่วง [0,1]  

แล้วเราจะแก้ปัญหานี้ยังไงดี? นักคณิตศาสตร์เลยคิด Sigmoid Function ขึ้นมาเพื่อใช้ normalize ตัวเลขอะไรก็ได้ให้มีค่าอยู่ระหว่าง [0, 1] สำหรับแก้ปัญหา binary classification โดยเฉพาะ กราฟด้านล่างเราเปลี่ยนสมการเส้นตรงให้กลายเป็น S-Curve ที่ fit กับข้อมูล [0, 1] ได้ดีขึ้นมาก .. What a Cool Trick!

ใช้ sigmoid function เพื่อ normalize ค่าให้อยู่ระหว่าง [0,1]

ดาวน์โหลดไฟล์ตัวอย่างการเขียน Sigmoid Function ใน Excel ได้ที่นี่

How Sigmoid Works?

Sigmoid สามารถเขียนเป็นสมการทางคณิตศาสตร์ได้ตามรูปด้านล่าง โดยที่ e คือ Exponential Function หรือฟังชั่น exp() ใน Excel/ R นักคณิตศาสตร์ใช้ฟังชั่นตระกูล exp ในการเปลี่ยนสมการ linear เป็น non-linear เส้นกราฟ Sigmoid ที่เราเห็นด้านบนเลยกลายเป็น s-curve สวยงาม แฮร่!

ค่า Z ในสมการคือค่า weighted sum (เหมือนสมการ linear regression) แต่ Logistic Reression ใช้เทคนิคที่เรียกว่า Maximum Likelihood ในการคำนวณ weights (bo, b1, b2, …) แทนการใช้ Least Squares

## weighted sum (just like linear regression)
Z = b0 + b1x1 + b2x2 + b3x3 + b4x4 + ...

## apply sigmoid to Z value
probability_y = exp(Z) / (1+exp(Z))

ผลลัพธ์ที่ได้จาก sigmoid(Z) คือความน่าจะเป็นที่ y=1 เราสามารถกำหนด threshold สำหรับการทำนายของโมเดลได้ เช่น ถ้า sigmoid(Z) >= 0.5 ให้ทำนาย y=1 (positive) แต่ถ้าน้อยกว่า 0.5 ให้ทำนาย y=0 (negative)

ทำไมมันง่ายอย่างงี้ๆๆ ตอนนี้เราสามารถสร้างโมเดลที่ fit กับข้อมูล 0, 1 ได้แล้ว 😛

Implementation in R

โอเคร! ตอนนี้เราเข้าใจ concept เบื้องต้นของ Logistic Regression แล้ว ถัดไปมาลองเขียนโค้ด R กันบ้างโดยโค้ดของเราจะแบ่งเป็น 5 ขั้นตอน Load data → Clean data → Split data → Train model → Test model

[1] Load Data

Tutorial วันนี้เราใช้ข้อมูล Breast Cancer จาก package mlbench นักเรียนสามารถโหลดข้อมูลเข้าสู่ RStudio ด้วยโค้ดด้านล่าง ถ้าใครยังไม่เคยติดตั้ง package นี้ให้รันโค้ด install.packages("mlbench") ก่อน

Target variable ที่เราต้องการทำนายคือ “Class” {benign, malignant} โดย positive class ของโมเดลเราคือ malignant (เนื้อร้าย/ เป็นมะเร็ง) ซึ่งมีอยู่ประมาณ 35% ใน dataset นี้

## install.packages("mlbench")
library(mlbench)
data("BreastCancer")

[2] Clean Data

โหลดข้อมูลเสร็จแล้ว นักเรียนสามารถเรียกดู structure เบื้องต้นของ dataframe ด้วยฟังชั่น str() หรือ head() โค้ดด้านล่างเราใช้ฟังชั่น na.omit() เพื่อลบแถวที่มี missing values และลบคอลั่ม Id ด้วยการ assign NULL

## check if any missing values
mean(complete.cases(BreastCancer)) 

## remove rows with NA
df <- na.omit(BreastCancer) 

## remove column id
df$Id <- NULL 

[3] Split Data

แบ่งข้อมูลเป็น train 80% และ test 20% อย่าลืม set.seed(1) เพื่อให้ผล random id ของเราสามารถทำซ้ำได้

set.seed(1)
id <- sample(1:nrow(df), 0.8*nrow(df))
train_df <- df[id, ]
test_df <- df[-id, ]

[4] Train Model

เทรน logistic regression ด้วยฟังชั่น glm() กำหนด family = “binomial” เพราะว่า target ของเรามี 2 classes {benign, malignant} เสร็จแล้วนำโมเดลที่ได้ไปทำนาย train_df และคำนวณค่า train accuracy

ถ้ารันโค้ดด้านล่างเสร็จแล้วจะพบว่าค่า train accuracy เท่ากับ 100% ตอนนี้เราต้อง skeptical กับผลที่ได้แล้ว เพราะเป็นสัญญาณหนึ่งของปัญหา overfitting เด๋วเราจะลองนำโมเดลนี้ไปทำนาย test_df ในขั้นตอนต่อไป

## Train logistic regression
log_model <- glm(Class ~ ., data = train_df, family = "binomial")

## Predict and evaluate train dataset
p1 <- predict(log_model, type = "response")
p1 <- ifelse(p1 >= .5, T, F)
train_result <- table(p1, train_df$Class)
print(paste0("Train Accuracy: ", sum(diag(train_result)/ nrow(train_df))) )

Note – accuracy ไม่ใช่ metric เดียวที่เราใช้วัดผลโมเดล binary classification ปกติเราจะดูค่า precision, recall และ F1-score ด้วย อ่านเพิ่มเติมได้ในบทความแนะนำด้านล่าง

10 Metrics พื้นฐานสำหรับวัดผลโมเดล ML

รู้จักกับ Accuracy, Precision, Recall, F1-Score สำหรับ Classification Problem

[5] Test Model

โค้ดด้านล่างเขียนเหมือนตอนทำนาย train_df แค่เปลี่ยนชื่อข้อมูลทั้งหมดเป็น test_df เราจะได้ ค่า test accuracy อยู่ที่ 91.97% ซึ่งน้อยกว่า train accuracy ประมาณ 9% เราสามารถสรุปได้ทันทีว่าโมเดลที่เราสร้างขึ้นมามีปัญหา overfitting ดูง่ายๆจากผล test accuracy ที่มีค่าน้อยกว่า train accuracy มากๆ 

## Predict and evaluate test dataset
p2 <- predict(log_model, newdata = test_df, type = "response")
p2 <- ifelse(p2 >= .5, T, F)
test_result <- table(p2, test_df$Class)
print(paste0("Test Accuracy: ", round(sum(diag(test_result)/ nrow(test_df)), 4)))

แล้วเราจะลดปัญหา overfitting ได้ยังไง? คำตอบอยู่ข้างล่าง Read On!

Regularization

เทคนิคสำคัญที่ ML practitioners ใช้ลดปัญหา overfitting เรียกว่า “Regularization” เป็นเทคนิคที่ทำให้โมเดลของเรา simple ขึ้น → generalize ได้ดีขึ้น สำหรับปัญหา regression เทคนิคนี้จะไปปรับค่า coefficients ของโมเดลให้มีขนาดเล็กลง หรือใน extreme cases คือปรับ coefficients เป็นศูนย์เลย

Next Time – บทความต่อไปเราจะเขียนอธิบาย regularization ให้นักเรียนอ่านเต็มๆอีกที 🙂

Full R Code

โค้ดตั้งแต่ line 22 เป็นต้นไปใช้สำหรับสร้าง Regularized Logistic Regression (ridge, lasso, elastic net) ด้วย package glmnet และ caret (อีกชื่อหนึ่งของ Regularization คือ Penalized Regression)

  • Line 31: Regularization จะมีสอง hyperparameters หลักที่เราจูนได้คือ alpha และ lambda
  • Line 32-33: เทรนโมเดลด้วย Grid Search 5-Fold CV
  • Line 48: Train accuracy 97.25%
  • Line 54: Test accuracy 97.08%

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.