Spaces:
Runtime error
Runtime error
| # =================== | |
| # Part 1: Importing Libraries | |
| # =================== | |
| import matplotlib.pyplot as plt | |
| # =================== | |
| # Part 2: Data Preparation | |
| # =================== | |
| # Data for plotting | |
| x = [10, 20, 30, 50, 155] | |
| y = [1.30, 1.21, 1.27, 1.28, 1.29] | |
| x2 = [50] | |
| y2 = [1.19] | |
| # Labels and Plot Types | |
| label_Llama_2_7B = "Llama 2 7B" | |
| label_Llama_2_13B = "Llama 2 13B" | |
| ax1_txt = [ | |
| "1.30\nLlaSMol Lite", | |
| "1.21\nLlaSMol Attn", | |
| "1.27\nLlaSMol FFN", | |
| "1.28\nLlaSMol", | |
| "1.29\nLlaSMol Plus", | |
| ] | |
| ax2_txt = "1.19\nLlaSMol Large" | |
| # Axes Limits and Labels | |
| xlabel_value = "Trainable Parameter Size (M)" | |
| ylabel_value = "RMSE" | |
| xticklabels1 = [str(num) for num in x] | |
| ylim_values = [1.15, 1.35] | |
| yticks_values = [ | |
| 1.15, | |
| 1.20, | |
| 1.25, | |
| 1.30, | |
| ] | |
| xlim_values = [-10, 170] | |
| xticks_values = [0, 50, 100, 150] | |
| xticklabels2 = ["0", "50", "100", "150"] | |
| # =================== | |
| # Part 3: Plot Configuration and Rendering | |
| # =================== | |
| # Create the figure and axis | |
| fig, ax = plt.subplots( | |
| figsize=(6, 8) | |
| ) # Adjust the size to match the original image's dimensions | |
| # Plot the data | |
| ax.plot(x, y, "ro-", label=label_Llama_2_7B, linewidth=2) | |
| ax.plot(x2, y2, "b*", markersize=10, label=label_Llama_2_13B) | |
| # Annotate the points | |
| for i, txt in enumerate(ax1_txt): | |
| ax.annotate( | |
| txt, | |
| (x[i], y[i]), | |
| textcoords="offset points", | |
| xytext=(0, 5), | |
| ha="center", | |
| fontsize=10, | |
| ) | |
| ax.annotate( | |
| ax2_txt, | |
| (x2[0], y2[0]), | |
| textcoords="offset points", | |
| xytext=(0, 5), | |
| ha="center", | |
| color="black", | |
| fontsize=10, | |
| ) | |
| # Set labels and title | |
| ax.set_xlabel(xlabel_value, fontsize=10) | |
| ax.set_ylabel(ylabel_value, fontsize=10) | |
| # Set the legend | |
| legend = ax.legend(fontsize=10) | |
| # Adjust x-axis labels | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(xticklabels1, ha="center") | |
| ax.set_ylim(ylim_values) | |
| ax.set_yticks(yticks_values) | |
| ax.set_xlim(xlim_values) | |
| ax.set_xticks(xticks_values) | |
| ax.set_xticklabels(xticklabels2, ha="center") | |
| # =================== | |
| # Part 4: Saving Output | |
| # =================== | |
| # Show the plot with tight layout | |
| plt.tight_layout() | |
| plt.savefig("CB_23.pdf", bbox_inches="tight") | |