项目作者: batra98

项目描述 :
Using MDP based models (Value Iteration and Policy Iteration) on toy environments.
高级语言: Jupyter Notebook
项目地址: git://github.com/batra98/MDP-Basics.git
创建时间: 2019-08-31T16:09:12Z
项目社区:https://github.com/batra98/MDP-Basics

开源协议:MIT License

下载


MDP Basics and Dynamic Programming Methods

Introduction

In this asignment we will be creating toy environments and map them to MDP based models (Value Iteration and Policy Iteration). We have used dynamic programming and bootstrapping to implement these agents, which learn to navigate through the environment to maximize returns.

:file_folder: File Structure

  1. ├── assignment.pdf
  2. ├── agent.py
  3. ├── env.py
  4. ├── main.ipynb
  5. ├── README.md
  6. ├── main.pdf
  7. └── main_files
  8. 2 directories, 6 files
  • env.py - Contanins code for the env in the Grid and Gambler problem. Details of the environment are given below in the appropriate sections.
  • agent.py - Contains code for the various agents used.
  • assignment.pdf - Contains details of the assignment.
  • main.ipynb - Contains the plots of the policies generated by the various agents in different environments.
  • main.pdf - contains .pdf format of the .ipynb notebook.

Details of the various problems

Problem-:one: : GridWorld

Environment Setting


The environment consists of a grid with the start state as and terminal state as .



The reward for reaching the terminal state is else the reward .



The possible actions in each stae are .

  1. import env
  2. import agent
  3. import numpy as np
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. from mpl_toolkits.mplot3d import Axes3D
  7. from matplotlib import cm
  8. Env = env.Grid_1()
  9. agent_1 = agent.ValueIteration(Env)
  10. agent_2 = agent.PolicyIteration(Env)
  11. agent_3 = agent.ConfusedAgent(Env)
  • We train 3 agents corresponding to Value Iteration, Policy Iteration and Random.
  1. def plot(Env,policy,V):
  2. pp = np.reshape(np.argmax(policy, axis=1), Env.shape)
  3. print(pp)
  4. cmp = plt.matshow(np.reshape(V, Env.shape))
  5. # plt.arrow(0,0.5,0,-0.7,head_width = 0.1)
  6. plt.colorbar(cmp)
  7. for i in range(Env.shape[0]):
  8. for j in range(Env.shape[1]):
  9. if i == (Env.shape[0]-1) and j == (Env.shape[1]-1):
  10. continue
  11. if pp[i][j] == 0:
  12. plt.arrow(j,i+0.5,0,-0.7,head_width = 0.1)
  13. elif pp[i][j] == 2:
  14. plt.arrow(j,i-0.5,0,+0.7,head_width = 0.1)
  15. elif pp[i][j] == 1:
  16. plt.arrow(j-0.5,i,0.7,0,head_width = 0.1)
  17. elif pp[i][j] == 3:
  18. plt.arrow(j+0.5,i,-0.7,0,head_width = 0.1)
  19. plt.show()

Function for plotting the optimal policies obtained by different agents

Part-a,b

Value-Iteration

  1. ### Value Iteration
  2. # agent_1.set_gamma(0.5)
  3. agent_1.clear()
  4. itr = 0
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. itr += 1
  11. print(itr)
  12. policy = agent_1.get_policy()
  13. print(np.reshape(agent_1.V,Env.shape))
  14. plot(Env,policy,agent_1.V)
  15. # agent_1.clear()
  1. 121
  2. [[-5.22041551 -4.66374662 -3.76022869 -3.51898212 -3.39792422 -3.27240293
  3. -3.53854272 -3.20898484]
  4. [-4.9694722 -4.4278591 -3.51898212 -3.39792422 -3.27240293 -3.1021674
  5. -2.68712729 -2.8775704 ]
  6. [-5.24781006 -4.71161351 -4.22887404 -3.31763094 -3.1021674 -2.68712729
  7. -2.16822194 -2.68712729]
  8. [-4.3842771 -4.22887404 -3.31763094 -3.15465386 -2.78012179 -2.16822194
  9. -1.72471154 -2.16822194]
  10. [-3.58602673 -3.53169027 -3.15465386 -2.63622143 -2.17376983 -1.72471154
  11. -1.52491241 -1.60236044]
  12. [-3.53169027 -3.19126 -2.63622143 -2.53203929 -2.01816864 -1.44891952
  13. -1.25594018 -0.62961128]
  14. [-3.58602673 -2.80017455 -2.53203929 -2.38232128 -1.44891952 -1.25594018
  15. -0.62961128 0. ]
  16. [-4.07129585 -3.14357618 -2.80017455 -2.53203929 -1.6153412 -0.73101098
  17. 0. 0. ]]
  18. [[1 1 1 2 2 2 2 2]
  19. [1 1 1 1 1 2 2 2]
  20. [1 1 2 2 1 1 2 3]
  21. [2 1 1 2 2 1 2 3]
  22. [2 2 1 2 1 1 2 2]
  23. [1 1 1 2 1 2 2 2]
  24. [0 1 1 1 1 1 1 2]
  25. [1 1 0 0 1 1 1 0]]

png

  1. In Value Iteration Algorithm, we initialize all the state values with and apply the Bellman optimality operator until the point of constantcy () is reached.



  1. We get optimal policy using the following equation.



  1. In the above part we show the final state value vector and the optimal policy for .

Policy-Iteration

  1. ### Policy Iteration
  2. # agent_2.set_gamma(0.5)
  3. agent_2.clear()
  4. while True:
  5. V = agent_2.evaluate_policy()
  6. # print(V)
  7. stable = agent_2.update()
  8. if stable == True:
  9. break
  10. print(np.reshape(agent_2.V,Env.shape))
  11. plot(Env,agent_2.policy,agent_2.V)
  12. # agent_2.clear()
  1. [[-5.22041551 -4.66374662 -3.76022869 -3.51898212 -3.39792422 -3.27240293
  2. -3.53854272 -3.20898484]
  3. [-4.9694722 -4.4278591 -3.51898212 -3.39792422 -3.27240293 -3.1021674
  4. -2.68712729 -2.8775704 ]
  5. [-5.24781006 -4.71161351 -4.22887404 -3.31763094 -3.1021674 -2.68712729
  6. -2.16822194 -2.68712729]
  7. [-4.3842771 -4.22887404 -3.31763094 -3.15465386 -2.78012179 -2.16822194
  8. -1.72471154 -2.16822194]
  9. [-3.58602673 -3.53169027 -3.15465386 -2.63622143 -2.17376983 -1.72471154
  10. -1.52491241 -1.60236044]
  11. [-3.53169027 -3.19126 -2.63622143 -2.53203929 -2.01816864 -1.44891952
  12. -1.25594018 -0.62961128]
  13. [-3.58602673 -2.80017455 -2.53203929 -2.38232128 -1.44891952 -1.25594018
  14. -0.62961128 0. ]
  15. [-4.07129585 -3.14357618 -2.80017455 -2.53203929 -1.6153412 -0.73101098
  16. 0. 0. ]]
  17. [[1 1 1 2 2 2 2 2]
  18. [1 1 1 1 1 2 2 2]
  19. [1 1 2 2 1 1 2 3]
  20. [2 1 1 2 2 1 2 3]
  21. [2 2 1 2 1 1 2 2]
  22. [1 1 1 2 1 2 2 2]
  23. [0 1 1 1 1 1 1 2]
  24. [1 1 0 0 1 1 1 0]]

png

  1. In Policy Iteration, we start with a random policy. This algorithm involves 2 parts:
    1. Policy Evaluation
    2. Policy Improvement

Policy Evaluation

  1. In this we evaluate the current policy and return the final state value vector.
  2. Evaluation is done using the following equation:-



Policy updation

  1. We update the our current policy using the final state vector obtained from the policy evaluation step.
  2. Improvement is done as follows:-



In the above part, we show the state value vector and optimal policy for Policy Iteration Algorithm.

  1. agent_3.get_policy()
  2. print(np.reshape(agent_3.V,Env.shape))
  3. plot(Env,agent_3.policy,agent_3.V)
  1. [[-0.04345 -0.55666888 -1.11333777 -2.0168557 -0.1255213 -0.17023553
  2. -0.53937516 -0.74397335]
  3. [-0.57585387 -0.53619655 -2.0168557 -2.92573268 -0.17023553 -0.29575683
  4. -0.33141444 -1.4879467 ]
  5. [-1.42734639 -1.07780964 -2.92573268 -0.16297707 -0.41504011 -0.46599235
  6. -0.19044311 -0.70934846]
  7. [-0.91915017 -1.83830034 -0.16297707 -0.51843243 -0.6814095 -1.47286136
  8. -0.19979913 -0.74597145]
  9. [-0.77347435 -2.69456567 -0.55503856 -1.04926075 -0.44905829 -2.33561133
  10. -2.78466962 -2.98446874]
  11. [-0.05433646 -0.94084391 -1.28127419 -1.56769318 -1.67187533 -0.19297934
  12. -0.76222846 -1.0312007 ]
  13. [-0.94084391 -0.26813526 -1.83631275 -0.36901193 -0.51872993 -0.6263289
  14. -0.73101098 -2.00394986]
  15. [-0.92771967 -0.92771967 -0.36901193 -0.71241356 -0.96835084 -0.81930824
  16. -1.70363847 0. ]]
  17. [[0 0 3 3 2 2 0 1]
  18. [2 2 0 3 1 3 1 0]
  19. [0 0 0 2 1 0 1 3]
  20. [3 3 1 2 3 3 2 1]
  21. [1 0 2 3 1 0 3 3]
  22. [3 2 3 0 3 2 3 3]
  23. [1 1 0 2 3 1 2 0]
  24. [1 2 1 3 2 0 3 2]]

png

We observe that the final state value vectors and optimal policies is the same for both value iteration and policy iteration algorithm.

Part-c

  1. def plot_mean_and_CI(mean, lb, ub, color_mean=None, color_shading=None):
  2. # plot the shaded range of the confidence intervals
  3. plt.fill_between(range(mean.shape[0]), ub, lb,
  4. color=color_shading, alpha=.5)
  5. # plot the mean on top
  6. plt.plot(mean, color_mean)
  1. agent_1.clear()
  2. mu = []
  3. low = []
  4. high = []
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. # mu.append((np.matrix(np.reshape(agent_1.V,Env.shape)))[0][0])
  9. # mu.append(agent_1.V[0][0])
  10. mat = np.matrix(np.reshape(agent_1.V,Env.shape))
  11. mu.append(mat[0,0])
  12. low.append((np.matrix(np.reshape(agent_1.V,Env.shape))).min())
  13. high.append((np.matrix(np.reshape(agent_1.V,Env.shape))).max())
  14. if agent_1.get_delta() < agent_1.get_threshold():
  15. break
  16. mu = np.array(mu)
  17. high = np.array(high)
  18. low = np.array(low)
  19. print(mu)
  20. plot_mean_and_CI(mu, high, low, color_mean='g--', color_shading='g')
  1. [ 0. -0.04345 -0.08689999 -0.13034999 -0.17379999 -0.21724999
  2. -0.26069998 -0.30414998 -0.34759998 -0.39104997 -0.43449997 -0.47794997
  3. -0.52139996 -0.56484996 -0.60829996 -0.65174996 -0.69519995 -0.73864995
  4. -0.78209995 -0.82554994 -0.86899994 -0.91244994 -0.95589993 -0.99934993
  5. -1.04279993 -1.08624993 -1.12969992 -1.17314992 -1.21659992 -1.26004991
  6. -1.30349991 -1.34694991 -1.3903999 -1.4338499 -1.4772999 -1.5207499
  7. -1.56419989 -1.60764989 -1.65109989 -1.69454988 -1.73799988 -1.78144988
  8. -1.82489988 -1.86834987 -1.91179987 -1.95524987 -1.99869986 -2.04214986
  9. -2.08559986 -2.12904985 -2.17249985 -2.21594985 -2.25939985 -2.30284984
  10. -2.34629984 -2.38974984 -2.43319983 -2.47664983 -2.52009983 -2.56354982
  11. -2.60699982 -2.65044982 -2.69389982 -2.73734981 -2.78079981 -2.82424981
  12. -2.8676998 -2.9111498 -2.9545998 -2.9980498 -3.04149979 -3.08494979
  13. -3.12839979 -3.17184978 -3.21529978 -3.25874978 -3.30219977 -3.34564977
  14. -3.38909977 -3.43254977 -3.47599976 -3.51944976 -3.56289976 -3.60634975
  15. -3.64979975 -3.69324975 -3.73669974 -3.78014974 -3.82359974 -3.86704974
  16. -3.91049973 -3.95394973 -3.99739973 -4.04084972 -4.08429972 -4.12774972
  17. -4.17119971 -4.21464971 -4.25809971 -4.30154971 -4.3449997 -4.3884497
  18. -4.4318997 -4.47534969 -4.51879969 -4.56224969 -4.60569969 -4.64914968
  19. -4.69259968 -4.73604968 -4.77949967 -4.82294967 -4.86639967 -4.90984966
  20. -4.95329966 -4.99674966 -5.04019966 -5.08364965 -5.12709965 -5.17054965
  21. -5.21399964 -5.22041551]

png

Above is the plot for the average value of the optimal state value function v/s number of iterations for value iteration

  1. We observe that as the number of iterations increases, the optimal reward value (average state value) converges to a constant value.
  1. agent_2.clear()
  2. mu = []
  3. low = []
  4. high = []
  5. while True:
  6. V = agent_2.evaluate_policy()
  7. # mu.append((np.matrix(np.reshape(V,Env.shape))).mean())
  8. mat = np.matrix(np.reshape(agent_2.V,Env.shape))
  9. mu.append(mat[0,0])
  10. low.append((np.matrix(np.reshape(V,Env.shape))).min())
  11. high.append((np.matrix(np.reshape(V,Env.shape))).max())
  12. stable = agent_2.update()
  13. if stable == True:
  14. break
  15. mu = np.array(mu)
  16. high = np.array(high)
  17. low = np.array(low)
  18. print(mu)
  19. plot_mean_and_CI(mu, high, low, color_mean='b--', color_shading='b')
  1. [-192.41222849 -7.22184676 -6.25734004 -5.62044015 -5.22041551]

png

Above is the plot for the average value of the optimal state value function v/s number of iterations for policy iteration

  1. We observe that as the number of iterations increases, the optimal reward value (average state value) converges to a constant value.
  1. agent_3.clear()
  2. mu = []
  3. low = []
  4. high = []
  5. for i in range(100):
  6. agent_3.get_policy()
  7. # mu.append((np.matrix(np.reshape(agent_3.V,Env.shape))).mean())
  8. mat = np.matrix(np.reshape(agent_3.V,Env.shape))
  9. mu.append(mat[0,0])
  10. low.append((np.matrix(np.reshape(agent_3.V,Env.shape))).min())
  11. high.append((np.matrix(np.reshape(agent_3.V,Env.shape))).max())
  12. mu = np.array(mu)
  13. high = np.array(high)
  14. low = np.array(low)
  15. print(mu)
  16. plot_mean_and_CI(mu, high, low, color_mean='b--', color_shading='b')
  1. [ -0.85149252 -0.89494252 -2.82045201 -3.04805649 -3.09150649
  2. -3.13495648 -3.17840648 -3.77852536 -4.92484152 -4.96829152
  3. -6.91806603 -6.96151602 -10.97207975 -14.5501409 -15.12599477
  4. -15.16944477 -14.34906268 -19.06797577 -19.11142576 -19.15487576
  5. -19.19832576 -19.24177576 -25.58355328 -25.62700328 -25.22381249
  6. -25.26726249 -25.31071249 -29.43737012 -29.48082012 -29.52427012
  7. -29.56772011 -30.46266263 -31.61107069 -31.65452069 -31.69797068
  8. -31.74142068 -31.78487068 -35.36977725 -40.98164786 -38.837261
  9. -38.88071099 -39.99404876 -41.26846264 -41.31191264 -43.21823714
  10. -43.26168714 -48.93040912 -49.83392705 -49.87737705 -49.92082704
  11. -45.59886027 -45.64231027 -45.68576027 -50.64489701 -50.68834701
  12. -50.731797 -54.98156821 -55.02501821 -57.51033007 -57.55378007
  13. -57.59723007 -58.19734895 -58.24079895 -58.28424894 -61.28885544
  14. -63.54599585 -63.58944585 -63.63289584 -64.23301473 -64.27646472
  15. -64.31991472 -64.36336472 -68.40375532 -68.44720532 -73.68853667
  16. -73.73198667 -71.04513289 -71.08858288 -71.13203288 -72.0269754
  17. -72.92191792 -72.96536792 -73.00881791 -73.05226791 -73.09571791
  18. -87.10745609 -87.15090609 -87.19435609 -89.14413059 -89.18758059
  19. -89.23103059 -91.5098695 -91.5533195 -91.59676949 -91.64021949
  20. -91.68366949 -91.72711949 -96.30590684 -97.76609365 -99.46831686]

png

  • By seeing the above plots we conclude that that the rewards obtained by confused agent may be higher but on an average the confused agent performs poorly than the learned agents.
  • We also observe that the final average state value is same for both value iteration and policy iteration.

" class="reference-link">Comparison of different policies with different values of

Value Iteration

" class="reference-link">

  1. agent_1.clear()
  2. # policies_value = []
  3. # V_value = []
  4. agent_1.set_gamma(0)
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(np.reshape(agent_1.V,Env.shape))
  12. plot(Env,policy,agent_1.V)
  1. [[-0.14512055 -0.30882554 -0.30882554 -0.24017891 -0.19047494 -0.19047494
  2. -0.19047494 -0.30045085]
  3. [-0.14512055 -0.14512055 -0.29377688 -0.15884583 -0.24017891 -0.19047494
  4. -0.30045085 -0.36050373]
  5. [-0.14512055 -0.29377688 -0.15884583 -0.29377688 -0.15884583 -0.15183197
  6. -0.16053415 -0.44040548]
  7. [-0.27280255 -0.17011694 -0.06925676 -0.15884583 -0.15183197 -0.16053415
  8. -0.11034099 -0.16053415]
  9. [-0.27280255 -0.06925676 -0.47161264 -0.06925676 -0.09368626 -0.11034099
  10. -0.16053415 -0.11034099]
  11. [-0.36257063 -0.50018686 -0.06925676 -0.05339038 -0.10052064 -0.09368626
  12. -0.11034099 -0.41490262]
  13. [-0.4238164 -0.36257063 -0.05339038 -0.10052064 -0.05339038 -0.26303996
  14. -0.37154444 0. ]
  15. [-0.36731126 -0.13121401 -0.13121401 -0.05339038 -0.09717174 -0.09717174
  16. 0. 0. ]]
  17. [[2 0 3 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 3 2 2 3]
  20. [3 0 2 0 1 1 2 3]
  21. [0 1 0 3 2 1 0 3]
  22. [1 2 0 2 3 3 0 0]
  23. [2 0 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_1.clear()
  2. # policies_value = []
  3. # V_value = []
  4. agent_1.set_gamma(0.1)
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(np.reshape(agent_1.V,Env.shape))
  12. plot(Env,policy,agent_1.V)
  1. [[-0.16124344 -0.34313605 -0.34313605 -0.26134067 -0.21163671 -0.21163671
  2. -0.21163671 -0.32161262]
  3. [-0.16124344 -0.16124344 -0.31278745 -0.19012164 -0.26134067 -0.21163671
  4. -0.32161262 -0.39266309]
  5. [-0.16124344 -0.31278745 -0.19012164 -0.31278745 -0.19012164 -0.16916036
  6. -0.17329999 -0.45773387]
  7. [-0.30311091 -0.20139274 -0.11758913 -0.19012164 -0.16916036 -0.17329999
  8. -0.12766939 -0.17329999]
  9. [-0.30311091 -0.11758913 -0.48337086 -0.10692832 -0.10437809 -0.12766939
  10. -0.17329999 -0.12766939]
  11. [-0.41640676 -0.53837446 -0.10692832 -0.0640822 -0.10692832 -0.10437809
  12. -0.12766939 -0.42766846]
  13. [-0.462004 -0.38188914 -0.0640822 -0.10692832 -0.0640822 -0.26944764
  14. -0.37154444 0. ]
  15. [-0.38188914 -0.14579189 -0.14579189 -0.0640822 -0.10796752 -0.10796752
  16. 0. 0. ]]
  17. [[2 0 3 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 3 2 2 3]
  20. [3 0 2 0 1 1 2 3]
  21. [0 1 0 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_1.clear()
  2. # policies_value = []
  3. # V_value = []
  4. agent_1.set_gamma(0.5)
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(np.reshape(agent_1.V,Env.shape))
  12. plot(Env,policy,agent_1.V)
  1. [[-0.29023224 -0.61763222 -0.55453705 -0.43064222 -0.38093826 -0.38093826
  2. -0.38093826 -0.49091417]
  3. [-0.29023224 -0.29023224 -0.49758728 -0.40763051 -0.43064222 -0.38093826
  4. -0.49091417 -0.60595501]
  5. [-0.29023224 -0.49758728 -0.40763051 -0.49758728 -0.36935294 -0.29562629
  6. -0.28759844 -0.5841998 ]
  7. [-0.54558845 -0.41890161 -0.35296401 -0.36935294 -0.29562629 -0.28759844
  8. -0.25413532 -0.28759844]
  9. [-0.54558845 -0.35296401 -0.56741776 -0.16961688 -0.17849163 -0.25413532
  10. -0.28759844 -0.25413532]
  11. [-0.73728736 -0.74944148 -0.16961688 -0.13819575 -0.16961688 -0.17849163
  12. -0.25413532 -0.54196692]
  13. [-0.67307102 -0.49851726 -0.13819575 -0.16961688 -0.13819575 -0.3321362
  14. -0.37154444 0. ]
  15. [-0.49851726 -0.26242001 -0.26242001 -0.13819575 -0.19433754 -0.19433754
  16. 0. 0. ]]
  17. [[2 0 1 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 2 2 2 3]
  20. [3 0 2 1 1 1 2 3]
  21. [0 1 1 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_1.clear()
  2. # policies_value = []
  3. # V_value = []
  4. agent_1.set_gamma(0.75)
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(np.reshape(agent_1.V,Env.shape))
  12. plot(Env,policy,agent_1.V)
  1. [[-0.58046375 -0.93655564 -0.94790034 -0.81157952 -0.76187555 -0.76187555
  2. -0.76187555 -0.87185147]
  3. [-0.58046375 -0.58046375 -0.94376715 -0.86320238 -0.81157952 -0.69392968
  4. -0.85745711 -1.00359146]
  5. [-0.58046375 -0.94376715 -0.86666614 -0.94376715 -0.64820076 -0.5688836
  6. -0.55607352 -0.85745711]
  7. [-1.00943013 -0.87793725 -0.61193023 -0.64820076 -0.5688836 -0.55607352
  8. -0.52739263 -0.55607352]
  9. [-1.0298701 -0.61193023 -0.7235689 -0.3212776 -0.33464276 -0.52739263
  10. -0.55607352 -0.52739263]
  11. [-1.16573032 -1.07088514 -0.3212776 -0.29434688 -0.3212776 -0.33464276
  12. -0.52739263 -0.81044199]
  13. [-0.99451468 -0.76093661 -0.29434688 -0.3212776 -0.29434688 -0.48379693
  14. -0.37154444 0. ]
  15. [-0.76093661 -0.52483936 -0.52483936 -0.29434688 -0.38867459 -0.37154444
  16. 0. 0. ]]
  17. [[2 3 1 1 1 0 3 3]
  18. [3 3 2 1 0 2 2 3]
  19. [0 1 1 3 2 2 2 3]
  20. [1 0 2 1 1 1 2 3]
  21. [0 1 1 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 1 1 0]]

png

" class="reference-link">

  1. agent_1.clear()
  2. # policies_value = []
  3. # V_value = []
  4. agent_1.set_gamma(1)
  5. while True:
  6. agent_1.reset()
  7. agent_1.update()
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(np.reshape(agent_1.V,Env.shape))
  12. plot(Env,policy,agent_1.V)
  1. [[-4.00052043 -3.87213884 -3.5688979 -3.22967614 -2.98949723 -2.79902229
  2. -2.59654604 -2.8969969 ]
  3. [-3.85539988 -3.15875858 -2.98864165 -2.94775684 -2.78891101 -2.21474088
  4. -2.23604231 -2.59654604]
  5. [-3.15875858 -2.98864165 -2.69486477 -2.78891101 -2.16901196 -1.9474688
  6. -1.79563683 -2.23604231]
  7. [-3.04585131 -2.69486477 -2.22325213 -2.15399537 -1.9474688 -1.79563683
  8. -1.63510268 -1.79563683]
  9. [-2.84204343 -2.22325213 -2.15399537 -1.67138298 -1.61115823 -1.63510268
  10. -1.52476168 -1.63510268]
  11. [-2.97756514 -2.61499451 -1.67138298 -1.57086235 -1.51747197 -1.52476168
  12. -0.76130864 -0.83198647]
  13. [-2.79131282 -2.29112596 -1.57086235 -1.51747197 -1.25443201 -0.76130864
  14. -0.37154444 0. ]
  15. [-2.65161468 -2.28430342 -2.15308941 -1.25443201 -1.15726028 -0.37154444
  16. 0. 0. ]]
  17. [[2 2 1 1 1 2 2 3]
  18. [2 2 2 2 2 2 2 3]
  19. [1 1 2 1 2 2 2 3]
  20. [1 1 2 2 1 1 2 3]
  21. [1 1 1 2 2 1 2 3]
  22. [1 1 1 2 2 1 2 2]
  23. [1 1 1 1 2 1 2 2]
  24. [1 1 1 1 1 1 1 0]]

png

Policy Iteration

" class="reference-link">

  1. agent_2.clear()
  2. agent_2.set_gamma(0)
  3. while True:
  4. V = agent_2.evaluate_policy()
  5. # print(V)
  6. stable = agent_2.update()
  7. if stable == True:
  8. break
  9. print(np.reshape(agent_2.V,Env.shape))
  10. plot(Env,agent_2.policy,agent_2.V)
  1. [[-0.14512055 -0.30882554 -0.30882554 -0.24017891 -0.19047494 -0.19047494
  2. -0.19047494 -0.30045085]
  3. [-0.14512055 -0.14512055 -0.29377688 -0.15884583 -0.24017891 -0.19047494
  4. -0.30045085 -0.36050373]
  5. [-0.14512055 -0.29377688 -0.15884583 -0.29377688 -0.15884583 -0.15183197
  6. -0.16053415 -0.44040548]
  7. [-0.27280255 -0.17011694 -0.06925676 -0.15884583 -0.15183197 -0.16053415
  8. -0.11034099 -0.16053415]
  9. [-0.27280255 -0.06925676 -0.47161264 -0.06925676 -0.09368626 -0.11034099
  10. -0.16053415 -0.11034099]
  11. [-0.36257063 -0.50018686 -0.06925676 -0.05339038 -0.10052064 -0.09368626
  12. -0.11034099 -0.41490262]
  13. [-0.4238164 -0.36257063 -0.05339038 -0.10052064 -0.05339038 -0.26303996
  14. -0.37154444 0. ]
  15. [-0.36731126 -0.13121401 -0.13121401 -0.05339038 -0.09717174 -0.09717174
  16. 0. 0. ]]
  17. [[2 0 3 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 3 2 2 3]
  20. [3 0 2 0 1 1 2 3]
  21. [0 1 0 3 2 1 0 3]
  22. [1 2 0 2 3 3 0 0]
  23. [2 0 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_2.clear()
  2. agent_2.set_gamma(0.1)
  3. while True:
  4. V = agent_2.evaluate_policy()
  5. # print(V)
  6. stable = agent_2.update()
  7. if stable == True:
  8. break
  9. print(np.reshape(agent_2.V,Env.shape))
  10. plot(Env,agent_2.policy,agent_2.V)
  1. [[-0.16124344 -0.34313605 -0.34313605 -0.26134067 -0.21163671 -0.21163671
  2. -0.21163671 -0.32161262]
  3. [-0.16124344 -0.16124344 -0.31278745 -0.19012164 -0.26134067 -0.21163671
  4. -0.32161262 -0.39266309]
  5. [-0.16124344 -0.31278745 -0.19012164 -0.31278745 -0.19012164 -0.16916036
  6. -0.17329999 -0.45773387]
  7. [-0.30311091 -0.20139274 -0.11758913 -0.19012164 -0.16916036 -0.17329999
  8. -0.12766939 -0.17329999]
  9. [-0.30311091 -0.11758913 -0.48337086 -0.10692832 -0.10437809 -0.12766939
  10. -0.17329999 -0.12766939]
  11. [-0.41640676 -0.53837446 -0.10692832 -0.0640822 -0.10692832 -0.10437809
  12. -0.12766939 -0.42766846]
  13. [-0.462004 -0.38188914 -0.0640822 -0.10692832 -0.0640822 -0.26944764
  14. -0.37154444 0. ]
  15. [-0.38188914 -0.14579189 -0.14579189 -0.0640822 -0.10796752 -0.10796752
  16. 0. 0. ]]
  17. [[2 0 3 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 3 2 2 3]
  20. [3 0 2 0 1 1 2 3]
  21. [0 1 0 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_2.clear()
  2. agent_2.set_gamma(0.5)
  3. while True:
  4. V = agent_2.evaluate_policy()
  5. # print(V)
  6. stable = agent_2.update()
  7. if stable == True:
  8. break
  9. print(np.reshape(agent_2.V,Env.shape))
  10. plot(Env,agent_2.policy,agent_2.V)
  1. [[-0.29023224 -0.61763222 -0.55453705 -0.43064222 -0.38093826 -0.38093826
  2. -0.38093826 -0.49091417]
  3. [-0.29023224 -0.29023224 -0.49758728 -0.40763051 -0.43064222 -0.38093826
  4. -0.49091417 -0.60595501]
  5. [-0.29023224 -0.49758728 -0.40763051 -0.49758728 -0.36935294 -0.29562629
  6. -0.28759844 -0.5841998 ]
  7. [-0.54558845 -0.41890161 -0.35296401 -0.36935294 -0.29562629 -0.28759844
  8. -0.25413532 -0.28759844]
  9. [-0.54558845 -0.35296401 -0.56741776 -0.16961688 -0.17849163 -0.25413532
  10. -0.28759844 -0.25413532]
  11. [-0.73728736 -0.74944148 -0.16961688 -0.13819575 -0.16961688 -0.17849163
  12. -0.25413532 -0.54196692]
  13. [-0.67307102 -0.49851726 -0.13819575 -0.16961688 -0.13819575 -0.3321362
  14. -0.37154444 0. ]
  15. [-0.49851726 -0.26242001 -0.26242001 -0.13819575 -0.19433754 -0.19433754
  16. 0. 0. ]]
  17. [[2 0 1 1 1 0 3 3]
  18. [3 3 2 2 0 0 0 3]
  19. [0 1 1 3 2 2 2 3]
  20. [3 0 2 1 1 1 2 3]
  21. [0 1 1 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 3 1 0]]

png

" class="reference-link">

  1. agent_2.clear()
  2. agent_2.set_gamma(0.75)
  3. while True:
  4. V = agent_2.evaluate_policy()
  5. # print(V)
  6. stable = agent_2.update()
  7. if stable == True:
  8. break
  9. print(np.reshape(agent_2.V,Env.shape))
  10. plot(Env,agent_2.policy,agent_2.V)
  1. [[-0.58046375 -0.93655564 -0.94790034 -0.81157952 -0.76187555 -0.76187555
  2. -0.76187555 -0.87185147]
  3. [-0.58046375 -0.58046375 -0.94376715 -0.86320238 -0.81157952 -0.69392968
  4. -0.85745711 -1.00359146]
  5. [-0.58046375 -0.94376715 -0.86666614 -0.94376715 -0.64820076 -0.5688836
  6. -0.55607352 -0.85745711]
  7. [-1.00943013 -0.87793725 -0.61193023 -0.64820076 -0.5688836 -0.55607352
  8. -0.52739263 -0.55607352]
  9. [-1.0298701 -0.61193023 -0.7235689 -0.3212776 -0.33464276 -0.52739263
  10. -0.55607352 -0.52739263]
  11. [-1.16573032 -1.07088514 -0.3212776 -0.29434688 -0.3212776 -0.33464276
  12. -0.52739263 -0.81044199]
  13. [-0.99451468 -0.76093661 -0.29434688 -0.3212776 -0.29434688 -0.48379693
  14. -0.37154444 0. ]
  15. [-0.76093661 -0.52483936 -0.52483936 -0.29434688 -0.38867459 -0.37154444
  16. 0. 0. ]]
  17. [[2 3 1 1 1 0 3 3]
  18. [3 3 2 1 0 2 2 3]
  19. [0 1 1 3 2 2 2 3]
  20. [1 0 2 1 1 1 2 3]
  21. [0 1 1 2 2 1 0 3]
  22. [1 2 1 2 3 3 0 0]
  23. [2 2 1 0 3 3 2 2]
  24. [1 1 2 0 2 1 1 0]]

png

" class="reference-link">

  1. agent_2.clear()
  2. agent_2.set_gamma(1)
  3. while True:
  4. V = agent_2.evaluate_policy()
  5. # print(V)
  6. stable = agent_2.update()
  7. if stable == True:
  8. break
  9. print(np.reshape(agent_2.V,Env.shape))
  10. plot(Env,agent_2.policy,agent_2.V)
  1. [[-4.00052043 -3.87213884 -3.5688979 -3.22967614 -2.98949723 -2.79902229
  2. -2.59654604 -2.8969969 ]
  3. [-3.85539988 -3.15875858 -2.98864165 -2.94775684 -2.78891101 -2.21474088
  4. -2.23604231 -2.59654604]
  5. [-3.15875858 -2.98864165 -2.69486477 -2.78891101 -2.16901196 -1.9474688
  6. -1.79563683 -2.23604231]
  7. [-3.04585131 -2.69486477 -2.22325213 -2.15399537 -1.9474688 -1.79563683
  8. -1.63510268 -1.79563683]
  9. [-2.84204343 -2.22325213 -2.15399537 -1.67138298 -1.61115823 -1.63510268
  10. -1.52476168 -1.63510268]
  11. [-2.97756514 -2.61499451 -1.67138298 -1.57086235 -1.51747197 -1.52476168
  12. -0.76130864 -0.83198647]
  13. [-2.79131282 -2.29112596 -1.57086235 -1.51747197 -1.25443201 -0.76130864
  14. -0.37154444 0. ]
  15. [-2.65161468 -2.28430342 -2.15308941 -1.25443201 -1.15726028 -0.37154444
  16. 0. 0. ]]
  17. [[2 2 1 1 1 2 2 3]
  18. [2 2 2 2 2 2 2 3]
  19. [1 1 2 1 2 2 2 3]
  20. [1 1 2 2 1 1 2 3]
  21. [1 1 1 2 2 1 2 3]
  22. [1 1 1 2 2 1 2 2]
  23. [1 1 1 1 2 1 2 2]
  24. [1 1 1 1 1 1 1 0]]

png


We observe that different values of result in different policies.



If value of , then the agent focusses on short term gains and acts greedily.



If the value of , then the agent will put more weight on long term gains and will not act greedily.



In our MDP, as , we obtain a better policy.

Problem-:three:

Problem Statement: (Sutton and Barto Exercise 4.8 - Value Iteration)

  • A gambler has the opportunity to make bets on the outcomes of a sequence of coin flips. If the coin comes up heads, he wins as many dollars as he has staked on that flip; if it is tails, he loses his stake. The game ends
    when the gambler wins by reaching his goal of $100, or loses by running out of money.
  • On each flip, the gambler must decide what portion of his capital to stake, in integer numbers of dollars.

Part-a

> where:" class="reference-link">The problem can be modelled as a undiscounted, episodic, finite MDP tuple < > where:

  1. is the set of all possible states of the gambler.





    where $ and are terminal states.

  2. is the set of all possible bets he can place in a given state .



  1. is 0 if else 1.

  2. Given a state , if the gambler makes a bet of , then can be:

    1. with
    2. with







Bellman Equation for value iteration

  1. The Bellman Update Equation for this problem will be as follows:



Part-b

  1. Gambler_env = env.Gambler_env()
  2. Gambler_env.set_p_h(0.3)
  3. agent_1 = agent.Gambler_ValueIteration(Gambler_env)
  4. Gambler_env.set_p_h(0.15)
  5. agent_2 = agent.Gambler_ValueIteration(Gambler_env)
  6. Gambler_env.set_p_h(0.65)
  7. agent_3 = agent.Gambler_ValueIteration(Gambler_env)
  1. def plot_fig1(y):
  2. x = range(100)
  3. for i in range(1,len(y)):
  4. plt.plot(x,y[i][:100])
  5. plt.show()
  1. def plot_fig2(policy):
  2. x = range(100)
  3. plt.bar(x,policy)
  4. plt.show()

" class="reference-link">Plots showing the iterations of Value Iteration Algorithm and the optimal policy for different values of

  • The first plots shows the state-value function then gives the probability of winning from each state.
  • The second plots shows the optimal policy , that is a mapping from levels of capital to stakes.

" class="reference-link">

  1. agent_1.clear()
  2. y = []
  3. while True:
  4. agent_1.reset()
  5. agent_1.update()
  6. y.append(agent_1.V)
  7. if agent_1.get_delta() < agent_1.get_threshold():
  8. break
  9. policy = agent_1.get_policy()
  10. print(agent_1.get_policy())
  11. print(agent_1.V)
  1. [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 15. 9. 8.
  2. 7. 6. 5. 21. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.
  3. 11. 13. 38. 11. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1. 50. 49. 2. 3.
  4. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 10. 9. 33. 7. 6. 20. 4.
  5. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 7. 8. 16. 10. 11. 12. 12. 11.
  6. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1.]
  7. [0.00000000e+00 2.66917018e-04 8.89723393e-04 1.92325355e-03
  8. 2.96574464e-03 4.32158176e-03 6.41084517e-03 8.50388325e-03
  9. 9.88581548e-03 1.18309578e-02 1.44052725e-02 1.77664658e-02
  10. 2.13694839e-02 2.71868419e-02 2.83462775e-02 3.00251072e-02
  11. 3.29527183e-02 3.52816705e-02 3.94365260e-02 4.60307893e-02
  12. 4.80175751e-02 5.16971693e-02 5.92215525e-02 6.31880185e-02
  13. 7.12316130e-02 9.00000000e-02 9.06228064e-02 9.20760213e-02
  14. 9.44875916e-02 9.69200708e-02 1.00083691e-01 1.04958639e-01
  15. 1.09842394e-01 1.13066903e-01 1.17605568e-01 1.23612303e-01
  16. 1.31455087e-01 1.39862129e-01 1.53435964e-01 1.56141314e-01
  17. 1.60058584e-01 1.66889676e-01 1.72323898e-01 1.82018561e-01
  18. 1.97405175e-01 2.02041008e-01 2.10626728e-01 2.28183623e-01
  19. 2.37438710e-01 2.56207097e-01 3.00000000e-01 3.00622806e-01
  20. 3.02076021e-01 3.04487592e-01 3.06920071e-01 3.10083691e-01
  21. 3.14958639e-01 3.19842394e-01 3.23066903e-01 3.27605568e-01
  22. 3.33612303e-01 3.41455087e-01 3.49862129e-01 3.63435964e-01
  23. 3.66141314e-01 3.70058584e-01 3.76889676e-01 3.82323898e-01
  24. 3.92018561e-01 4.07405175e-01 4.12041008e-01 4.20626728e-01
  25. 4.38183623e-01 4.47438710e-01 4.66207097e-01 5.10000000e-01
  26. 5.11453215e-01 5.14844050e-01 5.20471047e-01 5.26146832e-01
  27. 5.33528612e-01 5.44903490e-01 5.56298920e-01 5.63822773e-01
  28. 5.74412993e-01 5.88428706e-01 6.06728536e-01 6.26344968e-01
  29. 6.58017250e-01 6.64329733e-01 6.73470028e-01 6.89409244e-01
  30. 7.02089095e-01 7.24709975e-01 7.60612075e-01 7.71429020e-01
  31. 7.91462366e-01 8.32428453e-01 8.54023656e-01 8.97816560e-01
  32. 0.00000000e+00]
  1. plot_fig1(y)
  2. plot_fig2(policy)

png

png

" class="reference-link">

  1. agent_2.clear()
  2. y = []
  3. while True:
  4. agent_2.reset()
  5. agent_2.update()
  6. y.append(agent_2.V)
  7. if agent_2.get_delta() < agent_2.get_threshold():
  8. break
  9. policy = agent_2.get_policy()
  10. print(policy)
  11. print(agent_2.V)
  1. [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 10. 9. 8.
  2. 7. 6. 5. 4. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 32. 8. 9. 10.
  3. 11. 13. 38. 11. 10. 9. 8. 7. 6. 5. 4. 3. 2. 49. 50. 1. 2. 3.
  4. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 11. 10. 16. 8. 7. 6. 5. 4.
  5. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 18. 17. 9. 10. 11. 13. 12. 11.
  6. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1.]
  7. [0.00000000e+00 1.92759560e-06 1.28536475e-05 3.73999231e-05
  8. 8.57080408e-05 1.42805249e-04 2.49880557e-04 5.10991667e-04
  9. 5.71386939e-04 6.72316904e-04 9.52034995e-04 1.30250559e-03
  10. 1.66587038e-03 3.37663800e-03 3.40670777e-03 3.49636996e-03
  11. 3.80934292e-03 3.94646937e-03 4.48212975e-03 6.24514185e-03
  12. 6.34689997e-03 6.72948446e-03 8.68337057e-03 9.09506179e-03
  13. 1.11058025e-02 2.25000000e-02 2.25109230e-02 2.25728373e-02
  14. 2.27119329e-02 2.29856789e-02 2.33092297e-02 2.39159898e-02
  15. 2.53956194e-02 2.57378593e-02 2.63097958e-02 2.78948650e-02
  16. 2.98808650e-02 3.19399321e-02 4.16342820e-02 4.18046774e-02
  17. 4.23127631e-02 4.40862765e-02 4.48633264e-02 4.78987352e-02
  18. 5.78891372e-02 5.84657665e-02 6.06337453e-02 7.17057666e-02
  19. 7.40386835e-02 8.54328810e-02 1.50000000e-01 1.50010923e-01
  20. 1.50072837e-01 1.50211933e-01 1.50485679e-01 1.50809230e-01
  21. 1.51415990e-01 1.52895619e-01 1.53237859e-01 1.53809796e-01
  22. 1.55394865e-01 1.57380865e-01 1.59439932e-01 1.69134282e-01
  23. 1.69304677e-01 1.69812763e-01 1.71586277e-01 1.72363326e-01
  24. 1.75398735e-01 1.85389137e-01 1.85965766e-01 1.88133745e-01
  25. 1.99205767e-01 2.01538684e-01 2.12932881e-01 2.77500000e-01
  26. 2.77561897e-01 2.77912745e-01 2.78700953e-01 2.80252180e-01
  27. 2.82085635e-01 2.85523942e-01 2.93908510e-01 2.95847869e-01
  28. 2.99088843e-01 3.08070902e-01 3.19324902e-01 3.30992949e-01
  29. 3.85927598e-01 3.86893172e-01 3.89772324e-01 3.99822234e-01
  30. 4.04225516e-01 4.21426166e-01 4.78038444e-01 4.81306010e-01
  31. 4.93591223e-01 5.56332677e-01 5.69552540e-01 6.34119659e-01
  32. 0.00000000e+00]
  1. plot_fig1(y)
  2. plot_fig2(policy)

png

png

" class="reference-link">

  1. agent_3.clear()
  2. y = []
  3. while True:
  4. agent_3.reset()
  5. agent_3.update()
  6. y.append(agent_3.V)
  7. if agent_3.get_delta() < agent_3.get_threshold():
  8. break
  9. policy = agent_3.get_policy()
  10. print(policy)
  11. print(agent_3.V)
  1. [0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  2. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
  3. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
  4. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 1. 1. 1. 1. 1. 1. 1. 1.
  5. 1. 1. 1. 1.]
  6. [0. 0.46137334 0.70981407 0.84360309 0.91565822 0.9544721
  7. 0.97538609 0.98666051 0.99274307 0.99602866 0.99780691 0.99877234
  8. 0.99929903 0.99958853 0.99974948 0.99984048 0.9998932 0.99992476
  9. 0.99994446 0.99995737 0.99996628 0.99997274 0.99997763 0.99998147
  10. 0.99998454 0.99998706 0.99998911 0.99999081 0.9999922 0.99999336
  11. 0.99999432 0.99999513 0.9999958 0.99999638 0.99999686 0.99999728
  12. 0.99999763 0.99999794 0.9999982 0.99999843 0.99999862 0.9999988
  13. 0.99999894 0.99999907 0.99999919 0.99999929 0.99999938 0.99999945
  14. 0.99999952 0.99999958 0.99999963 0.99999968 0.99999972 0.99999975
  15. 0.99999978 0.99999981 0.99999984 0.99999986 0.99999987 0.99999989
  16. 0.99999991 0.99999992 0.99999993 0.99999994 0.99999995 0.99999995
  17. 0.99999996 0.99999997 0.99999997 0.99999998 0.99999998 0.99999998
  18. 0.99999998 0.99999999 0.99999999 0.99999999 0.99999999 0.99999999
  19. 0.99999999 1. 1. 1. 1. 1.
  20. 1. 1. 1. 1. 1. 1.
  21. 1. 1. 1. 1. 1. 1.
  22. 1. 1. 1. 1. 0. ]
  1. plot_fig1(y)
  2. plot_fig2(policy)

png

png

Part-c

  • is the convergence threshold parameter in Value Iteration.

" class="reference-link">

  1. agent_1.set_threshold(0.00000000000000001)
  2. agent_1.clear()
  3. y = []
  4. while True:
  5. agent_1.reset()
  6. agent_1.update()
  7. y.append(agent_1.V)
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(policy)
  12. print(agent_1.V)
  1. [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 11. 15. 9. 17.
  2. 7. 6. 5. 4. 3. 2. 1. 25. 1. 2. 3. 4. 30. 6. 7. 8. 9. 10.
  3. 11. 13. 12. 11. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1. 50. 1. 2. 3.
  4. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 10. 9. 33. 7. 6. 5. 4.
  5. 28. 2. 1. 25. 1. 2. 3. 4. 5. 6. 7. 8. 9. 15. 11. 12. 12. 11.
  6. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1.]
  7. [0.00000000e+00 2.66917018e-04 8.89723393e-04 1.92325355e-03
  8. 2.96574464e-03 4.32158176e-03 6.41084517e-03 8.50388325e-03
  9. 9.88581548e-03 1.18309578e-02 1.44052725e-02 1.77664658e-02
  10. 2.13694839e-02 2.71868419e-02 2.83462775e-02 3.00251072e-02
  11. 3.29527183e-02 3.52816705e-02 3.94365260e-02 4.60307893e-02
  12. 4.80175751e-02 5.16971693e-02 5.92215525e-02 6.31880185e-02
  13. 7.12316130e-02 9.00000000e-02 9.06228064e-02 9.20760213e-02
  14. 9.44875916e-02 9.69200708e-02 1.00083691e-01 1.04958639e-01
  15. 1.09842394e-01 1.13066903e-01 1.17605568e-01 1.23612303e-01
  16. 1.31455087e-01 1.39862129e-01 1.53435964e-01 1.56141314e-01
  17. 1.60058584e-01 1.66889676e-01 1.72323898e-01 1.82018561e-01
  18. 1.97405175e-01 2.02041008e-01 2.10626728e-01 2.28183623e-01
  19. 2.37438710e-01 2.56207097e-01 3.00000000e-01 3.00622806e-01
  20. 3.02076021e-01 3.04487592e-01 3.06920071e-01 3.10083691e-01
  21. 3.14958639e-01 3.19842394e-01 3.23066903e-01 3.27605568e-01
  22. 3.33612303e-01 3.41455087e-01 3.49862129e-01 3.63435964e-01
  23. 3.66141314e-01 3.70058584e-01 3.76889676e-01 3.82323898e-01
  24. 3.92018561e-01 4.07405175e-01 4.12041008e-01 4.20626728e-01
  25. 4.38183623e-01 4.47438710e-01 4.66207097e-01 5.10000000e-01
  26. 5.11453215e-01 5.14844050e-01 5.20471047e-01 5.26146832e-01
  27. 5.33528612e-01 5.44903490e-01 5.56298920e-01 5.63822773e-01
  28. 5.74412993e-01 5.88428706e-01 6.06728536e-01 6.26344968e-01
  29. 6.58017250e-01 6.64329733e-01 6.73470028e-01 6.89409244e-01
  30. 7.02089095e-01 7.24709975e-01 7.60612075e-01 7.71429020e-01
  31. 7.91462366e-01 8.32428453e-01 8.54023656e-01 8.97816560e-01
  32. 0.00000000e+00]
  1. plot_fig1(y)
  2. plot_fig2(policy)

png

png

" class="reference-link">

  1. agent_1.set_threshold(0.00000000000001)
  2. agent_1.clear()
  3. y = []
  4. while True:
  5. agent_1.reset()
  6. agent_1.update()
  7. y.append(agent_1.V)
  8. if agent_1.get_delta() < agent_1.get_threshold():
  9. break
  10. policy = agent_1.get_policy()
  11. print(policy)
  12. print(agent_1.V)
  1. [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 15. 9. 8.
  2. 7. 6. 5. 21. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.
  3. 11. 13. 38. 11. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1. 50. 49. 2. 3.
  4. 4. 5. 6. 7. 8. 9. 10. 11. 12. 12. 11. 10. 9. 33. 7. 6. 20. 4.
  5. 3. 2. 1. 25. 1. 2. 3. 4. 5. 6. 7. 8. 16. 10. 11. 12. 12. 11.
  6. 10. 9. 8. 7. 6. 5. 4. 3. 2. 1.]
  7. [0.00000000e+00 2.66917018e-04 8.89723393e-04 1.92325355e-03
  8. 2.96574464e-03 4.32158176e-03 6.41084517e-03 8.50388325e-03
  9. 9.88581548e-03 1.18309578e-02 1.44052725e-02 1.77664658e-02
  10. 2.13694839e-02 2.71868419e-02 2.83462775e-02 3.00251072e-02
  11. 3.29527183e-02 3.52816705e-02 3.94365260e-02 4.60307893e-02
  12. 4.80175751e-02 5.16971693e-02 5.92215525e-02 6.31880185e-02
  13. 7.12316130e-02 9.00000000e-02 9.06228064e-02 9.20760213e-02
  14. 9.44875916e-02 9.69200708e-02 1.00083691e-01 1.04958639e-01
  15. 1.09842394e-01 1.13066903e-01 1.17605568e-01 1.23612303e-01
  16. 1.31455087e-01 1.39862129e-01 1.53435964e-01 1.56141314e-01
  17. 1.60058584e-01 1.66889676e-01 1.72323898e-01 1.82018561e-01
  18. 1.97405175e-01 2.02041008e-01 2.10626728e-01 2.28183623e-01
  19. 2.37438710e-01 2.56207097e-01 3.00000000e-01 3.00622806e-01
  20. 3.02076021e-01 3.04487592e-01 3.06920071e-01 3.10083691e-01
  21. 3.14958639e-01 3.19842394e-01 3.23066903e-01 3.27605568e-01
  22. 3.33612303e-01 3.41455087e-01 3.49862129e-01 3.63435964e-01
  23. 3.66141314e-01 3.70058584e-01 3.76889676e-01 3.82323898e-01
  24. 3.92018561e-01 4.07405175e-01 4.12041008e-01 4.20626728e-01
  25. 4.38183623e-01 4.47438710e-01 4.66207097e-01 5.10000000e-01
  26. 5.11453215e-01 5.14844050e-01 5.20471047e-01 5.26146832e-01
  27. 5.33528612e-01 5.44903490e-01 5.56298920e-01 5.63822773e-01
  28. 5.74412993e-01 5.88428706e-01 6.06728536e-01 6.26344968e-01
  29. 6.58017250e-01 6.64329733e-01 6.73470028e-01 6.89409244e-01
  30. 7.02089095e-01 7.24709975e-01 7.60612075e-01 7.71429020e-01
  31. 7.91462366e-01 8.32428453e-01 8.54023656e-01 8.97816560e-01
  32. 0.00000000e+00]
  1. plot_fig1(y)
  2. plot_fig2(policy)

png

png

  1. We see from the above plots that as , the state value matrix and the optimal policy becomes stable for .

Problem-:two:

Problem Statement: (Sutton and Barto Exercise 4.4 - Policy Iteration)

  • Jack manages two locations for a nationwide car rental company. Each day, some number of customers arrive at each location to rent cars. If Jack has a car available, he rents it out and is credited $10 by the national company.
  • If he is out of cars at that location, then the business is lost. Cars become available for renting the day after they are returned. To help ensure that cars are available where they are needed, Jack can move them between the two locations overnight, at a cost of $2 per car moved.
  • We assume that the number of cars requested and returned at each location are Poisson random variables, meaning that the probability that the number is n is:



  • where is the expected number. Suppose is 3 and 4 for rental requests at the first and second locations
    and 3 and 2 for returns.

  • To simplify the problem slightly, we assume that there can be no more than 20 cars at each location (any additional cars are returned to the nationwide company, and thus disappear from the problem) and a maximum of five cars can be moved from one location to the other in one night.

Part-a

> where :" class="reference-link">The problem can be modelled as a MDP tuple < > where :

  1. is the set of all possible states where is the number of cars at location .



  2. is the set of all possible movement of cars that can be done given a state .



  1. Let number of cars rented at location 1 and location 2 be and respectively and cars returned be and .That is, given .

  2. . where are sampled from poisson’s distribution with ,.

  3. Given a state , if number of cars moved are and final state is , then:



for ." class="reference-link">Assumption: I have assumed an upper bound of 8 on the number of cars that can be rented or returned as for .

Bellman Equation for policy iteration

  1. The Bellman Update Equation for this problem will be as follows:

    1. Policy evaluation:
      Let be the action determined by .



  1. 2. Policy Improvement:



Part-b

  1. def three_dimentional_plot(V):
  2. fig = plt.figure()
  3. ax = fig.gca(projection = '3d')
  4. X = np.arange(0,V.shape[0],1)
  5. Y = np.arange(0,V.shape[1],1)
  6. X,Y = np.meshgrid(X,Y)
  7. surf = ax.plot_surface(X,Y,V,rstride = 1,cstride = 1,cmap = cm.coolwarm,linewidth = 0,antialiased = False)
  8. ax.set_xlabel('Location 2')
  9. ax.set_ylabel('Location 1')
  10. plt.show()
  1. def plot_policy(all_P):
  2. itr = 0
  3. for pi in all_P:
  4. cmp = plt.matshow(pi)
  5. plt.xlabel('Location 2')
  6. plt.ylabel('Location 1')
  7. plt.colorbar(cmp)
  8. plt.show()
  1. Jack_env = env.Jack_env()
  2. agent_1 = agent.Jack_PolicyIteration(Jack_env)
  3. all_V = []
  4. all_P = []
  5. all_P.append(agent_1.policy.copy())
  6. while True:
  7. agent_1.evaluate_policy()
  8. stable = agent_1.update()
  9. all_V.append(agent_1.V.copy())
  10. all_P.append(np.flip(agent_1.policy.copy(),0))
  11. if stable == True:
  12. break
  1. plot_policy(all_P)

png

png

png

png

png

  1. three_dimentional_plot(agent_1.V)

png

Part-c

  • One of Jack’s employees at the first location rides a bus home each night
    and lives near the second location. She is happy to shuttle one car to the second location for free.
  • Each additional car still costs $2, as do all cars moved in the other direction. In addition, Jack has limited parking space at each location. If more than 10 cars are kept overnight at a location
    (after any moving of cars), then an additional cost of $4 must be incurred to use a second parking lot (independent of how many cars are kept there).
  • These sorts of non-linearities and arbitrary dynamics often occur in real problems and cannot easily be handled by optimization methods
    other than dynamic programming.
  • Solve the problem incorporating the new complications using
    Policy Iteration.
  1. Jack_env = env.Jack_env()
  2. agent_2 = agent.Jack_PolicyIteration_2(Jack_env)
  3. all_V = []
  4. all_P = []
  5. all_P.append(agent_2.policy.copy())
  6. while True:
  7. agent_2.evaluate_policy()
  8. stable = agent_2.update()
  9. all_V.append(agent_2.V.copy())
  10. all_P.append(np.flip(agent_2.policy.copy(),0))
  11. if stable == True:
  12. break
  1. plot_policy(all_P)

png

png

png

png

png

  1. three_dimentional_plot(agent_2.V)

png