AndroidでPytorchを動かす その2

もう少しで動きそう
github.com

(追記)
動いたのでMainActivityを貼っておく
MainActivity.kt

package com.example.predictor

import android.content.Context
import android.graphics.*
import android.os.Bundle
import android.widget.Button
import android.widget.Toast
import androidx.appcompat.app.AppCompatActivity
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.torchvision.TensorImageUtils
import java.io.File
import java.io.FileOutputStream

class MainActivity : AppCompatActivity() {

    lateinit var module: Module

    override fun onCreate(savedInstanceState: Bundle?) {

        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        val predButton:Button = findViewById<Button>(R.id.predict_button)
        predButton.setOnClickListener{
            predict()
        }

        module = Module.load(assetFilePath(this, "model.pt"))
    }
    fun assetFilePath(context: Context, assetName: String): String {
        val file = File(context.filesDir, assetName)
        if (file.exists() && file.length() > 0) {
            return file.absolutePath
        }
        context.assets.open(assetName).use { inputStream ->
            FileOutputStream(file).use { outputStream ->
                val buffer = ByteArray(4 * 1024)
                var read: Int
                while (inputStream.read(buffer).also { read = it } != -1) {
                    outputStream.write(buffer, 0, read)
                }
                outputStream.flush()
            }
            return file.absolutePath
        }
    }

    private fun predict() {
        var bitmap = BitmapFactory.decodeStream(getAssets().open("leaf.jpg"));
        val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB)
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
        val scores = outputTensor.dataAsFloatArray
        val arg = argmax(scores)
        val labels = arrayOf("Early blight", "Healthy", "Late blight")
        Toast.makeText(applicationContext, labels[arg], Toast.LENGTH_SHORT).show();
    }

    fun argmax(array: FloatArray): Int {
        var max = array[0]
        var re = 0
        for (i in 1 until array.size) {
            if (array[i] > max) {
                max = array[i]
                re = i
            }
        }
        return re
    }

}