1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| import numpy as np import pandas as pd import time import IPython
length = 5 epsilon = 0.9 a = [0, 1] Q_a = np.zeros([length + 1, len(a)]) E = Q_a.copy() alpha = 0.1 gamma = 0.9 trace_decay = 0.9 game_over = False
def print_environment(state): """打印环境""" str = '-' * state str += 'o' str += '-' * (length - state - 1) if state != length: str += '$' print(str)
def best_Q(state): """最佳动作获得的Q""" return Q_a[state].max()
def best_action(state): """最佳动作, 有多个最大从中随机选择一个""" max_arr = np.argsort(Q_a[state])[::-1] max_index = np.random.randint(0, np.sum(Q_a[state] == Q_a[state][max_arr[0]])) return max_arr[max_index]
def update_Q(state, next_state, action, next_action, reward): """更新Q表, 相比Sarsa的差别就在于E表起到了对之前的Q表更新的作用""" """ E[state][action] += 1 """ global E global Q_a E[state,:] = 0 E[state][action] = 1 Q_a += alpha * E * (reward + gamma * Q_a[next_state][next_action] - Q_a[state][action]) E *= trace_decay * gamma
def greedy(state): """贪婪策略""" if np.random.rand() < epsilon: return best_action(state) else: return np.random.randint(0, Q_a[state].shape[0])
def update_state(now_state, action): """更新环境及获取奖赏""" if action == 0: if now_state != 0: next_state = now_state - 1 reward = 0 else: next_state = now_state reward = 0 else: if now_state == length - 1: next_state = now_state + 1 reward = 1 global game_over game_over = True else: next_state = now_state + 1 reward = 0 print_environment(next_state) return next_state, reward
def Sarsa_Lambda(start, rounds, max_times): """start是起点, rounds是轮数, max_times是每轮迭代最大次数""" state = start for i in range(rounds): state = start print_environment(state) time.sleep(1) IPython.display.clear_output() action = greedy(state) for j in range(max_times): global game_over if game_over: game_over = False print('Game Over') time.sleep(1) IPython.display.clear_output() E[:] = 0 break next_action = greedy(state) next_state, reward = update_state(state, action) update_Q(state, next_state, action, next_action, reward) state = next_state action = next_action time.sleep(1) IPython.display.clear_output()
Sarsa_Lambda(0, 10, 100)
|