2018年9月21日星期五

算法学习之股票预测

动机


学了好长时间的机器学习一直找不到实践的机会,工作上的场景都被算法大拿们占据了,小菜鸟只能自谋出路。恰逢前些时中概股大跌,损失惨重,一怒之下干脆研究研究能否用机器学习来搞搞了。  


数据获取


历史股价数据可以从雅虎金融上获取,而且已经有现成的python lib来做这个。
import pandas_datareader.data as reader
import fix_yahoo_finance as yf

def fetch_stock_list_from_web():
    global COMPANIES
    COMPANIES = pd.read_csv(SP500_LIST_PATH)

    if _data_already_loaded():
        return

    yf.pdr_override()
    for cell in COMPANIES['Symbol']:
        try:
            data = reader.get_data_yahoo(cell)
            data.to_csv('./stock_data/' + cell + '.csv')
        except Exception as e:
            print(e)

其中SP500_LIST_PATH=‘constituents.csv’ 这个google一下就能找到。


打印预测结果


跳跃一下先把周边的工具准备完毕,有了股票预测数据之后一般都要打印一个预测和真实数据的对比图,这个可以用matplotlib.pyplot搞定。
import matplotlib.pyplot as plt

def _print_comparison_gui(pred, real):
    truths = _flatten(real)[-200:]
    preds = _flatten(pred)[-200:]
    days = range(len(truths))[-200:]

    plt.figure(figsize=(12, 6))
    plt.plot(days, truths, label='truth')
    plt.plot(days, preds, label='pred')
    plt.legend(loc='upper left', frameon=False)
    plt.xlabel("day")
    plt.ylabel("normalized price")
    plt.ylim((min(preds), max(truths)))
    plt.grid(ls='--')
    plt.show()


naive版本


预测股价的模型应该是怎样的呢?初始想法是把股价表达成时间的函数,因为某些时间具有特殊性比如周五,年底等等,所以把时间拆成年,月,日而不是作为一个整体,其实我们预测股价的时候一般会看公司的基本面,大环境,一些相关消息等等,但搞定这些东西还远超出我的能力所以就先忽略了。最终x是(year, month, day, day_of_week), y就是price。
def preprocess_data(symbols):
    global PROCESSED_STOCK_DATA

    for i in range(len(symbols)):
        symbol = symbols[i]
        print('Loading ' + symbol)
        price = ORIGINAL_STOCK_DATA[symbol]['Open']
        price = list(map(lambda each: [each],price))
        date = list(map(lambda each: datetime.datetime.strptime(each, '%Y-%m-%d'), ORIGINAL_STOCK_DATA[symbol]['Date']))
        x = list(map(lambda each: (each.year, each.month, each.day, each.weekday()), date))

        train_size = int(len(x) * (1.0 - TEST_RATIO))
        train_X, test_X = x[:train_size], x[train_size:]
        train_y, test_y = price[:train_size], price[train_size:]

        PROCESSED_STOCK_DATA[symbol] = (train_X, train_y, test_X, test_y)


模型选择了一个简单的神经网络,

def build_model(x):
    L1 = 30
    w1 = tf.Variable(tf.random_uniform([INPUT_VARIABLE_COUNT, L1], 0, 1))
    b1 = tf.Variable(tf.zeros([1, L1]))
    wb1 = tf.matmul(x, w1) + b1
    layer1 = tf.nn.relu(wb1)

    L2 = 40
    w2 = tf.Variable(tf.random_uniform([L1, L2], 0, 1))
    b2 = tf.Variable(tf.zeros([1, L2]))
    wb2 = tf.matmul(layer1, w2) + b2
    layer2 = tf.nn.relu(wb2)

    w3 = tf.Variable(tf.random_uniform([L2, 1], 0, 1))
    b3 = tf.Variable(tf.zeros([1, 1]))
    wb3 = tf.matmul(layer2, w3) + b3

    return wb3

可惜一跑就杯具了,跑了一小会所有的预测值就都在12上下跳动了,感觉像是梯度消失只好搬出batchnorm来救场,

def build_model(x, is_test, iteration):
    L1 = 30
    w1 = tf.Variable(tf.random_uniform([INPUT_VARIABLE_COUNT, L1], 0, 1))
    S1 = tf.Variable(tf.ones([L1]))
    b1 = tf.Variable(tf.zeros([1, L1]))
    wb1 = tf.matmul(x, w1) + b1
    Y1bn, update_ema1 = batchnorm(wb1, b1, S1, is_test, iteration)
    layer1 = tf.nn.relu(Y1bn)

    L2 = 40
    w2 = tf.Variable(tf.random_uniform([L1, L2], 0, 1))
    S2 = tf.Variable(tf.ones([L2]))
    b2 = tf.Variable(tf.zeros([1, L2]))
    wb2 = tf.matmul(layer1, w2) + b2
    Y2bn, update_ema2 = batchnorm(wb2, b2, S2, is_test, iteration)
    layer2 = tf.nn.relu(Y2bn)

    w3 = tf.Variable(tf.random_uniform([L2, 1], 0, 1))
    b3 = tf.Variable(tf.zeros([1, 1]))
    wb3 = tf.matmul(layer2, w3) + b3
    update_ema = tf.group(update_ema1, update_ema2)

    return wb3, update_ema

最终拿到的结果是这样的,





















感觉相当的不靠谱,程序似乎学到了一点大的趋势,但在基准值和连续性上面都差距较大。连续性差感觉可能是变量不够,网络层次太浅,可惜加大层数后在我的macbook pro上面跑了一晚上也没有一点要收敛的样子,只能放弃了。


LSTM版本


回想一下股价预测的方法中有一个流派叫技术分析也就是看k线图,这个似乎比较符合我现在只有股价做输入的状况,而k线图不就是和LSTM模型的场景一样一样的么?不过从头写一个LSTM还是有点超出本菜鸟的能力了,所以就在别人的基础上改吧,借鉴了使用rnn预测股票价格  

rnn文章里面用的面向对象编程风格,而且把训练过程和测试过程放到一起,都让程序的理解难度加大了,所以我改写成了简单的过程式风格而且分开了训练和测试过程。 

一些要点

1. LSTM版本除了模型不一样之外比较重要的是对股价做了规范化处理,把绝对股价换算成股价变动百分比,这样就不会碰到预测不出训练时没训练过的股价的问题了。

        data = [data[0] / data[0][0] - 1.0] + [
            curr / data[i][-1] - 1.0 for i, curr in enumerate(data[1:])]

2. 数据预处理的时候有INPUT_SIZE的概念,就是把连续几天的股价作为一个向量输入,这样更能体现股价变化的特征

        data = [np.array(data[i * INPUT_SIZE: (i + 1) * INPUT_SIZE])
                for i in range(len(data) // INPUT_SIZE)]


3. 预测的模型实际上是用连续多天(比如30天)的股价预测后面一天的股价,所以x和y是下面的样子

        X = np.array([data[i: i + NUMBER_OF_STEPS] for i in range(len(data) - NUMBER_OF_STEPS)])
        y = np.array([data[i + NUMBER_OF_STEPS] for i in range(len(data) - NUMBER_OF_STEPS)])


4. rnn文章对股票价格的input还做了一个向量转换(tf.nn.embedding_lookup),但我实际测试的结果不理想,而且感觉做了数据规范化以后的股价已经是向量了,似乎没必要再做一层转换,所以我的模型是这样的

def build_model(symbols,inputs,keep_prob):
    def _create_one_cell():
        lstm_cell = tf.contrib.rnn.LSTMCell(LSTM_SIZE, state_is_tuple=True)
        lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)
        return lstm_cell

    cell = tf.contrib.rnn.MultiRNNCell(
        [_create_one_cell() for _ in range(NUMBER_OF_LAYERS)],
        state_is_tuple=True
    ) if NUMBER_OF_LAYERS > 1 else _create_one_cell()

    # Run dynamic RNN
    val, state_ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, scope="dynamic_rnn")
    # Before transpose, val.get_shape() = (batch_size, num_steps, lstm_size)
    # After transpose, val.get_shape() = (num_steps, batch_size, lstm_size)
    val = tf.transpose(val, [1, 0, 2])
    # 这里取向量的最后一位也就是最后一个step对应的数据,也就是模型预测的股票数据,也解释了为啥要用transpose
    last = tf.gather(val, int(val.get_shape()[0]) - 1, name="lstm_state")
    ws = tf.Variable(tf.truncated_normal([LSTM_SIZE, INPUT_SIZE]), name="w")
    bias = tf.get_variable("b", [INPUT_SIZE])
    pred = tf.matmul(last, ws) + bias
    return pred, cell

结果





















从上面的图来看,预测结果还是有点靠谱的,有几个较大的波动都给预测出来了,这是不是也说明K线图法是靠谱的呢?


代码



import pandas_datareader.data as reader
import fix_yahoo_finance as yf
import numpy as np
import os
import pandas as pd
import random
import tensorflow as tf
import matplotlib.pyplot as plt


SP500_LIST_PATH = './stock_data/constituents.csv'
DATA_PATH = './stock_data'
CHECK_POINTS_PATH = "./checkpoints/stock_check_points"
TRAIN_STOCK_LIST = ['AAPL','AMZN','GOOG','FB','MSFT','NFLX','NVDA','ORCL']
#TRAIN_STOCK_LIST = ['AAPL']


COMPANIES = None
ORIGINAL_STOCK_DATA = {}
PROCESSED_STOCK_DATA = {}

INPUT_SIZE = 5
# 可以理解成序列长度
NUMBER_OF_STEPS = 30
LSTM_SIZE = 128
EMBED_SIZE = 128
NUMBER_OF_LAYERS = 2
STOCK_COUNT = 10
MAX_EPOCH = 50000
INIT_EPOCH = 5

LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY_STEP = 1000
LEARNING_RATE_DECAY_RATE = 0.95

INIT_LEARNING_RATE = 0.001
LEARNING_RATE_DECAY = 0.99

TEST_RATIO = 0.1
BATCH_SIZE = 64
KEEP_PROB = 0.8
SAVE_STEP = 1000


def _data_already_loaded():
    dirs = os.listdir(DATA_PATH)
    if dirs and len(dirs) > 100:
        return True
    else:
        return False


def fetch_stock_list_from_web():
    global COMPANIES
    COMPANIES = pd.read_csv(SP500_LIST_PATH)

    if _data_already_loaded():
        return

    yf.pdr_override()
    for cell in COMPANIES['Symbol']:
        try:
            data = reader.get_data_yahoo(cell)
            data.to_csv('./stock_data/' + cell + '.csv')
        except Exception as e:
            print(e)


def load_stock_data():
    global ORIGINAL_STOCK_DATA
    for cell in COMPANIES['Symbol']:
        csv_file = './stock_data/' + cell + '.csv'
        if not (os.path.exists(csv_file)):
            continue
        csv_data = pd.read_csv(csv_file)
        ORIGINAL_STOCK_DATA[cell] = csv_data


def preprocess_data(symbols):
    global PROCESSED_STOCK_DATA

    for symbol in symbols:
        print('Loading ' + symbol)
        data = ORIGINAL_STOCK_DATA[symbol]['Open']
        # 按input_size进行拆分
        data = [np.array(data[i * INPUT_SIZE: (i + 1) * INPUT_SIZE])
                for i in range(len(data) // INPUT_SIZE)]
        # 计算相对增量,注意后面的for里面i是从0开始的
        data = [data[0] / data[0][0] - 1.0] + [
            curr / data[i][-1] - 1.0 for i, curr in enumerate(data[1:])]

        X = np.array([data[i: i + NUMBER_OF_STEPS] for i in range(len(data) - NUMBER_OF_STEPS)])
        y = np.array([data[i + NUMBER_OF_STEPS] for i in range(len(data) - NUMBER_OF_STEPS)])

        train_size = int(len(X) * (1.0 - TEST_RATIO))
        train_X, test_X = X[:train_size], X[train_size:]
        train_y, test_y = y[:train_size], y[train_size:]

        PROCESSED_STOCK_DATA[symbol] = (train_X, train_y, test_X, test_y)


def _generate_batch(symbol):
    train_X, train_y, test_X, test_y = PROCESSED_STOCK_DATA[symbol]
    num_batches = int(len(train_X)) // BATCH_SIZE
    if BATCH_SIZE * num_batches < len(train_X):
        num_batches += 1

    batch_indices = list(range(num_batches))
    random.shuffle(batch_indices)
    for j in batch_indices:
        batch_X = train_X[j * BATCH_SIZE: (j + 1) * BATCH_SIZE]
        batch_y = train_y[j * BATCH_SIZE: (j + 1) * BATCH_SIZE]
        assert set(map(len, batch_X)) == {NUMBER_OF_STEPS}
        yield batch_X, batch_y


def build_model(symbols,inputs,keep_prob):
    def _create_one_cell():
        lstm_cell = tf.contrib.rnn.LSTMCell(LSTM_SIZE, state_is_tuple=True)
        lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)
        return lstm_cell

    cell = tf.contrib.rnn.MultiRNNCell(
        [_create_one_cell() for _ in range(NUMBER_OF_LAYERS)],
        state_is_tuple=True
    ) if NUMBER_OF_LAYERS > 1 else _create_one_cell()


    # embed_matrix = tf.Variable(
    #     tf.random_uniform([STOCK_COUNT, EMBED_SIZE], -1.0, 1.0),
    #     name="embed_matrix"
    # )

    # 对embedding lookup的效果比较怀疑
    # stock_label_embeds.shape = (batch_size, embedding_size)
    # stacked_symbols = tf.tile(symbols, [1, NUMBER_OF_STEPS], name='stacked_stock_labels')
    # stacked_embeds = tf.nn.embedding_lookup(embed_matrix, stacked_symbols)

    # After concat, inputs.shape = (batch_size, num_steps, input_size + embed_size)
    #inputs_with_embed = tf.concat([inputs, stacked_embeds], axis=2, name="inputs_with_embed")

    inputs_with_embed = tf.identity(inputs)

    print("inputs.shape:", inputs.shape)
    print("inputs_with_embed.shape:", inputs_with_embed.shape)

    # Run dynamic RNN
    val, state_ = tf.nn.dynamic_rnn(cell, inputs_with_embed, dtype=tf.float32, scope="dynamic_rnn")
    # Before transpose, val.get_shape() = (batch_size, num_steps, lstm_size)
    # After transpose, val.get_shape() = (num_steps, batch_size, lstm_size)
    val = tf.transpose(val, [1, 0, 2])
    # 这里取向量的最后一位也就是最后一个step对应的数据,也就是模型预测的股票数据,也解释了为啥要用transpose
    last = tf.gather(val, int(val.get_shape()[0]) - 1, name="lstm_state")
    ws = tf.Variable(tf.truncated_normal([LSTM_SIZE, INPUT_SIZE]), name="w")
    bias = tf.get_variable("b", [INPUT_SIZE])
    pred = tf.matmul(last, ws) + bias
    return pred, cell


def train():
    print('start training ...')

    learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
    keep_prob = tf.placeholder(tf.float32, None, name="keep_prob")

    # Stock symbols are mapped to integers.
    symbols = tf.placeholder(tf.int32, [None, 1], name='stock_labels')

    inputs = tf.placeholder(tf.float32, [None, NUMBER_OF_STEPS, INPUT_SIZE], name="inputs")
    targets = tf.placeholder(tf.float32, [None, INPUT_SIZE], name="targets")

    pred, cell = build_model(symbols,inputs,keep_prob)

    # 方差损失
    loss = tf.reduce_mean(tf.square(pred - targets), name="loss_mse_train")
    global_step = tf.Variable(0, trainable=False)
    add_global_step = global_step.assign_add(1)

    trainable_variables = tf.trainable_variables()
    grads, a = tf.clip_by_global_norm(tf.gradients(loss, trainable_variables), 5) # prevent loss divergence caused by gradient explosion
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step=global_step,
                                               decay_steps=LEARNING_RATE_DECAY_STEP, decay_rate=LEARNING_RATE_DECAY_RATE)
    optimizer = tf.train.AdamOptimizer(learning_rate)
    optim = optimizer.apply_gradients(zip(grads, trainable_variables))

    #optim = tf.train.RMSPropOptimizer(learning_rate).minimize(loss, name="rmsprop_optim")


    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        if not os.path.exists(CHECK_POINTS_PATH):
            os.makedirs(CHECK_POINTS_PATH)

        check_point = tf.train.get_checkpoint_state(CHECK_POINTS_PATH)
        # if have checkPoint, restore checkPoint
        if check_point and check_point.model_checkpoint_path:
            saver.restore(sess, check_point.model_checkpoint_path)
            print("restored %s" % check_point.model_checkpoint_path)
        else:
            print("no checkpoint found!")

        g_step = 0
        for epoch in range(MAX_EPOCH):
            epoch_step = 0
            each_turn_learning_rate = INIT_LEARNING_RATE * (
                    LEARNING_RATE_DECAY ** max(float(epoch + 1 - INIT_EPOCH), 0.0)
            )

            for label_, d_ in PROCESSED_STOCK_DATA.items():
                label_pos = list(PROCESSED_STOCK_DATA.keys()).index(label_)
                for batch_x, batch_y in _generate_batch(label_):
                    g_step += 1
                    epoch_step += 1
                    batch_labels = np.array([[label_pos]] * len(batch_x))
                    train_data_feed = {
                        learning_rate: each_turn_learning_rate,
                        keep_prob: KEEP_PROB,
                        inputs: batch_x,
                        targets: batch_y,
                        symbols: batch_labels,
                    }
                    train_loss, train_pred, _, _ = sess.run(
                        [loss, pred, optim, add_global_step], train_data_feed)

                    print("epoch: %d, steps: %d, loss: %3f" % (epoch + 1, epoch_step, train_loss))
                    # save and test
                    if g_step % SAVE_STEP == SAVE_STEP - 1: # prevent save at the beginning
                        print("save model")
                        saver.save(sess, os.path.join(CHECK_POINTS_PATH, 'stock.model'), global_step=g_step)


def test():
    print('start test ...')

    learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
    keep_prob = tf.placeholder(tf.float32, None, name="keep_prob")

    # Stock symbols are mapped to integers.
    symbols = tf.placeholder(tf.int32, [None, 1], name='stock_labels')

    inputs = tf.placeholder(tf.float32, [None, NUMBER_OF_STEPS, INPUT_SIZE], name="inputs")
    targets = tf.placeholder(tf.float32, [None, INPUT_SIZE], name="targets")

    pred, cell = build_model(symbols,inputs,keep_prob)

    # 方差损失
    loss = tf.reduce_mean(tf.square(pred - targets), name="loss_mse_train")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        check_point = tf.train.get_checkpoint_state(CHECK_POINTS_PATH)
        # if have checkPoint, restore checkPoint
        if check_point and check_point.model_checkpoint_path:
            saver.restore(sess, check_point.model_checkpoint_path)
            print("restored %s" % check_point.model_checkpoint_path)
        else:
            print("no checkpoint found!")
            exit(1)

        for label_, d_ in PROCESSED_STOCK_DATA.items():
            label_pos = list(PROCESSED_STOCK_DATA.keys()).index(label_)
            a, b, test_x, test_y = PROCESSED_STOCK_DATA[label_]

            batch_labels = np.array([[label_pos]] * len(test_x))
            test_data_feed = {
                learning_rate: 0.0,
                keep_prob: 1.0,
                inputs: test_x,
                targets: test_y,
                symbols: batch_labels,
            }

            test_loss, test_pred = sess.run(
                [loss,pred], test_data_feed)

            print("stock: %s, loss: %3f" % (label_, test_loss))
            _print_comparison_gui(test_pred, test_y, label_)
            _print_comparison_text(test_pred, test_y)

        # writer=tf.summary.FileWriter('./logs',sess.graph)
        # writer.close()

# 把数据打印出来试试
def _print_comparison_gui(pred, real, stock_sym):
    truths = _flatten(real)[-200:]
    preds = 1 * (_flatten(pred)[-200:])
    days = range(len(truths))[-200:]

    plt.figure(figsize=(12, 6))
    plt.plot(days, truths, label='truth')
    plt.plot(days, preds, label='pred')
    plt.legend(loc='upper left', frameon=False)
    plt.xlabel("day")
    plt.ylabel("normalized price")
    plt.ylim((min(truths), max(truths)))
    plt.grid(ls='--')
    plt.title(stock_sym + " | Last %d days in test" % len(truths))
    plt.show()


def _print_comparison_text(pred, real):
    pred = _flatten(pred)
    real = _flatten(real)
    for each in pred:
        print("%5f" % each, end=' ')
    print('')
    print('')
    for each in real:
        print("%5f" % each, end=' ')
    print('')
    print('')

def _flatten(seq):
    return np.array([x for y in seq for x in y])

fetch_stock_list_from_web()
load_stock_data()
preprocess_data(TRAIN_STOCK_LIST)
#train()
test()









2018年4月6日星期五

linux内核源码初探

动机

作为一个java程序员,最遗憾的事情莫过于被jvm隔了一层所以很难真正接触到操作系统核心的东西,遇到一些诡异的现象或者性能问题也只能做一些大概的分析而无法精确的追本溯源。所以一直很想了解一下linux内核,也做过一些尝试,比如读那本经典的《深入理解linux内核》或者直接读内核源码,可惜效果都不太理想。感觉主要的原因是深入程度不够缺少可以展开思考的场景,看书和直接读源码都很容易陷入细节而且抓不住重点,也很难串联各个知识点。

一次偶然的机会接触到了孟宁老师的庖丁解牛linux内核,课程虽然不长却把内核最核心的概念讲得非常清晰。师傅引进门,在这个基础上终于可以比较自由的在内核源码中游览了,这篇博文算是这几个月内核学习的一个总结。学习基于的内核版本是3.18.6。

准备工作

要想把内核代码玩起来,能够debug是至关重要的。

编译内核

 
wget https://www.kernel.org/pub/linux/kernel/v3.x/linux-3.18.6.tar.xz
make i386_defconfig
在基础配置上加上debug信息
make menuconfig
kernel hacking—>
[*] compile the kernel with debug info
然后编译
make

基础调试环境

 安装模拟器
sudo apt-get install qemu
sudo ln -s /usr/bin/qemu-system-i386 /usr/bin/qemu 
这个时候就可以做基本的debug了
qemu -kernel linux-3.18.6/arch/x86/boot/bzImage -s -S 
gdb
 (gdb)file linux-3.18.6/vmlinux # 加载符号表
 (gdb)target remote:1234
 (gdb)break start_kernel

在最新的64位ubuntu上面运行可能还有小坑,需要
sudo apt-get installlibc6-dev-i386 

  (gdb)set_archi i386
这样就能debug linux内核的启动过程了,不过由于没有根文件系统,除了启动也没法干别的

加上根文件系统

首先写一个hello world命名为hello.c
gcc -o init hello.c -m32 -static
find init | cpio -o -Hnewc |gzip -9 > ./rootfs.img  
qemu -kernel linux-3.18.6/arch/x86/boot/bzImage -initrd rootfs.img -s -S

这样会看到内核启动以后调用hello.c,在hello.c里面写上需要debug的系统调用就可以自由的debug内核函数了

 eclipse debug

在eclipse里面选择 c++ Remote Application->Using GDB Manual Remote Debugging Laucher - Select other...  
更具体的可以参照https://www.cnblogs.com/yutingliuyl/p/7063326.html

系统调用的流程

简单流程图



重点代码

entry_32.S部分节选

ENTRY(system_call)
     RING0_INT_FRAME   # can't unwind into user space anyway
     ASM_CLAC
     pushl_cfi %eax   # save orig_eax
     #保存用户态现场
     SAVE_ALL
     GET_THREAD_INFO(%ebp)
     # system call tracing in operation / emulation
     testl $_TIF_WORK_SYSCALL_ENTRY,TI_flags(%ebp)
     jnz syscall_trace_entry
     cmpl $(NR_syscalls), %eax
     jae syscall_badsys
syscall_call:
     #查系统调用表,执行系统调用
     call *sys_call_table(,%eax,4)
     syscall_after_call:
     movl %eax,PT_EAX(%esp)  # store the return value
syscall_exit:
     LOCKDEP_SYS_EXIT
     DISABLE_INTERRUPTS(CLBR_ANY) # make sure we don't miss an interrupt
     # setting need_resched or sigpending
     # between sampling and the iret
     TRACE_IRQS_OFF
     movl TI_flags(%ebp), %ecx
     testl $_TIF_ALLWORK_MASK, %ecx # current->work
     #跳转到信号处理和调度,否则恢复用户态状态从系统调用返回
     jne syscall_exit_work

restore_all:
     TRACE_IRQS_IRET
restore_all_notrace:

restore_nocheck:
     #恢复用户态现场
     RESTORE_REGS 4   # skip orig_eax/error_code
irq_return:
     #从中断返回,基本是iret
     INTERRUPT_RETURN

syscall_exit_work:
     testl $_TIF_WORK_SYSCALL_EXIT, %ecx
     #还有工作没完成
     jz work_pending
     TRACE_IRQS_ON
     ENABLE_INTERRUPTS(CLBR_ANY) # could let syscall_trace_leave() call
     # schedule() instead
     movl %esp, %eax
     call syscall_trace_leave
     #恢复用户态现场
     jmp resume_userspace

work_pending:
     testb $_TIF_NEED_RESCHED, %cl
     #跳转到信号处理
     jz work_notifysig
work_resched:
     #开始进程调度
     call schedule
     LOCKDEP_SYS_EXIT
     DISABLE_INTERRUPTS(CLBR_ANY) # make sure we don't miss an interrupt
     # setting need_resched or sigpending
     # between sampling and the iret
     TRACE_IRQS_OFF
     movl TI_flags(%ebp), %ecx
     andl $_TIF_WORK_MASK, %ecx # is there any work to be done other
     # than syscall tracing?
     #跳转到系统调用返回逻辑
     jz restore_all
     testb $_TIF_NEED_RESCHED, %cl
     jnz work_resched

work_notifysig: 

一些重点问题

1.系统调用怎样初始化的?
main.c start_kernel->trap_init  
    set_system_trap_gate(SYSCALL_VECTOR, &system_call);
    set_bit(SYSCALL_VECTOR, used_vectors);

2.为啥要分用户态内核态?
为了系统安全稳定防止用户程序搞挂,一个是cpu安全级别,第二个是内存控制用户态程序无法访问内核内存空间。

3.eip指针的问题
save_all并没有保存eip,cs,eflags,ss,esp等最重要的指针,是谁做的呢?这个是系统调用cpu由ring3转到ring0的时候cpu自动做的,而且必须是保存在内核栈中,iret的时候就能恢复用户态的eip继续执行。

创建新进程的流程

简单流程图
























重点代码

1.调用方 
 int pid;
 /* fork another process */
 pid = fork();
 if (pid < 0) 
 { 
 } 
 else if (pid == 0) 
 {
  /*  child process  */
             printf("This is Child Process!\n");
 } 
 else 
 {  
  /*  parent process  */
             printf("This is Parent Process!\n");
  /* parent will wait for the child to complete*/
  wait(NULL);
  printf("Child Complete!\n");
 }

fork系统调用会有两次返回,一次是父进程,一次是子进程,父子进程在用户态使用同一份代码

2.task_struct浅拷贝(fork.c)
 
int node = tsk_fork_get_node(orig);
int err;

tsk = alloc_task_struct_node(node);
if (!tsk)
 return NULL;

ti = alloc_thread_info_node(tsk, node);
if (!ti)
 goto free_tsk;

err = arch_dup_task_struct(tsk, orig);


int __weak arch_dup_task_struct(struct task_struct *dst,
            struct task_struct *src)
{
 *dst = *src;
 return 0;
}

static inline void setup_thread_stack(struct task_struct *p, struct task_struct *org)
{
        //这里只拷贝了thread_info可没拷贝内核堆栈
 *task_thread_info(p) = *task_thread_info(org);
 task_thread_info(p)->task = p;
}

3.配置调度器(sched_fork)
//初始化调度器
__sched_fork(clone_flags, p);

p->state = TASK_RUNNING;

//设置进程vruntime,运行时间优先级相关
if (p->sched_class->task_fork)
 p->sched_class->task_fork(p);

raw_spin_lock_irqsave(&p->pi_lock, flags);
//为子进程指定cpu
set_task_cpu(p, cpu);
raw_spin_unlock_irqrestore(&p->pi_lock, flags);

4.copy_thread
首先说说重要的数据结构内核栈task_struct.stack
union thread_union {
 struct thread_info thread_info;
 unsigned long stack[THREAD_SIZE/sizeof(long)];
};

















上图应该表达的比较形象,另外还有一个比较重要的数据结构是task_struct.thread,存放当前进程的运行状态比如ip,sp,es,ds这些












//指向内核栈顶往下留出一个pt_regs结构的空间,注意这里的指针运算的-1是指一个结构而不是一个字节 
struct pt_regs *childregs = task_pt_regs(p);
struct task_struct *tsk;
//指向栈顶往下一个pt_regs结构
p->thread.sp = (unsigned long) childregs;
//指向栈顶
p->thread.sp0 = (unsigned long) (childregs+1);
memset(p->thread.ptrace_bps, 0, sizeof(p->thread.ptrace_bps));
//全拷贝父进程的寄存器
*childregs = *current_pt_regs();
//fork调用子进程的返回值0是从这里来的
childregs->ax = 0;
if (sp)
      childregs->sp = sp;
//子进程运行起点
p->thread.ip = (unsigned long) ret_from_fork;

一些重点问题

1.子进程执行执行起点和内核堆栈怎样保持一致?
子进程拷贝了父进程的所有寄存器相当于调用了父进程的save_all所以后面遇到restore all能够对称

2.fork,clone,vfork的区别

(1)vfork和clone的出现主要为了解决execve抛弃父进程的代码段,数据段这些东西另起炉灶所以复制浪费成本的问题和创建线程的问题

(2)fork是子进程复制父进程的各种资源,其中内存并没有深拷贝只是复制了页表指向共同的物理页然后用copy-on-write技术父进程或者子进程发生实际写的时候才把对应的页复制。vfork是父子进程共享内存空间,而且子进程会调度到父进程之前,父进程等待子进程,vfork已经不常用了。clone多用于创建线程,这种情况共享内存空间,但clone可以非常灵活,有参数控制各种资源都可以选择是共享还是复制。

(3)参考文章1 参考文章2

(4)看看代码里面clone标志怎么控制的

 if (clone_flags & CLONE_VM) {
  atomic_inc(&oldmm->mm_users);
  mm = oldmm;
  goto good_mm;
 }

 retval = -ENOMEM;
 mm = dup_mm(tsk);
 if (!mm)
  goto fail_nomem;

good_mm:

进程切换

简单流程图






























重点代码

1.前置判断
if (prev->state && !(preempt_count() & PREEMPT_ACTIVE)) {
        //如果当前进程是非就绪态,preempt_count==0不允许抢占说明是非自愿让出cpu,所以为了公平调度器会再给机会
        //当未决信号量未空的时候会设置成就绪态,否则就置成睡眠从执行队列里面拿出
        //并且从工作队列拿出一个任务放入运行队列
 if (unlikely(signal_pending_state(prev->state, prev))) {
  prev->state = TASK_RUNNING;
 } else {
  deactivate_task(rq, prev, DEQUEUE_SLEEP);
  prev->on_rq = 0;

  /*
   * If a worker went to sleep, notify and ask workqueue
   * whether it wants to wake up a task to maintain
   * concurrency.
   */
  if (prev->flags & PF_WQ_WORKER) {
   struct task_struct *to_wakeup;

   to_wakeup = wq_worker_sleeping(prev, cpu);
   if (to_wakeup)
    try_to_wake_up_local(to_wakeup);
  }
 }
 switch_count = &prev->nvcsw;
}

2.地址空间切换

//如果是内核线程,借用prev的地址空间,不过内核线程不访问用户态地址空间
if (!mm) {
 next->active_mm = oldmm;
 atomic_inc(&oldmm->mm_count);
 enter_lazy_tlb(oldmm, next);
//否则做地址空间切换
} else
 switch_mm(oldmm, mm, next);

if (!prev->mm) {
 prev->active_mm = NULL;
 rq->prev_mm = oldmm;
}

spin_release(&rq->lock.dep_map, 1, _THIS_IP_);

context_tracking_task_switch(prev, next);
/* Here we just switch the register state and the stack. */
switch_to(prev, next, prev);

barrier();

//主要是归还之前借用的地址空间
finish_task_switch(this_rq(), prev);

3.寄存器切换switch_to.h

#define switch_to(prev, next, last)     \
do {         \
 unsigned long ebx, ecx, edx, esi, edi;    \
         \
 asm volatile("pushfl\n\t"  /* save    flags */ \
       "pushl %%ebp\n\t"  /* save    EBP   */ \
       "movl %%esp,%[prev_sp]\n\t" /* save    ESP   */ \
       "movl %[next_sp],%%esp\n\t" /* restore ESP   */ \
       "movl $1f,%[prev_ip]\n\t" /* save    EIP   */ \
       "pushl %[next_ip]\n\t" /* restore EIP   */ \
       __switch_canary     \
       "jmp __switch_to\n" /* regparm call  */ \
       "1:\t"      \
       "popl %%ebp\n\t"  /* restore EBP   */ \
       "popfl\n"   /* restore flags */ \
         \
       /* output parameters */    \
       : [prev_sp] "=m" (prev->thread.sp),  \
         [prev_ip] "=m" (prev->thread.ip),  \
         "=a" (last),     \
         \
         /* clobbered output registers: */  \
         "=b" (ebx), "=c" (ecx), "=d" (edx),  \
         "=S" (esi), "=D" (edi)    \
                \
         __switch_canary_oparam    \
         \
         /* input parameters: */    \
       : [next_sp]  "m" (next->thread.sp),  \
         [next_ip]  "m" (next->thread.ip),  \
                \
         /* regparm parameters for __switch_to(): */ \
         [prev]     "a" (prev),    \
         [next]     "d" (next)    \
         \
         __switch_canary_iparam    \
         \
       : /* reloaded segment registers */   \
   "memory");     \
} while (0)

这一段主要在置换ip,sp,jmp __switch_to使用regparm call, 参数不是压入堆栈,而是使用寄存器传值,来调用__switch_to, eax存放prev,edx存放next。这里为什么不用call __switch_to而用jmp,因为call会导致自动把下面这句话的地址(也就是1:)压栈,然后__switch_to()就必然只能ret到这里,而无法根据需要ret到ret_from_fork,当一个进程再次被调度时,会从1:开始执行,把ebp弹出,然后把flags弹出。

重点问题

1.__switch_to干啥了?
__switch_to 这里被称作硬件上下文切换代码见process_32.c,最重要的是做了tss相关的处理。
先看看tss相关的知识
(1)分段有两个用一个是cpu位数不足,通过段机制扩大访问空间,二个是控制访问权限,linux只用了权限控制,其他的都是为了保持兼容
(2)linux有全局段描述符表gdt和局部的ldt,区别是一个是cpu级别一个是进程级别,cs,ds,ss这些里面放的是段选择符,为了去gdt里面拿实际的段描述符,段描述符存的有base:limit这样的地址
(3)linux不同进程的相同种类的段描述符(比如内核代码段)基本完全一样,因为寻址都是0-4g,也说明了linux没有用地址扩大机制
(4)tss也很重要保存了内核栈顶指针。这些重要的内容都有对应的cpu寄存器比如tss->tr cs->段描述符非编程寄存器 gdt->gdtr,也终于理解了为啥__switch_to叫硬件切换
(5)tss任务状态段,是和cpu相关的,linux只用tss保存ss0,esp0(内核栈顶),在跨段提权的时候,需要切换栈,CPU会通过 tr 寄存器找到 TSS,load_TLS()处理gdt这些东西
(6)可以参照一下这篇文章

2.什么时候会schedule呢?
 中断处理会调用,系统调用返回会调用,内核线程会调用,比如那个有名的0号idle进程

3.一般来说一个进程或者线程跑多久就被调度呢?
stackoverflow上的帖子,感觉CFS的这个答案比较符合自己测试的效果
实际计算逻辑可以看effective_prio(core.c),还可以参考一篇讲调度的好文

4.内核抢占的问题
用最简单的方式理解就是中断或者系统调用返回的时候调用了schedule()
一篇讲的不错的文章

5.O(1)和cfs
(1)O(1)最显著的特点是有一个优先级数组,数组存该优先级下面进程链表的指针,这样pick next就是O(1)可以理解成计数排序 谈谈调度 - Linux O(1)
(2)cfs进程优先级组织成红黑树,pick next每次取最左节点,优先级计算的最重要因素是进程已运行时间和等待时间
(3)两篇比较好的文章cfs1cfs2

6.TASK_UNINTERRUPTIBLE
(1)如果是D状态是会释放cpu的
(2)状态转换发生在系统调用中,主要特点是不响应信号,注意因为是睡眠的所以和中断没啥关系
(3)常见于io,但内核加锁,进程调度等过程中也可能使用该状态,主要是为了避免响应信号导致的复杂代码路径导致的复杂的状态处理而使用的一个保护状态

可执行程序的装载

重要过程

1.do_execve 入口
2.open_exec 获取可执行文件对象
3.sched_exec 确定负载最小的cpu用来执行新程序
4.bprm_mm_init 初始化linux_binprm内存结构
5.prepare_binprm 填充linux_binprm某些字段,读取目标文件头的128个字节帮助判断文件类型
6.search_binary_handler 根据可执行文件的格式寻找对应的load_binary处理函数,这里应该会找到load_elf_binary,然后进入load_elf_binary
7.flush_old_exec 清理从父进程继承过来的用户态内存空间
8.elf_map 用mmap建立可执行文件到用户态内存空间的映射
9.elf_entry = loc->elf_ex.e_entry 设置返回用户态之后的入口地址
10.create_elf_tables 填写目标文件的参数环境变量等必要信息
11.start_thread 直接设置ip,sp等指针作为返回用户态的起始点,原来的起始点应该是系统调用的下一条指令

因为execve的时候会用可执行文件的内容替换掉当前进程的内存空间,所以一般都是先fork一个子进程然后在子进程里面调用execve。

12.但这里并没有到达用户的main函数,还有一段过程在_start函数,这个函数在sysdeps/x86_64/start.S中,总是被默认的 ld 脚本链接到程序 .text 段的起始位置。详细见用户空间程序启动过程

13.32位text segment的起始地址是0x08048000,64位是0x400000。原因是32位的栈放在0x08048000往下生长(不过这好像是历史了,现在栈好像还是接近用户空间的头部),而64位放在0x80000000000,这主要是因为64位地址大了直接操作指令不工作需要用更慢的指令实现所以尽量把需要直接寻址的数据结构往低地址放。用4M这个大小可能是因为够用又是最大页的大小,详细可见stackoverflow

重点代码

1.对多种可执行文件格式的支持

static struct linux_binfmt elf_format = {
 .module  = THIS_MODULE,
 .load_binary = load_elf_binary,
 .load_shlib = load_elf_library,
 .core_dump = elf_core_dump,
 .min_coredump = ELF_EXEC_PAGESIZE,
};

static int __init init_elf_binfmt(void)
{
 register_binfmt(&elf_format);
 return 0;
}
 
2.静态链接和动态链接入口的差异

if (elf_interpreter) {
   unsigned long interp_map_addr = 0;
   //入口是动态链接器的程序入口
   elf_entry = load_elf_interp(&loc->interp_elf_ex,
   interpreter,
   &interp_map_addr,
   load_bias);

} else {       
   //入口是可执行程序的入口类似于0x8048d0a的地址
   elf_entry = loc->elf_ex.e_entry;

}

3.start_thread

//把内核栈放用户态状态的那几个位置更新掉
start_thread(struct pt_regs *regs, unsigned long new_ip, unsigned long new_sp)
{
 set_user_gs(regs, 0);
 regs->fs  = 0;
 regs->ds  = __USER_DS;
 regs->es  = __USER_DS;
 regs->ss  = __USER_DS;
 regs->cs  = __USER_CS;
 regs->ip  = new_ip;
 regs->sp  = new_sp;
 regs->flags  = X86_EFLAGS_IF;
 /*
  * force it to the iret return path by making it look as if there was
  * some work pending.
  */
 set_thread_flag(TIF_NOTIFY_RESUME);
}

重点问题

1.可执行文件装载到内存的哪个位置?
这个在编译的时候就能决定了,基本上就是把elf里面对应的地址原样搬迁到内存空间的逻辑地址中去,详细见linux 进程地址空间的一步步探究



启动过程

流程概览

1.整体流程这篇文章讲的很好下面是我的一些摘要性理解。

2.bios程序自举读取引导设备第一个扇区的前 512 字节(MBR),将其读入到内存 0x0000:7C00,并跳转至此处执行,MBR存放的是GRUB stage1。

3.GRUB是一个引导程序,装ubuntu,windows双系统的时候启动的时候就能看到这东东,负责载入操作系统内核。

4.linux内核镜像bzImage 由 setup.elf、setup.bin、vmlinux.bin.gz 三部分组成,其中setup.bin由setup.elf通过objcopy得来,setup.elf运行在实模式下,为保护模式的Linux内核启动准备环境。这个部分最后会切换进入保护模式,跳转到保护模式的内核执行,也就是0x10000的vmlinux处。而vmlinux是进入保护模式后的代码部分,由解压程序和内核映像压缩包组成。

5.GRUB 等 boot loader 将 setup.elf 读到 0x90000 处,将 vmlinux 读到 0x100000 处,然后跳转到 0x90200 开始执行,恰好跳过了 512 字节的 bootsector。

6.0x90200(_start)会跳到start_of_setup,start_of_setup做一些初始化工作最后跳转到arch/x86/boot/main.c。

7.main.c首先干一些比较硬的初始化工作比如copy_boot_params,console_init,set_bios_mode,detect_memory,keyboard_init...
最后调用go_to_protected_mode(boot/pm.c)跳转到保护模式。

8.go_to_protected_mode先干的比较重要的事情是,setup_idt-初始化中断描述符,setup_gdt-初始化GDT,最后调用protected_mode_jump(pmjump.S)实际从实模式跳转到保护模式。

9. 跳转到保护模式简单理解就是由没有分段的内存访问跳转到有分段的内存访问并且设置了cpu的cr0标志,其中做的比较重要的事情有(1)物理地址转换成线性地址 (2)设置cr0标志
  (3)跳到boot_params.hdr.code32_start即vmlinux的入口(header.S 0x100000即1m的bizImage的位置)

10.arch/x86/boot/compressed/head_32.S,这里拿到的内核还是压缩过的,所以要先解压,解压后跳转到真正的内核入口arch/x86/kernel/head_32.S。

11.head_32.S,(1)初始化参数  (2)开启分页机制,主要的工作是生成页目录页表,设置cr0,cr3寄存器等 (3)初始化eflags (4)初始化中断向量表 (5)载入gdt,idt (6)最后i386_start_kernel(head32.c)->start_kernel(main.c)终于到了能debug到的代码了。

12.start_kernel 这里开始进入体系结构无关的初始化部分。这里初始化的项目很多,比较有代表性的有trap_init(系统保留中断向量初始化),mm_init(内存分配器初始化),sched_init(初始化调度器),init_IRQ(其它中断向量初始化),signals_init(信号量管理初始化),rest_init(这个稍微特别一点)

13.rest_init 这里做的最重要的事情就是产生了1号进程和2号进程,而当前进程本身做位0号进程也蜕化成了idle进程,参考文章 。(1)创建1号进程有一个有意思的点,这里本来是在内核态,通过一个execve系统调用返回到用户态了,主处理器上的idle由原始进程(pid=0)演变而来,从处理器上的idle由init进程fork得到,但是它们的pid都为0。Idle进程为最低优先级,且不参与调度,只是在运行队列为空的时候才被调度。(2)1号进程,init进程,kernel_thread(kernel_init, NULL, CLONE_FS),加载init程序最终运行在内核态是其他用户态进程的祖先,并且作为守护进程守护其他进程,1号用户进程->getty进程->shell进程
(3)2号进程,kthreadd进程,运行在内核空间,负责内核线程的调度和管理

14.init进程也有一个进化史,参考文章
  (1)sysvinit 最早的实现,启动任务完全串行,runlevel定义不同的启动配置比如纯文本的,图形系统的,启动慢
  (2)Upstart ubuntu16之前的启动系统,事件驱动,部分并行化
  (3)Systemd ubuntu现在使用的版本,并行化程度更高,初始化一些被依赖的资源接口加大并行度,使用cgroup而不是strace跟踪进程

epoll

简单流程图








































一些重点理解

1.epoll性能高的重要原因是组织了一颗目标fd的红黑树,这个红黑树用来通过fd来快速查找epitem。还有一个原因是准备了一个就绪fd链表,这个链表是在有中断事件到来时添加的在使用的时候就不用去主动查询每个fd状态,添加链表的逻辑在ep_poll_callback。个人感觉围绕这红黑树和就绪链表去理解epoll是个不错的方法,还有一篇不错的文章 可以参考。  

2.epoll使用了linux等待队列的机制做进程间的通知,和java里面的Object.wait有点相似但是比较半自动。
在等待方,
init_waitqueue_entry(&wait, current);
__add_wait_queue_exclusive(&ep->wq, &wait);

for (;;) {

 set_current_state(TASK_INTERRUPTIBLE);
 if (ep_events_available(ep) || timed_out)
  break;
 if (signal_pending(current)) {
  res = -EINTR;
  break;
 }

 spin_unlock_irqrestore(&ep->lock, flags);
 if (!schedule_hrtimeout_range(to, slack, HRTIMER_MODE_ABS))
  timed_out = 1;

 spin_lock_irqsave(&ep->lock, flags);
}
__remove_wait_queue(&ep->wq, &wait);

set_current_state(TASK_RUNNING);

需要自己设置进程的状态,还需要自己主动调用schedule函数退出运行队列。
在唤醒方,要调用wait queue的wake_up系列方法,wake_up默认会调用try_to_wake_up(core.c)把目标进程放入运行队列。
有一篇不错的讲解wait queue的文章

3.ep_poll_callback的注册调用链是ep_insert->ep_item_poll->ep_eventpoll_poll->poll_wait->ep_ptable_queue_proc
ep_ptable_queue_proc的第二个参数whead来源于监听目标的file->private_data,这个file可以是个socket之类的东西。

malloc

glibc端的逻辑

malloc不是系统调用而是glibc库,会在用户空间攒内存,不够用了会调用sys_brk系统调用去扩展进程内存,详细内容可以参照这篇文章 ,借用一下图片


























sys_brk系统调用

1.这一步的主要目的就是取得一个vm_area_struct结构,取得的方式可能是复用旧的vma_merge,或者创建一个新的kmem_cache_zalloc,参考文章 
kmem_cache_zalloc调用slab_alloc(slab.c)拿到vma对象本身需要的内存。

2.vm_area_struct是描述虚存空间的基本结构,描述的是线性地址,刚建立vm_area_struct的时候物理页和映射应该没建立,此时访问会报缺页异常。

3.缺页异常产生后的调用链
ENTRY(page_fault)->do_page_fault(fault.c)->handle_mm_fault->handle_pte_fault->
do_anonymous_page->alloc_zeroed_user_highpage_movable
参考文章

4.alloc_zeroed_user_highpage_movable->alloc_page_vma->__alloc_pages_nodemask(page_alloc.c)
(最后的调用链是通过别人的贴的堆栈确认的)
说明申请用户态的内存最终通过buddy系统分配的内存页。内核态申请内存还有kmalloc和vmalloc,kmalloc走的slab分配而vmalloc走的是buddy分配,参考文章

5.和mmap的关系。mmap_region和do_brk的逻辑很相似,看do_brk的注释,

/*
 *  this is really a simplified "do_mmap".  it only handles
 *  anonymous maps.  eventually we may be able to do some
 *  brk-specific accounting here.
 */

使用cat /proc/pid/maps可以看进程的vm_area_struct,其中有的条目是绑定了文件的,有的条目显示为[heap]说明是malloc出来的。从直观上也能感觉到sys_brk和mmap的联系。

其他相关知识点

1.进程的内存分布状况可以参考这篇文章 ,其中的两个图比较好,摘录下来





















































2.所有进程共享1g内核内存空间,内核必须能访问所有的4g内存,但是内核内存空间只有1g,而且16-896m和物理内存是一一映射的不会因为进程切换而改变,所有896m-1000m区间被拿出来做动态映射,叫做高端内存区,每次要访问1g之外的范围就拿高端内存去的逻辑地址建立动态映射,用完之后再还回去,参考文章

并发控制  

主要是想看看linux内核的并发控制有什么魔法,有哪些在java并发体系中看不到的东东

spin_lock


void __lockfunc _spin_lock(spinlock_t *lock)
{
        preempt_disable();
        if (unlikely(!_raw_spin_trylock(lock)))
                __preempt_spin_lock(lock);
}

static inline int _raw_spin_trylock(spinlock_t *lock)
{
        char oldval;
        __asm__ __volatile__(
                "xchgb %b0,%1"
                :"=q" (oldval), "=m" (lock->lock)
                :"0" (0) : "memory");
        return oldval > 0;
}

static inline void __preempt_spin_lock(spinlock_t *lock)
{
        if (preempt_count() > 1) {
                _raw_spin_lock(lock);
                return;
        }
        do {
                preempt_enable();
                while (spin_is_locked(lock))
                        cpu_relax();
                preempt_disable();
        } while (!_raw_spin_trylock(lock));
}

1.linux内核可能有因为调度导致死锁的情形所以加锁前需要preempt_disable禁止内核抢占。

2.*lock必须是多核都能访问到的全局数据再加上xchgb原子指令才能起到抢锁的效果。

3.如果preempt_count() > 1关闭了内核抢占类似于单核的情况处理比较简单。

4.while循环反复抢锁直到成功,cpu_relax可以是优化过的节能指令。

5.spin_unlock比较简单,直接把lock->lock置1就可以了。


信号量

这里借用<<深入理解linux内核>>里面的代码

1.释放信号量
  movl $sem->count,%ecx
  lock; incl (%ecx)
  jg 1f
  lea %ecx,%eax
  pushl %edx
  pushl %ecx
  call __up
  popl %ecx
  popl %edx
1:  

这里__up函数仅仅做了一个唤醒等待进程的动作。可以看到这里真正的魔法就是lock;这个指令起到了类似于java中volatile的作用。

2.获取信号量
  movl $sem->count,%ecx
  lock; decl (%ecx);
  jns 1f
  lea %ecx,%eax
  pushl %edx
  pushl %ecx
  call __down
  popl %ecx
  popl %edx
1:  

__down函数相对比较复杂,会首先把当前进程放到信号量sem的等待队列里面,然后循环判断信号量是否释放,如果没有释放就调用schedule让自己睡眠,但是联想一下java里面的信号量实现,其实非常类似。而汇编代码这块就比较好理解了,魔法还是在lock;指令。

对比起来linux内核加锁只是更加底层,但逻辑上和java里面的其实比较类似。


参考资料

1.https://mooc.study.163.com/course/1000072000#/info 庖丁解牛linux内核
2.https://blog.csdn.net/gogokongyin/article/details/51178257  fork()、vfork()、clone()的区别
3.https://blog.csdn.net/gatieme/article/details/51417488 Linux中fork,vfork和clone详解
4.http://www.techbulo.com/708.html GDT,LDT,GDTR,LDTR 详解,包你理解透彻
5.https://www.cnblogs.com/hanyan225/archive/2011/07/12/2103545.html linux内核分析笔记----调度
6.https://blog.csdn.net/gatieme/article/details/51872618 Linux用户抢占和内核抢占详解(概念, 实现和触发时机)
7.https://blog.csdn.net/a2796749/article/details/47101533 Linux进程调度-------O(1)调度和CFS调度器
8.http://ju.outofmemory.cn/entry/105407 从几个问题开始理解CFS调度器
9.http://home.ustc.edu.cn/~boj/courses/linux_kernel/1_boot.html Linux源代码阅读——内核引导
10.https://blog.csdn.net/gatieme/article/details/51484562 Linux下0号进程的前世(init_task进程)今生(idle进程)
11.https://www.ibm.com/developerworks/cn/linux/1407_liuming_init1/index.html 浅析 Linux 初始化 init 系统,第 1 部分
12.https://www.cnblogs.com/apprentice89/p/3234677.html epoll源码实现分析[整理]
13.https://www.douban.com/group/topic/79167871/ Linux驱动开发笔记-内核等待队列机制(分享一下)
14.https://blog.csdn.net/ordeder/article/details/41654509/ Linux Malloc分析-从用户空间到内核空间
15.https://blog.csdn.net/mrpre/article/details/79115523 调用malloc时发生了什么(2) - sys_brk函数
16.http://edsionte.com/techblog/archives/4174 malloc()之后,内核发生了什么?
17.http://unicornx.github.io/2016/04/02/20160402-lk-mm/ Linux内存管理
18.https://blog.csdn.net/beyondhaven/article/details/6636561 Linux内存管理分析报告
19.https://www.cnblogs.com/zlcxbb/p/5841417.html linux 用户空间与内核空间——高端内存详解
20.linux 进程地址空间的一步步探究
21.<<深入理解linux内核>>