มาลองทำ Object Detection หรือการตรวจจับวัตถุว่าเป็นอะไร โดยให้ Machine Learning ได้ประมวลผลทำกับ Unity ร่วมกับ TensorFlow API for .NET หรือ TensorFlowSharp
API การตรวจจับวัตถุ หรือ Object Detection ของ TensorFlow นับว่าเป็นเครื่องมือที่ทรงพลังตัวหนึ่งที่ทุกคนสามารถเปิดใช้งานการประมวลผล Machie Learning ได้อย่างรวดเร็ว โดยเฉพาะผู้ที่ไม่มีพื้นฐานการเรียนรู้ของการทำงานด้าน AI และ Machine Learning เพื่อสร้างและปรับใช้ซอฟต์แวร์ของพวกประมวลผล หรือจดจำรูปภาพให้ทำงานได้อย่างรวดเร็ว และเป็นประโยชน์
บทเรียนนี้เราจำเป็นต้องใช้ Library Assets ของ Unity ตัวหนึ่งที่ชื่อว่า TensorFlowSharp
- https://github.com/migueldeicaza/TensorFlowSharp
- ดาวน์โหลด Assets สำหรับ Import ที่ TensotFlowSharp Unity Package
หรือแบบที่พร้อมใช้งานเลยถ้าเข้าใจแล้ว ดาวน์โหลด (สำหรับสายขี้เกียจ) เป็น Template Project นั้นทางผมได้เตรียมให้แล้วที่ Github ไปดาวน์โหลดมาได้เลย
- ตัวอย่าง Source Code ที่ทำไว้แล้วแบบ Unity Package
- Git ตัว Project ที่ผมเตรียมไว้ให้ https://github.com/banyapon/TensorFowUnitySample แล้วตามด้วยดาวน์โหลด TensorFlowSharpUnity Package
เริ่มต้นพัฒนา
ทำการเปิด Project ของผมที่เราได้ git clone https://github.com/banyapon/TensorFowUnitySample ลงมาบน Unity หลังจากนั้นติดตั้ง ML Kit ของ TensorFlow ด้วยการ Import Assets ของ TensotFlowSharp Unity Package ที่ดาวน์โหลดมาที่เมนู
Assets->Import Package->Custom Package
รอจนกว่าระบบจะประมวลผลเสร็จก็เรียบร้อย
ขั้นตอนต่อมาเราจำเป็นจะต้องเปลี่ยนแพลตฟอร์มเป็น Android หลังจากนั้น ตั้งค่า Build Setting -> Player Setting โดย
ใน Other Setting ให้เราไปเพิ่ม ENABLE_TENSORFLOW ในช่อง Scripting Define Symbols
เมื่อเสร็จแล้วไปที่ Project -> Assets หา Folder ที่ชื่อว่า “ML-Agents” เลือก Plugins->Android เราจะเห็นไฟล์นามสกุล .dll มากมายปรากฏอยู่ให้ทำการลบ .dll ทุกไฟล์ เหลือไว้เพียง
- Java.Interop
- Mono.Android
- System.Linq
- TensorFlowSharp
- TensorFlowSharp.Android
ทีนี้สังเกตการทำงานของ Unity ของเราคือ Object Detect เราจะมีการบังคับเปิดกล้องมือถือผ่าน C# ที่ชื่อว่า PhoneCamera.cs ไปวางใน MainCamera ซึ่งจะมี Mode ให้เราเลือกคือ Detector จะไปเรียก Class ของ Detector.cs อีกที
using System; using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI; using System.IO; using System.Linq; using System.Text; using System.Text.RegularExpressions; using TFClassify; using System.Diagnostics; using System.Threading.Tasks; using Debug = UnityEngine.Debug; using TensorFlow; public enum Mode { Detect, Classify, } public class PhoneCamera : MonoBehaviour { private const int detectImageSize = 300; private const int classifyImageSize = 224; private static Texture2D boxOutlineTexture; private static GUIStyle labelStyle; private bool camAvailable; private WebCamTexture backCamera; private Texture defaultBackground; private Classifier classifier; private Detector detector; private List<BoxOutline> boxOutlines; private Vector2 backgroundSize; private Vector2 backgroundOrigin; public Mode mode; public RawImage background; public AspectRatioFitter fitter; public TextAsset modelFile; public TextAsset labelsFile; public Text uiText; private void Start() { LoadWorker(); defaultBackground = background.texture; WebCamDevice[] devices = WebCamTexture.devices; if(devices.Length == 0) { this.uiText.text = "No camera detected"; camAvailable = false; return; } for(int i = 0; i < devices.Length; i++) { if(!devices[i].isFrontFacing) { this.backCamera = new WebCamTexture(devices[i].name, Screen.width, Screen.height); } } if(backCamera == null) { this.uiText.text = "Unable to find back camera"; return; } this.backCamera.Play(); this.background.texture = this.backCamera; this.backgroundSize = new Vector2(this.backCamera.width, this.backCamera.height); camAvailable = true; string func = mode == Mode.Classify ? nameof(TFClassify) : nameof(TFDetect); InvokeRepeating(func, 1f, 1f); } private void Update() { if(!this.camAvailable) { return; } float ratio = (float)backCamera.width / (float)backCamera.height; fitter.aspectRatio = ratio; float scaleY = backCamera.videoVerticallyMirrored ? -1f : 1f; background.rectTransform.localScale = new Vector3(1f, scaleY, 1f); int orient = -backCamera.videoRotationAngle; background.rectTransform.localEulerAngles = new Vector3(0, 0, orient); } public void OnGUI() { if (this.boxOutlines != null && this.boxOutlines.Any()) { foreach (var outline in this.boxOutlines) { DrawBoxOutline(outline); } } } private void LoadWorker() { try { if (mode == Mode.Classify) { LoadClassifier(); } else { LoadDetector(); } } catch (TFException ex) { if (ex.Message.EndsWith("is up to date with your GraphDef-generating binary.).")) { this.uiText.text = "Error: TFException. Make sure you model trained with same version of TensorFlow as in Unity plugin."; } throw; } } private void LoadClassifier() { this.classifier = new Classifier( this.modelFile.bytes, Regex.Split(this.labelsFile.text, "\n|\r|\r\n") .Where(s => !String.IsNullOrEmpty(s)).ToArray(), classifyImageSize); } private void LoadDetector() { this.detector = new Detector( this.modelFile.bytes, Regex.Split(this.labelsFile.text, "\n|\r|\r\n") .Where(s => !String.IsNullOrEmpty(s)).ToArray(), detectImageSize); } private async void TFClassify() { var snap = TakeTextureSnap(); var scaled = Scale(snap, classifyImageSize); var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height); try { var probabilities = await this.classifier.ClassifyAsync(rotated); this.uiText.text = String.Empty; for(int i = 0; i < 3; i++) { this.uiText.text += probabilities[i].Key + ": " + String.Format("{0:0.000}%", probabilities[i].Value) + "\n"; } } catch (NullReferenceException) { this.uiText.text = "Error: NullReferenceException. Make sure you set correct INPUT_NAME and OUTPUT_NAME"; } finally { Destroy(snap); Destroy(scaled); } } private async void TFDetect() { UpdateBackgroundOrigin(); var snap = TakeTextureSnap(); var scaled = Scale(snap, detectImageSize); var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height); this.boxOutlines = await this.detector.DetectAsync(rotated); Destroy(snap); Destroy(scaled); } private void UpdateBackgroundOrigin() { var backgroundPosition = this.background.transform.position; this.backgroundOrigin = new Vector2(backgroundPosition.x - this.backgroundSize.x / 2, backgroundPosition.y - this.backgroundSize.y / 2); } private void DrawBoxOutline(BoxOutline outline) { var xMin = outline.XMin * this.backgroundSize.x + this.backgroundOrigin.x; var xMax = outline.XMax * this.backgroundSize.x + this.backgroundOrigin.x; var yMin = outline.YMin * this.backgroundSize.y + this.backgroundOrigin.y; var yMax = outline.YMax * this.backgroundSize.y + this.backgroundOrigin.y; DrawRectangle(new Rect(xMin, yMin, xMax - xMin, yMax - yMin), 4, Color.green); DrawLabel(new Rect(xMin + 10, yMin + 10, 200, 20), $"{outline.Label}: {(int)(outline.Score * 100)}%"); } public static void DrawRectangle(Rect area, int frameWidth, Color color) { // Create a one pixel texture with the right color if (boxOutlineTexture == null) { var texture = new Texture2D(1, 1); texture.SetPixel(0, 0, color); texture.Apply(); boxOutlineTexture = texture; } Rect lineArea = area; lineArea.height = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line lineArea.y = area.yMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line lineArea = area; lineArea.width = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line lineArea.x = area.xMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line } private static void DrawLabel(Rect position, string text) { if (labelStyle == null) { var style = new GUIStyle(); style.fontSize = 50; style.normal.textColor = Color.red; labelStyle = style; } GUI.Label(position, text, labelStyle); } private Texture2D TakeTextureSnap() { var smallest = backCamera.width < backCamera.height ? backCamera.width : backCamera.height; var snap = TextureTools.CropWithRect(backCamera, new Rect(0, 0, smallest, smallest), TextureTools.RectOptions.Center, 0, 0); return snap; } private Texture2D Scale(Texture2D texture, int imageSize) { var scaled = TextureTools.scaled(texture, imageSize, imageSize, FilterMode.Bilinear); return scaled; } private Task<Color32[]> RotateAsync(Color32[] pixels, int width, int height) { return Task.Run(() => { return TextureTools.RotateImageMatrix( pixels, width, height, -90); }); } private Task<Texture2D> RotateAsync(Texture2D texture) { return Task.Run(() => { return TextureTools.RotateTexture(texture, -90); }); } private void SaveToFile(Texture2D texture) { File.WriteAllBytes( Application.persistentDataPath + "/" + "snap.png", texture.EncodeToPNG()); } }
เราจะเรียก LoadDetector() และ TFClassify() จาก Library ร่วมกับคลาสในการส่งไฟล์รูปภาพจากกล้องมาประมวลผลเพื่อวาดกรอบข้อมูล พร้อม Label ใน
public static void DrawRectangle(Rect area, int frameWidth, Color color) { // Create a one pixel texture with the right color if (boxOutlineTexture == null) { var texture = new Texture2D(1, 1); texture.SetPixel(0, 0, color); texture.Apply(); boxOutlineTexture = texture; } Rect lineArea = area; lineArea.height = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line lineArea.y = area.yMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line lineArea = area; lineArea.width = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line lineArea.x = area.xMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line } private static void DrawLabel(Rect position, string text) { if (labelStyle == null) { var style = new GUIStyle(); style.fontSize = 50; style.normal.textColor = Color.red; labelStyle = style; } GUI.Label(position, text, labelStyle); }
ไป implement C# เพิ่มอีกไฟล์คือ Detector.cs ดังนี้:
using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using TensorFlow; using UnityEngine; namespace TFClassify { public class BoxOutline { public float YMin { get; set; } = 0; public float XMin { get; set; } = 0; public float YMax { get; set; } = 0; public float XMax { get; set; } = 0; public string Label { get; set; } public float Score { get; set; } } public class Detector { private static int IMAGE_MEAN = 117; private static float IMAGE_STD = 1; // Minimum detection confidence to track a detection. private static float MINIMUM_CONFIDENCE = 0.6f; private int inputSize; private TFGraph graph; private string[] labels; public Detector(byte[] model, string[] labels, int inputSize) { #if UNITY_ANDROID TensorFlowSharp.Android.NativeBinding.Init(); #endif this.labels = labels; this.inputSize = inputSize; this.graph = new TFGraph(); this.graph.Import(new TFBuffer(model)); } public Task<List<BoxOutline>> DetectAsync(Color32[] data) { return Task.Run(() => { using (var session = new TFSession(this.graph)) using (var tensor = TransformInput(data, this.inputSize, this.inputSize)) { var runner = session.GetRunner(); runner.AddInput(this.graph["image_tensor"][0], tensor) .Fetch(this.graph["detection_boxes"][0], this.graph["detection_scores"][0], this.graph["detection_classes"][0], this.graph["num_detections"][0]); var output = runner.Run(); var boxes = (float[,,])output[0].GetValue(jagged: false); var scores = (float[,])output[1].GetValue(jagged: false); var classes = (float[,])output[2].GetValue(jagged: false); foreach(var ts in output) { ts.Dispose(); } return GetBoxes(boxes, scores, classes, MINIMUM_CONFIDENCE); } }); } public static TFTensor TransformInput(Color32[] pic, int width, int height) { byte[] floatValues = new byte[width * height * 3]; for (int i = 0; i < pic.Length; ++i) { var color = pic[i]; floatValues [i * 3 + 0] = (byte)((color.r - IMAGE_MEAN) / IMAGE_STD); floatValues [i * 3 + 1] = (byte)((color.g - IMAGE_MEAN) / IMAGE_STD); floatValues [i * 3 + 2] = (byte)((color.b - IMAGE_MEAN) / IMAGE_STD); } TFShape shape = new TFShape(1, width, height, 3); return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length); } private List<BoxOutline> GetBoxes(float[,,] boxes, float[,] scores, float[,] classes, double minScore) { var x = boxes.GetLength(0); var y = boxes.GetLength(1); var z = boxes.GetLength(2); float ymin = 0, xmin = 0, ymax = 0, xmax = 0; var results = new List<BoxOutline>(); for (int i = 0; i < x; i++) { for (int j = 0; j < y; j++) { if (scores [i, j] < minScore) continue; for (int k = 0; k < z; k++) { var box = boxes [i, j, k]; switch (k) { case 0: ymin = box; break; case 1: xmin = box; break; case 2: ymax = box; break; case 3: xmax = box; break; } } int value = Convert.ToInt32(classes[i, j]); var label = this.labels[value]; var boxOutline = new BoxOutline { YMin = ymin, XMin = xmin, YMax = ymax, XMax = xmax, Label = label, Score = scores[i, j], }; results.Add(boxOutline); } } return results; } } }
เพิ่ม Class อีกตัวที่น่าสนใจคือ Classifier สร้าง C# ใหม่ขึ้นมาว่า Classifier.cs
using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using TensorFlow; using UnityEngine; namespace TFClassify { public class Classifier { private static int IMAGE_MEAN = 117; private static float IMAGE_STD = 1; private static string INPUT_NAME = "input"; private static string OUTPUT_NAME = "output"; private int inputSize; private TFGraph graph; private string[] labels; public Classifier(byte[] model, string[] labels, int inputSize) { #if UNITY_ANDROID TensorFlowSharp.Android.NativeBinding.Init(); #endif this.labels = labels; this.inputSize = inputSize; this.graph = new TFGraph(); this.graph.Import(model, ""); } public Task<List<KeyValuePair<string, float>>> ClassifyAsync(Color32[] data) { return Task.Run(() => { var map = new List<KeyValuePair<string, float>>(); using (var session = new TFSession(this.graph)) using (var tensor = TransformInput(data, this.inputSize, this.inputSize)) { var runner = session.GetRunner(); runner.AddInput(graph[INPUT_NAME][0], tensor).Fetch(graph[OUTPUT_NAME][0]); var output = runner.Run(); // output[0].Value() is a vector containing probabilities of // labels for each image in the "batch". The batch size was 1. // Find the most probably label index. var result = output[0]; var rshape = result.Shape; if (result.NumDims != 2 || rshape [0] != 1) { var shape = ""; foreach (var d in rshape) { shape += $"{d} "; } shape = shape.Trim (); Debug.Log("Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [{shape}]"); Environment.Exit (1); } var probabilities = ((float[][])result.GetValue(jagged: true))[0]; for (int i = 0; i < labels.Length; i++) { map.Add(new KeyValuePair<string, float>(labels[i], probabilities[i] * 100)); } foreach (var ts in output) { ts.Dispose(); } } return map.OrderByDescending(x => x.Value).ToList(); }); } public static TFTensor TransformInput(Color32[] pic, int width, int height) { float[] floatValues = new float[width * height * 3]; for (int i = 0; i < pic.Length; ++i) { var color = pic[i]; floatValues [i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD; floatValues [i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD; floatValues [i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD; } TFShape shape = new TFShape(1, width, height, 3); return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length); } } }
เสียบสาย USB เข้ากับเครื่องคอมพิวเตอร์ของเราหลังจากนั้นให้ Build & run ตัว Android ของเราลงสมาร์ตโฟน เพื่อเริ่มต้นทดสอบ
จะเห็นว่าเราสามารถทำ Object Detector ง่ายๆ ด้วย TensorFlowSharp กับ Unity ได้แล้ว
สำหรับคนที่มี Model ของ TensorFlow เป็นของตัวเอง Model ที่เลือกมาต้องถูก Trained ด้วย TensorFlow 1.4 ขึ้นไปนะครับ เปลี่ยนนามสกุลไฟล์จาก .pb เป็น .bytes ด้วย