博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【强化学习】python 实现 q-learning 迷宫通用模板
阅读量:7087 次
发布时间:2019-06-28

本文共 9982 字,大约阅读时间需要 33 分钟。

本文作者:hhh5460

本文地址: 

0.说明

这里提供了二维迷宫问题的一个比较通用的模板,拿到后需要修改的地方非常少。

对于任意的二维迷宫的 class Agent,只需修改三个地方:MAZE_R, MAZE_R, rewards其他的不要动如下所示:

class Agent(object):    '''个体类'''    MAZE_R = 6 # 迷宫行数    MAZE_C = 6 # 迷宫列数        def __init__(self, alpha=0.1, gamma=0.9):        '''初始化'''        # ... ...        self.rewards = [0,-10,0,  0,  0, 0,                        0,-10,0,  0,-10, 0,                        0,-10,0,-10,  0, 0,                        0,-10,0,-10,  0, 0,                        0,-10,0,-10,  1, 0,                        0,  0,0,-10,  0,10,] # 奖励集。出口奖励10,陷阱奖励-10,元宝奖励1        # ... ...

 

1.完整代码

import pandas as pdimport randomimport timeimport pickleimport pathlibimport osimport tkinter as tk''' 6*6 的迷宫:-------------------------------------------| 入口 | 陷阱 |      |      |      |      |-------------------------------------------|      | 陷阱 |      |      | 陷阱 |      |-------------------------------------------|      | 陷阱 |      | 陷阱 |      |      |-------------------------------------------|      | 陷阱 |      | 陷阱 |      |      |-------------------------------------------|      | 陷阱 |      | 陷阱 | 元宝 |      |-------------------------------------------|      |      |      | 陷阱 |      | 出口 |-------------------------------------------作者:hhh5460时间:20181219地点:Tai Zi Miao'''class Maze(tk.Tk):    '''环境类(GUI)'''    UNIT = 40  # pixels    MAZE_R = 6  # grid row    MAZE_C = 6  # grid column     def __init__(self):        '''初始化'''        super().__init__()        self.title('迷宫')        h = self.MAZE_R * self.UNIT        w = self.MAZE_C * self.UNIT        self.geometry('{0}x{1}'.format(h, w)) #窗口大小        self.canvas = tk.Canvas(self, bg='white', height=h, width=w)        # 画网格        for c in range(1, self.MAZE_C):            self.canvas.create_line(c * self.UNIT, 0, c * self.UNIT, h)        for r in range(1, self.MAZE_R):            self.canvas.create_line(0, r * self.UNIT, w, r * self.UNIT)        # 画陷阱        self._draw_rect(1, 0, 'black') # 在1列、0行处,下同        self._draw_rect(1, 1, 'black')        self._draw_rect(1, 2, 'black')        self._draw_rect(1, 3, 'black')        self._draw_rect(1, 4, 'black')        self._draw_rect(3, 2, 'black')        self._draw_rect(3, 3, 'black')        self._draw_rect(3, 4, 'black')        self._draw_rect(3, 5, 'black')        self._draw_rect(4, 1, 'black')        # 画奖励        self._draw_rect(4, 4, 'yellow')        # 画玩家(保存!!)        self.rect = self._draw_rect(0, 0, 'red')        self.canvas.pack() # 显示画作!            def _draw_rect(self, x, y, color):        '''画矩形,  x,y表示横,竖第几个格子'''        padding = 5 # 内边距5px,参见CSS        coor = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding]        return self.canvas.create_rectangle(*coor, fill = color)     def move_agent_to(self, state, step_time=0.01):        '''移动玩家到新位置,根据传入的状态'''        coor_old = self.canvas.coords(self.rect) # 形如[5.0, 5.0, 35.0, 35.0](第一个格子左上、右下坐标)        x, y = state % 6, state // 6 #横竖第几个格子        padding = 5 # 内边距5px,参见CSS        coor_new = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding]        dx_pixels, dy_pixels = coor_new[0] - coor_old[0], coor_new[1] - coor_old[1] # 左上角顶点坐标之差        self.canvas.move(self.rect, dx_pixels, dy_pixels)        self.update() # tkinter内置的update!        time.sleep(step_time)class Agent(object):    '''个体类'''    MAZE_R = 6 # 迷宫行数    MAZE_C = 6 # 迷宫列数        def __init__(self, alpha=0.1, gamma=0.9):        '''初始化'''        self.states = range(self.MAZE_R * self.MAZE_C) # 状态集。0~35 共36个状态        self.actions = list('udlr')              # 动作集。上下左右  4个动作 ↑↓←→ ←↑→↓↖↗↘↙        self.rewards = [0,-10,0,  0,  0, 0,                        0,-10,0,  0,-10, 0,                        0,-10,0,-10,  0, 0,                        0,-10,0,-10,  0, 0,                        0,-10,0,-10,  1, 0,                        0,  0,0,-10,  0,10,] # 奖励集。出口奖励10,陷阱奖励-10,元宝奖励5        #self.hell_states = [1,7,13,19,25,15,31,37,43,10] # 陷阱位置                self.alpha = alpha        self.gamma = gamma                self.q_table = pd.DataFrame(data=[[0 for _ in self.actions] for _ in self.states],                                    index=self.states,                                     columns=self.actions)        def save_policy(self):        '''保存Q table'''        with open('q_table.pickle', 'wb') as f:            pickle.dump(self.q_table, f, pickle.HIGHEST_PROTOCOL)        def load_policy(self):        '''导入Q table'''        with open('q_table.pickle', 'rb') as f:            self.q_table = pickle.load(f)        def choose_action(self, state, epsilon=0.8):        '''选择相应的动作。根据当前状态,随机或贪婪,按照参数epsilon'''        #if (random.uniform(0,1) > epsilon) or ((self.q_table.ix[state] == 0).all()):  # 探索        if random.uniform(0,1) > epsilon:             # 探索            action = random.choice(self.get_valid_actions(state))        else:            #action = self.q_table.ix[state].idxmax() # 利用 当有多个最大值时,会锁死第一个!            #action = self.q_table.ix[state].filter(items=self.get_valid_actions(state)).idxmax() # 重大改进!然鹅与上面一样            s = self.q_table.ix[state].filter(items=self.get_valid_actions(state))            action = random.choice(s[s==s.max()].index) # 从可能有多个的最大值里面随机选择一个!        return action        def get_q_values(self, state):        '''取给定状态state的所有Q value'''        q_values = self.q_table.ix[state, self.get_valid_actions(state)]        return q_values            def update_q_value(self, state, action, next_state_reward, next_state_q_values):        '''更新Q value,根据贝尔曼方程'''        self.q_table.ix[state, action] += self.alpha * (next_state_reward + self.gamma * next_state_q_values.max() - self.q_table.ix[state, action])        def get_valid_actions(self, state):        '''取当前状态下所有的合法动作'''        valid_actions = set(self.actions)        if state // self.MAZE_C == 0:                 # 首行,则 不能向上            valid_actions -= set(['u'])        elif state // self.MAZE_C == self.MAZE_R - 1: # 末行,则 不能向下            valid_actions -= set(['d'])                    if state % self.MAZE_C == 0:                  # 首列,则 不能向左            valid_actions -= set(['l'])        elif state % self.MAZE_C == self.MAZE_C - 1:  # 末列,则 不能向右            valid_actions -= set(['r'])                    return list(valid_actions)        def get_next_state(self, state, action):        '''对状态执行动作后,得到下一状态'''        #u,d,l,r,n = -6,+6,-1,+1,0        if action == 'u' and state // self.MAZE_C != 0:                 # 除首行外,向上-MAZE_C            next_state = state - self.MAZE_C        elif action == 'd' and state // self.MAZE_C != self.MAZE_R - 1: # 除末行外,向下+MAZE_C            next_state = state + self.MAZE_C        elif action == 'l' and state % self.MAZE_C != 0:                # 除首列外,向左-1            next_state = state - 1        elif action == 'r' and state % self.MAZE_C != self.MAZE_C - 1:  # 除末列外,向右+1            next_state = state + 1        else:            next_state = state        return next_state        def learn(self, env=None, episode=1000, epsilon=0.8):        '''q-learning算法'''        print('Agent is learning...')        for i in range(episode):            current_state = self.states[0]                        if env is not None: # 若提供了环境,则重置之!                env.move_agent_to(current_state)                            while current_state != self.states[-1]:                current_action = self.choose_action(current_state, epsilon) # 按一定概率,随机或贪婪地选择                next_state = self.get_next_state(current_state, current_action)                next_state_reward = self.rewards[next_state]                next_state_q_values = self.get_q_values(next_state)                self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values)                current_state = next_state                                #if next_state not in self.hell_states: # 非陷阱,则往前;否则待在原位                #    current_state = next_state                                if env is not None: # 若提供了环境,则更新之!                    env.move_agent_to(current_state)            print(i)        print('\nok')            def test(self):        '''测试agent是否已具有智能'''        count = 0        current_state = self.states[0]        while current_state != self.states[-1]:            current_action = self.choose_action(current_state, 1.) # 1., 100%贪婪            next_state = self.get_next_state(current_state, current_action)            current_state = next_state            count += 1                        if count > self.MAZE_R * self.MAZE_C: # 没有在36步之内走出迷宫,则                return False                # 无智能                return True  # 有智能        def play(self, env=None, step_time=0.5):        '''玩游戏,使用策略'''        assert env != None, 'Env must be not None!'                if not self.test(): # 若尚无智能,则            if pathlib.Path("q_table.pickle").exists():                self.load_policy()            else:                print("I need to learn before playing this game.")                self.learn(env, episode=1000, epsilon=0.5)                self.save_policy()                print('Agent is playing...')        current_state = self.states[0]        env.move_agent_to(current_state, step_time)        while current_state != self.states[-1]:            current_action = self.choose_action(current_state, 1.) # 1., 100%贪婪            next_state = self.get_next_state(current_state, current_action)            current_state = next_state            env.move_agent_to(current_state, step_time)        print('\nCongratulations, Agent got it!')if __name__ == '__main__':    env = Maze()    # 环境    agent = Agent() # 个体(智能体)    agent.learn(env, episode=1000, epsilon=0.6) # 先学习    #agent.save_policy()    #agent.load_policy()    agent.play(env)                             # 再玩耍        #env.after(0, agent.learn, env, 1000, 0.8) # 先学    #env.after(0, agent.save_policy) # 保存所学    #env.after(0, agent.load_policy) # 导入所学    #env.after(0, agent.play, env)            # 再玩    env.mainloop()

 

Just enjoy it!

 

转载地址:http://ksgml.baihongyu.com/

你可能感兴趣的文章
【记录】从客户端()中检测到有潜在危险的 Request.Path 值。
查看>>
神秘的subsys_initcall【转】
查看>>
【转】Nginx服务器的反向代理proxy_pass配置方法讲解
查看>>
【OpenCV学习】Laplace变换(视频边界检测)
查看>>
关于Visual Studio无法连接到Visual Studio 的Localhost Web服务器问题
查看>>
页面限制8060 bytes
查看>>
【Android Demo】自定义Activity的标题栏(Titlebar)
查看>>
Android Studio项目整合PullToRefresh的问题记录
查看>>
Windows Azure Cloud Service (24) 使用Startup注册COM组件(上)
查看>>
多种最值算法,适时选择
查看>>
问题三兄妹、个性三兄妹、神仙三兄妹……
查看>>
程序员编程艺术:三之三续、求数组中给定下标区间内的第K小(大)元素
查看>>
4、应用程序设置应用程序详细信息页面
查看>>
QTP的那些事--执行用例后提交bug到QC中
查看>>
[火狐REST] 火狐REST 模拟 HTTP get, post请求
查看>>
C++中数字与字符串之间的转换
查看>>
[Android Pro] Java进阶学习:jar打包详解
查看>>
Linux中断(interrupt)子系统之一:中断系统基本原理【转】
查看>>
Sublime Text2格式化HMTL/CSS/JS插件HTML-CSS-JS Prettify
查看>>
AgileEAS.NET平台开发实例-药店系统-BLL层分析
查看>>