次はソースを実装。
作るのは、Academy継承クラスとAgent継承クラスです。
(1) Academy継承クラス
Academyクラスを継承。
MLAgentをインポートして、Academyクラスを継承するようにすればOKです。
今回やりたいことに範囲では特にメソッドの実装はいらなさそうです。
using UnityEngine;
using MLAgents;
public class TestAcademy : Academy
{
public override void InitializeAcademy()
{
}
public override void AcademyReset()
{
}
public override void AcademyStep()
{
}
}
(2) Agent継承クラス
Agentクラスを継承。
部分的には、Basicクラスからそのまま持ってきています。整理していないのはご勘弁。
using UnityEngine;
using MLAgents;
public class TestAgent : Agent
{
public GameObject m_Goal1;
public GameObject m_Goal2;
public GameObject m_Player;
public float timeBetweenDecisionsAtInference;
TestAcademy m_Academy;
int m_PlayerPosition;
int m_Goal1Position;
int m_Goal2Position;
float m_TimeSinceDecision;
public override void InitializeAgent()
{
m_Academy = FindObjectOfType(typeof(TestAcademy)) as TestAcademy;
m_Goal1.transform.position = new Vector3(-10f, 0.5f, 0f);
m_Goal2.transform.position = new Vector3( 10f, 0.5f, 0f);
m_Player.transform.position = new Vector3(0f, 0.5f, 0f);
}
public override void CollectObservations()
{
AddVectorObs(m_PlayerPosition, 20);
}
public override void AgentAction(float[] vectorAction)
{
var movement = (int)vectorAction[0];
var direction = 0;
switch (movement)
{
case 1:
direction = -1;
break;
case 2:
direction = 1;
break;
}
m_PlayerPosition += direction;
m_Player.transform.position = new Vector3(m_PlayerPosition, 0.5f, 0f);
AddReward(-0.01f);
if (m_PlayerPosition <= m_Goal1Position)
{
Done();
AddReward(0.1f);
}
if (m_PlayerPosition >= m_Goal2Position)
{
Done();
AddReward(1f);
}
}
public override void AgentReset()
{
m_PlayerPosition = 0;
m_Goal1Position = -10;
m_Goal2Position = 10;
m_Goal1.transform.position = new Vector3(-10f, 0.5f, 0f);
m_Goal2.transform.position = new Vector3( 10f, 0.5f, 0f);
m_Player.transform.position = new Vector3(0f, 0.5f, 0f);
}
public override void AgentOnDone()
{
m_PlayerPosition = 0;
m_Goal1Position = -10;
m_Goal2Position = 10;
m_Goal1.transform.position = new Vector3(-10f, 0.5f, 0f);
m_Goal2.transform.position = new Vector3(10f, 0.5f, 0f);
m_Player.transform.position = new Vector3(0f, 0.5f, 0f);
}
public void FixedUpdate()
{
WaitTimeInference();
}
void WaitTimeInference()
{
if (!m_Academy.IsCommunicatorOn)
{
RequestDecision();
}
else
{
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference)
{
m_TimeSinceDecision = 0f;
RequestDecision();
}
else
{
m_TimeSinceDecision += Time.fixedDeltaTime;
}
}
}
}
まず、publicメンバ変数として、Goal1, Goal2, Playerを定義。
これらは、作成したGoal1, Goal2, Playerのそれぞれのオブジェクトを設定します。
■ Agentクラスの各関数
- InitializeAgent()メソッド:Agentの初期化時に呼ばれる
- AgentAction()メソッド:ステップ毎に呼ばれる
- AgentReset()メソッド:リセット時に呼ばれる
- AgentOnDone()メソッド:完了時に呼ばれる
- CollectObservation()メソッド:Stateの取得時に呼ばれる。
(A) InitiaizeAgentメソッド
初期化。まあ念のためです。
(B) AgentActionメソッド
このメソッドが一番のメインです。
今回の学習では、素早くまっすぐGoal2にたどり着けばいいという学習結果を身に着けることが期待値です。
Vector Actionとしては、「Discrete(離散)」か「Continuous(連続)」を選択できますが、今回の場合は、入力値としては「0」「1」「2」の整数値のいずれかでいいので、
- 「Space Type」としては「Discrete(離散)」を選択。
- 「Branches Size」は「1」(パラメータの数だと思います。今回の入力値は1つなので。)
- 「Branch 0 Size」は「3」(「0」「1」「2」の3種類のいずれかから返す)
を設定。
こうすることで,AgentActionメソッドのパラメータとして、Brainが判定した「0」「1」「2」のいずれかを受け取れます。
つまり,vectorAction[0] には「0」「1」「2」のいずれかが入っています。
その結果、
- 「0」の場合、direction = 0 つまり移動無し
- 「1」の場合、direction = -1 つまり左に1移動
- 「2」の場合、direction = 1 つまり右に1移動
となります。
今回の場合,Goalは2つ用意しており,
- Goal1に到着 → 報酬(小)を与える
- Goal2に到着 → 報酬(大)を与える
というようにします。
またGoalに早くたどり着いてもらうために、ステップごとに罰(小)を与えるようにしています。つまり、立ち止まっていたり、ウロウロするだけで、ポイントが下がるようにします。
今回の場合は
- Goal2に到着: 報酬「1」を与える
- Goal1に到着: 報酬「0.1」を与える
- ステップ毎:罰「0.01」(報酬「-0.01」)を与える
という感じです。
いずれかのゴールにたどり着いたら、Done()を呼び出し、その回の実行は終了。
次の学習に移ります。
(C) AgentReset / AgentOnDone メソッド
次の学習のため、各オブジェクトの位置を初期化します。
(ただ、Player以外は特に不要ですね)
あと、関数化すれば二回も同じことしなくてもいいんですね。初期化でも使えますし。
(D) CollectObservcationメソッド
今回の場合は、Playerオブジェクトの位置を渡すことで、どの状態の時にどうすればいいかの判断材料にしてもらいます。
あと、FixedUpdate() はML-Agents固有のものではなく、Unityの関数。
一定時間ごとの評価のために呼び出されます。
それでも判定を一定時間以上の間隔にしないと、学習している間隔が早すぎてわかりにくいですので。
あとは、Unity上でAcademyとAgentをそれぞれEmptyオブジェクトにアタッチして、Agent側ではpublicオブジェクトを設定して完了です。