Tree Based Models in R

วันนี้เราจะมาสอนสร้าง tree-based models ง่ายๆใน R

เรา assume ว่าเพื่อนๆรู้เรื่อง machine learning นิดหน่อย เช่น ทำไมต้อง train/ test split และทำไมต้องจูนค่า hyperparameter เป็นต้น ถ้าใครยังไม่รู้ว่า ML คืออะไร? ลองอ่านบทความแนะนำของ dataquest ได้ในลิ้งด้านล่าง

tutorial วันนี้แบ่งเป็น 5 พาร์ท ใช้เวลาอ่านและทำตามประมาณ 20 นาที

  1. Getting to Know R – Optional สำหรับเพื่อนๆที่ยังไม่เคยเขียน R มาก่อนเลย
  2. Prepare Data
  3. Decision Tree
  4. Random Forest
  5. Summary

Ready to fly? เปิด RStudio ขึ้นมาแล้วลอง copy โค้ดด้านล่างไปรันใน console ได้เลย 😛


Getting to Know R

5 นาที – อธิบายการเขียน R เบื้องต้น ถ้าใครเขียน R เป็นบ้างแล้ว สามารถ skip ไปที่หัวข้อ Prepare Data ได้เลย

  • line 3-6 ลองสร้าง vector ง่ายๆด้วยฟังชั่น c()
  • line 9-12 เราใช้ฟังชั่น data.frame() เพื่อสร้าง dataframe ซึ่งเป็นหัวใจสำคัญของการทำ data analysis และ machine learning ใน R
  • line 15-21 คือฟังชั่นที่เราใช้บ่อยๆกับ dataframe เช่น str, head, tail และ summary

จบพาร์ทแรก ตอนนี้เพื่อนๆสามารถเขียนโค้ด R ง่ายๆเพื่อสร้าง vector และ dataframe รวมถึงการใช้งานฟังชั่นเบื้องต้น session ต่อไป เราจะสอนโหลด dataset สำหรับ tutorial นี้


Prepare Data

5 นาที – โหลด Breast Cancer dataset เข้าสู่ RStudio จัดการกับ missing values และแบ่งข้อมูลเป็น 70% train และ 30% test

Breast Cancer เป็นข้อมูลสำหรับปัญหา binary classification มีทั้งหมด 11 columns 699 observations โดยตัวแปรที่เราต้องการ predict คือ Class {benign, malignant}

  • line 2-3 ดาวน์โหลดและติดตั้ง package mlbench ใน R ซึ่งเป็น package ที่รวบรวม datasets สำหรับงาน machine learning จากเว็บ UCI ML repository
  • line 6 ใช้ฟังชั่น data() เพื่อโหลด BreastCancer dataframe เข้าไปใน R
  • line 21-24 เป็น standard code ที่ใช้ split data เป็นสองส่วน – 70% train และ 30% test ถ้าใครอยากลองเปลี่ยน ratio สามารถเปลี่ยนได้ที่เลข 0.7 (เป็น 0.8 หรือ 0.6 ก็ได้)

Good to know – ในชีวิตจริงเรานิยม split data เป็นสามส่วน {train, validate, test} หรือใช้ k-fold cross validation ในการ train model


Decision Tree

5 นาที – train decision tree และการจูนค่า complexity parameter เพื่อให้ได้ accuracy ที่สูงขึ้นและลดการ overfit ของโมเดล

หน้าตาของ decision tree เวลาเรา visualize จะออกมาแบบรูปด้านล่าง

โดยตัวแปรแรกที่ถูกใช้ split data ที่ root node (ด้านบนสุดของ tree) คือ Cell.size = 1,2 ถ้าตอบ yes จะวิ่งไปทางซ้าย ตอบ no จะไปทางขวา จนลงมาถึง terminal node (ด้านล่างสุดของ tree) ที่กระบวนการ split data หยุดตรงนี้ มาลองอ่าน diagram กัน

  • ถ้า case นี้มี Cell.size = 1,2 โมเดลจะ predict ว่า benign (เซลล์ดี)
  • ถ้า case นี้มี Cell.size > 2 และ Cell.shape > 2 โมเดลจะ predict ว่า malignant (เซลล์ร้าย)
decision tree in R
visualize decision tree using rpart.plot() function

สำหรับ package หลักที่เราใช้ build และ visualize tree ใน R คือ rpart และ rpart.plot

  • line 8 คือการเขียนโค้ดเพื่อสร้าง decision tree model ด้วยฟังชั่น rpart()
    • Class ~ . คือ formula ใน R อ่านว่า “ตัวแปร Class เป็นฟังชั่นของตัวแปร x ทั้งหมดใน train dataset”
  • line 11 เราเรียกดูค่า complexity parameter (เรียกสั้นๆว่า cp) ที่ทำให้ค่า xerror ต่ำที่สุด
    • cp สูงเกินไป – tree ของเราจะ underfit ทำให้ได้ accuracy ต่ำ ↔ xerror สูง
    • cp ต่ำเกินไป – tree ของเราจะ overfit ไม่สามารถนำโมเดลไปใช้กับ new data ได้
    • เราต้องเลือกค่า cp ที่เหมาะสมเพื่อให้ได้ optimal final model
  • line 19 เราใช้ฟังชั่น prune() เพื่อสร้าง decision tree ด้วยค่า cp ที่ดีที่สุดของเรา ในตัวอย่างเราเลือกใช้ cp = 0.01 สำหรับ final model
  • line 22 visualize final model ของเราด้วยฟังชั่น rpart.plot()

พอเรา prune จนได้ final model แล้ว เราจะใช้มันทำนาย test dataset ด้วยฟังชั่น predict() และสร้าง confusion matrix เพื่อวัด accuracy ของโมเดลด้วยฟังชั่น table() ใน line 25-26

เราสามารถคำนวณ accuracy จาก confusion matrix ด้านบน ด้วยการหาผลรวมเส้นทแยงมุม หารด้วยจำนวน testing cases ทั้งหมดที่เราทดสอบ (124 + 70) / (124 + 7 + 4 + 70) = 0.9463415

decision tree ที่เราสร้างขึ้นมาใช้ cp = 0.10 และได้ test accuracy = 94.63%


Random Forest

5 นาที – train random forest ด้วยฟังชั่น randomForest()

cats

concept ของ random forest คือการสร้าง decision tree หลายๆต้น โดยค่า default ของ number of trees – ntree ในฟังชั่น randomForest() จะอยู่ที่ 500 ต้น แล้วค่อยเอาผล predictions ของทั้ง 500 ต้นมาโหวตกันว่า new case นั้นจะเป็น benign หรือ malignant

เช่น จากทั้งหมด 500 ต้น – 400 ต้นทำนายว่า benign และอีก 100 ต้นทำนายว่า malignant เราจะยึดผลโหวตส่วนใหญ่ final prediction เท่ากับ benign

random forest เป็นโมเดลประเภท Ensemble Learning ที่เกิดจากการสร้างและรวมหลายๆโมเดลเข้าด้วยกันเพื่อให้ได้ model performance ที่ดีขึ้น and it works !!

วิธีการคำนวณ accuracy จาก confusion matrix จะเหมือนกับของ decision tree เลย (127 + 73) / (127 + 4 + 1 + 73) = 0.9756098

random forest ที่เราสร้างขึ้นมาใช้ ntree = 500 และได้ test accuracy = 97.56% สูงกว่า decision tree ประมาณ 3% แต่แลกมากับเวลาในการ train ที่นานขึ้น (จะเห็นความแตกต่างเรื่องเวลาชัดมาก ถ้า dataset เราใหญ่)


Summary

ใน tutorial นี้ เราเรียน concept ของ tree-based models เบื้องต้นใน R ซึ่งสองตัวที่เราใช้กันเยอะมากในหลายๆ applications คือ decision tree และ random forest

  • decision tree – train เร็ว accuracy ปานกลาง ค่อนข้าง overfit ถ้าเราไม่ prune มันก่อน แต่ข้อดีคืออธิบายง่ายมาก
  • random forest – train ช้า accuracy สูง แต่อธิบายยากเพราะกระบวนการสร้าง trees 500 ต้นเกิดขึ้นแบบ random อย่างที่ชื่อ algorithm implies

decision tree ปกติจะเป็นโมเดลแรกๆที่เราลองสร้าง (baseline model) แล้วค่อยพยายาม improve performance ด้วยการจูน hyperparameter หรือเปลี่ยนไปใช้ tree-based แบบอื่นๆอย่าง random forest (bagging algorithm) หรือ xgboost (boosting algorithm)

อยากเรียน algorithm อะไรอีก? คอมเม้นบอกเราใต้บล๊อกวันนี้ได้เลย 😎

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.